目录
前言
Exchanging Dual-Encoder–Decoder: A New Strategy for Change Detection With Semantic Guidance and Spatial Localization
来源:TGRS2023
官方代码:https://github.com/NJU-LHRS/official-SGSLN
Temporal Fusion Attention Module(TFAM)由作者为孪生网络的变化检测所提出。作者认为卷积增强方法通过应用各种卷积操作增强多尺度和多语义级别的双时相特征,这减少了双时相特征中的噪声干扰。然后,它使用加法、减法或连接来融合双时相特征。注意力增强方法通常在通道维度上连接双时相特征,然后使用注意力机制实现有效融合。然而,卷积增强方法专注于融合前的双时相特征的增强,而注意力增强方法专注于简单融合后的双时相特征的增强。它们都忽略了双时相特征之间的时间信息。为了解决上述问题,作者提出了一种时间特征融合模块(TFAM),利用时间信息进行有效的特征融合,使用通道和空间注意力来确定特征的重要部分,并使用时间信息来确定双时相特征之间的重要部分。笔者认为该模块不仅可用于变化检测,可也用于目标检测、分割等领域中图像的特征融合。
一、TFAM网络讲解
如图1TFAM结构所示,TFAM包括两个分支,通道分支(channel branch)和空间分支(spatial branch)。通道分支用于增强通道信息关注,空间分支用于增强空间信息关注。对两个输入特征同时进行空间维度池化和通道维度池化,接着分别将空间维度池化和通道维度池化进行Concat,随后进行权重确定,保留有用信息,最后进行分离。通过权重调整,保留了双时相特征中更有价值的部分,同时抑制了不重要或误导性的信息,从而提高了变化检测的准确性和鲁棒性。
图1 TFAM结构
二、TFAM计算
TFAM使用通道和空间注意力来确定要素的重要部分,并使用时间信息来确定要素之间的重要部分。在通道分支中,输入的双时态要素通过跨空间维度的全局池化进行传递,以聚合空间信息。聚合过程可以表述为:
其中Sc 表示聚合的空间特征,T1和T2表示双时相特征,Avg(·) 和Max(·) 分别表示跨空间维度的全局平均池化和全局最大池化。聚合的空间特征被传递到两个一维卷积,它们与ECA模块相同,以确定输入双时态特征的双时相通道权重。两个通道权重可以表述为:
其中 Wc1,Wc2 表示双时相通道权重,Conv1 (·) 和 Conv2(·) 表示一维卷积。然后在双时相通道权重中使用 Softmax 使它们的总和等于 1,这意味着比较双时相权重以确定它们之间的较高值,从而确定通道维度中双时相特征之间的重要部分。Softmax 方法可以被公式化为:
其中 Wc1',Wc2' 表示输出双时相通道权重。在空间分支中,双时空间权重 Ws1' 和 Ws2'以相同的方式确定,从而确定空间维度中双时相特征之间的重要部分。对双时相通道权重和双时相空间权重进行汇总,得到双时相权重,确定双时相特征之间的重要部分。
最后,将双时态权重与双时态特征相乘并进行汇总,以有效地融合双时态特征。输出可以表述为:
其中 Output 表示融合特征。由于双时态权重之和等于1,因此保留了双时态特征之间的有用部分,而丢弃了无用部分,从而实现了有效的特征融合。
三、TFAM参数量
利用thop库的profile函数计算FLOPs和Param。Input:(64,32,32)(64,32,32)。
四、代码详解
import torch
import torch.nn as nn
import math
def kernel_size(in_channel):
"""Compute kernel size for one dimension convolution in eca-net"""
k = int((math.log2(in_channel) + 1) // 2) # parameters from ECA-net
if k % 2 == 0:
return k + 1
else:
return k
class TFAM(nn.Module):
"""Fuse two feature into one feature."""
def __init__(self, in_channel):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.k = kernel_size(in_channel)
self.channel_conv1 = nn.Conv1d(4, 1, kernel_size=self.k, padding=self.k // 2)
self.channel_conv2 = nn.Conv1d(4, 1, kernel_size=self.k, padding=self.k // 2)
self.spatial_conv1 = nn.Conv2d(4, 1, kernel_size=7, padding=3)
self.spatial_conv2 = nn.Conv2d(4, 1, kernel_size=7, padding=3)
self.softmax = nn.Softmax(0)
def forward(self, t1, t2, log=None, module_name=None,
img_name=None):
# channel part
t1_channel_avg_pool = self.avg_pool(t1) # b,c,1,1
t1_channel_max_pool = self.max_pool(t1) # b,c,1,1
t2_channel_avg_pool = self.avg_pool(t2) # b,c,1,1
t2_channel_max_pool = self.max_pool(t2) # b,c,1,1
channel_pool = torch.cat([t1_channel_avg_pool, t1_channel_max_pool,
t2_channel_avg_pool, t2_channel_max_pool],
dim=2).squeeze(-1).transpose(1, 2) # b,4,c
t1_channel_attention = self.channel_conv1(channel_pool) # b,1,c
t2_channel_attention = self.channel_conv2(channel_pool) # b,1,c
channel_stack = torch.stack([t1_channel_attention, t2_channel_attention],
dim=0) # 2,b,1,c
channel_stack = self.softmax(channel_stack).transpose(-1, -2).unsqueeze(-1) # 2,b,c,1,1
# spatial part
t1_spatial_avg_pool = torch.mean(t1, dim=1, keepdim=True) # b,1,h,w
t1_spatial_max_pool = torch.max(t1, dim=1, keepdim=True)[0] # b,1,h,w
t2_spatial_avg_pool = torch.mean(t2, dim=1, keepdim=True) # b,1,h,w
t2_spatial_max_pool = torch.max(t2, dim=1, keepdim=True)[0] # b,1,h,w
spatial_pool = torch.cat([t1_spatial_avg_pool, t1_spatial_max_pool,
t2_spatial_avg_pool, t2_spatial_max_pool], dim=1) # b,4,h,w
t1_spatial_attention = self.spatial_conv1(spatial_pool) # b,1,h,w
t2_spatial_attention = self.spatial_conv2(spatial_pool) # b,1,h,w
spatial_stack = torch.stack([t1_spatial_attention, t2_spatial_attention], dim=0) # 2,b,1,h,w
spatial_stack = self.softmax(spatial_stack) # 2,b,1,h,w
# fusion part, add 1 means residual add
stack_attention = channel_stack + spatial_stack + 1 # 2,b,c,h,w
fuse = stack_attention[0] * t1 + stack_attention[1] * t2 # b,c,h,w
return fuse
if __name__ == '__main__':
from thop import profile
model = TFAM(in_channel=32)
flops, params = profile(model, inputs=(torch.randn(1, 64, 32, 32), torch.randn(1, 64, 32, 32)))
print(f"FLOPs: {flops}, Params: {params}")