将维度embSize平均分成n_head份,构成multi-head的self-attention。平分后的qkv经过Scaled Dot-Product Attention之后在进行concat最终得到维度不变的输入数据。
以下代码实现下图红色框中的功能,即multi-head + residual + normalization.
n_head: head的个数
d_model为词嵌入维度embedSize
注意
- 多头输入的Q, K, V对于self-attention而言是相同的数据,因此维度d_model是相同的,而对应soft-attention而言Q与K,V的维度未必相同;
- 输入的Q,K,V经过linear之后通过view变成n_head个q, k, v(此处用小写),此时的q,k,v的维度dim*n_head未必等于输入的维度d_model;
- k与q的维度dim是相同的,但seq未必相同, 而k与v的seq是相同的,但dim未必相同。
对应self-attention而言,输入的数据q, k, v为同一个数据,维度为(batchSize, seqLen, d_model)
最终输出数据q的维度和输入维度相同,输出的attn为Q和K的score,维度为(batchSize, n_head , seqLen, seqLen)
ScaledDotProductAttention 见文章: https://my.oschina.net/u/4228078/blog/4497939
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, q, k, v, mask=None):
# q/k/v.shape:(batch, seqlen, d_model)
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
residual = q # q.shape: (batch, seqLen, d_model)
# Pass through the pre-attention projection: b x lq x (n*dv)
# Separate different heads: (batchSize, seqLen, n_head, dim)
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# Transpose for attention dot product: (batchSize, n_head, seqLen, dim)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if mask is not None: # mask.shape: (batchSize, 1, seqLen)
mask = mask.unsqueeze(1) # For head axis broadcasting. mask.shape: (batchSize, 1, 1, seqLen)
# q.shape: (batch, n_head, seqLen, d_k) attn.shape: (batch, n_head, seqLen, seqLen)
q, attn = self.attention(q, k, v, mask=mask)
# Transpose to move the head dimension back: b x lq x n x dv
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
q = self.dropout(self.fc(q))
q += residual
q = self.layer_norm(q)
return q, attn