Mask2Former的论文中有这样一张图,表示masked attenion比cross attention效果要好
那么这个attention map是怎么画出来的?
在mask2attention的源代码中 CrossAttentionLayer这个类中,在forward_post函数中做如下修改:
def forward_post(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2, atten_weight = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask, average_attn_weights=False)
atten_weight = atten_weight.squeeze().detach().cpu().numpy()
head_num = 0
selected_query_num = 0
if atten_weight.shape[-1] == 21888:
import matplotlib.pyplot as plt
# 创建2行4列的图形
fig, axs = plt.subplots(2, 4, figsize=(12, 6))
# 使用8次for循环在每个子图中进行绘制
for i in range(2):
for j in range(4):
atten_map = atten_weight[head_num, selected_query_num, :]
atten_map = atten_map.reshape((128, 171))
head_num += 0
axs[i, j].imshow(atten_map)
plt.show()
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
在 nn.MultiheadAttention 类实例的forward方法中,加入
average_attn_weights=False
得到每个注意力头的attention map,将attention_weight可视化,就得到了论文中的图片。