文军的烹饪实验室

文军的烹饪实验室

attention

Transformer详解(3)-多头自注意力机制-LMLPHP

Transformer详解(3)-多头自注意力机制-LMLPHP

multi-head attention

Transformer详解(3)-多头自注意力机制-LMLPHP
Transformer详解(3)-多头自注意力机制-LMLPHP

pytorch代码实现

import math
import torch
from torch import nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    def __init__(self, heads=8, d_model=128, droput=0.1):
        super().__init__()

        self.d_model = d_model  # 128
        self.d_k = d_model // heads  # 128//8=16
        self.h = heads  # 8

        self.q_linear = nn.Linear(d_model, d_model)  # (50,128)*(128,128)=(50,128),其中(128*128)属于权重,在网络训练中学习。
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(droput)
        self.out = nn.Linear(d_model, d_model)

    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # 矩阵乘法 (32,8,50,16)*(32,8,16,50)->(32,8,50,50)

        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)

        scores = F.softmax(scores, dim=-1)

        if dropout is not None:
            scores = dropout(scores)

        output = torch.matmul(scores, v)  # (32,8,50,50)*(32,8,50,16)->(32,8,50,16)
        return output

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)  # batch_size 大小  这里的例子是32
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.k_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.k_linear(v).view(bs, -1, self.h, self.d_k)
        # (32,50,128)->(32,50,128)->(32,50,8,16)  8*16=128 每个embedding拆成的8份,也就是8个头

        k = k.transpose(1, 2)  # (32,50,8,16)->(32,8,50,16)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)

        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)  # (32,8,50,16)
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)  # (32,50,128)
        output = self.out(concat)  # (32,50,128)

        return output


if __name__ == '__main__':
    multi_head_attention = MultiHeadAttention(8, 128)
    normal_tensor = torch.randn(32, 50, 128)  # 随机生成均值为0,方差为1的正态分布。batch_size=32,序列长度=50,embedding维度=128。
    x = torch.sigmoid(normal_tensor)  # 把每个数缩放到(0,1)
    output = multi_head_attention(x, x, x)
    print('done')
05-25 08:37