Multi-Head Attention(Transformer)-LMLPHP

将维度embSize平均分成n_head份,构成multi-head的self-attention。平分后的qkv经过Scaled Dot-Product Attention之后在进行concat最终得到维度不变的输入数据。

以下代码实现下图红色框中的功能,即multi-head + residual + normalization.

Multi-Head Attention(Transformer)-LMLPHP

n_head: head的个数

d_model为词嵌入维度embedSize

注意

  • 多头输入的Q, K, V对于self-attention而言是相同的数据,因此维度d_model是相同的,而对应soft-attention而言Q与K,V的维度未必相同;
  • 输入的Q,K,V经过linear之后通过view变成n_head个q, k, v(此处用小写),此时的q,k,v的维度dim*n_head未必等于输入的维度d_model;
  • k与q的维度dim是相同的,但seq未必相同, 而k与v的seq是相同的,但dim未必相同。

对应self-attention而言,输入的数据q, k, v为同一个数据,维度为(batchSize, seqLen, d_model)

最终输出数据q的维度和输入维度相同,输出的attn为Q和K的score,维度为(batchSize, n_head , seqLen, seqLen)

ScaledDotProductAttention 见文章: https://my.oschina.net/u/4228078/blog/4497939

class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)


    def forward(self, q, k, v, mask=None):
        # q/k/v.shape:(batch, seqlen, d_model)
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q  # q.shape: (batch, seqLen, d_model)

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: (batchSize, seqLen, n_head, dim)
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: (batchSize, n_head, seqLen, dim)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None: # mask.shape: (batchSize, 1, seqLen)
            mask = mask.unsqueeze(1)   # For head axis broadcasting.  mask.shape: (batchSize, 1, 1, seqLen)
        # q.shape: (batch, n_head, seqLen, d_k) attn.shape: (batch, n_head, seqLen, seqLen)
        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn
09-09 06:55