学习目标:
Transformer
Vision Transformer
学习内容:
transoformer和vision transoformer的code
学习时间:
11.5-11.11
学习产出:
一、Transformer
1.1 Self-Attention
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
# 在函数中, 首先取query的最后一维的大小, 一般情况下就等同于我们的词嵌入维度, 命名为d_k
d_k = query.size(-1)
# 自注意力公式实现,自注意力中Q、K、V矩阵都是相同的,普通注意力计算时Q不同而K、V相同
# 按照注意力公式, 将query与key的转置相乘, 这里面key是将最后两个维度进行转置, 再除以缩放系数根号下d_k, 这种计算方法也称为缩放点积注意力计算.
# 得到注意力得分张量scores
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 判断是否使用掩码张量
if mask is not None:
# 使用tensor的masked_fill方法, 将掩码张量和scores张量每个位置一一比较, 如果掩码张量处为0,则对应的scores张量用-1e9这个值来替换, 如下演示
scores = scores.masked_fill(mask == 0, -1e9)
# 对scores的最后一维进行softmax操作, 使用F.softmax方法获得最终的注意力张量
p_attn = scores.softmax(dim=-1)
# 判断是否使用dropout
if dropout is not None:
p_attn = dropout(p_attn)
# 根据公式将p_attn与value张量相乘获得最终的query注意力表示, 同时返回注意力张量
return torch.matmul(p_attn, value), p_attn
1.2 Multi-Head Attention
# 实现多头注意力机制
class MultiHeadedAttention(nn.Module):
# 在类的初始化时, 会传入三个参数,head代表头数,embedding_dim代表词嵌入的维度,dropout代表进行dropout操作时置0比率,默认是0.1.
def __init__(self, h, d_model, dropout=0.1):
"Take in model size and number of heads."
super(MultiHeadedAttention, self).__init__()
# 断言语句,判断h是否能被d_model整除,以便后续计算embedding_dim/head的数量,因为需要计算给每个头分配等量的词特征.
assert d_model % h == 0
# We assume d_v always equals d_k
# 获得每个头的分割词向量维度
self.d_k = d_model // h
self.h = h
# 通过nn的Linear实例化,它的内部变换矩阵是embedding_dim->embedding_dim,然后使用clones函数克隆四个,
# 生成四个Linear:Q、K、V以及最后拼接矩阵各一个
self.linears = clones(nn.Linear(d_model, d_model), 4)
# 最后得到的注意力张量
self.attn = None
self.dropout = nn.Dropout(p=dropout)
# 前向传播,mask是注意力机制中需要的mask掩码张量
def forward(self, query, key, value, mask=None):
"Implements Figure 2"
if mask is not None:
# Same mask applied to all h heads.
# 使用unsqueeze拓展维度
mask = mask.unsqueeze(1)
# 得到一个batch_size的变量,他是query尺寸的第1个数字,代表有多少条样本.
nbatches = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d_k
# 使用zip将QKV和线性层组合到一起,然后使用for循环将输入QKV分别传入线性层
# 线性变换之后开始为每个头分割输入,使用view方法对线性变换的结果进行维度重塑,将QKV的d_model维向量分解为h * d_k
# -1代表自适应维度,机器会根据这种变换自动计算这里的值,然后对第二维和第三维进行转置
# 这里使用了init定义中四个Linear层的三个
query, key, value = [
lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for lin, x in zip(self.linears, (query, key, value))
]
# 2) Apply attention on all the projected vectors in batch.
x, self.attn = attention(
query, key, value, mask=mask, dropout=self.dropout
)
# 3) "Concat" using a view and apply a final linear.
# 得到每个头计算结果组成的四维张量后需要转换为输入的形状,即先对第二维和第三维进行转置,然后使用contiguous()方法使转换后的形状能够使用view()方法
# 然后使用view重塑形状,变成和输入形状相同的三维
x = (
x.transpose(1, 2)
.contiguous()
.view(nbatches, -1, self.h * self.d_k)
)
del query
del key
del value
# 最后使用线性层列表中的最后一个线性层对输入进行线性变换得到最终的多头注意力结构的输出.
return self.linears[-1](x)
2 .1 Add&Norm
# 通过LayerNorm实现规范化层的类,随着网络层数的增加,通过多层的计算后参数可能开始出现过大或过小的情况,这样可能会导致学习过程出现异常,模型可能收敛非常的慢,因此都会在一定层数后接规范化层进行数值的规范化,使其特征数值在合理范围内
class LayerNorm(nn.Module):
"Construct a layernorm module (See citation for details)."
# features表示词嵌入的维度,eps是一个足够小的数,在规范化公式的分母中出现,防止分母为0
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
# 根据features的形状初始化两个张良a2、b2,a2初始化为全为1的张量,b2初始化为全为0的张量,这两个张量就是规范化层的参数
# 因为直接对上一层得到的结果做规范化公式计算,将改变结果的正常表征,因此就需要有参数作为调节因子,
# 使其即能满足规范化要求,又能不改变针对目标的表征.最后使用nn.parameter封装,代表他们是模型的参数
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
# 先对输入变量x的最后一个维度求均值,保持其输出维度和输入维度一致
# 接着求最后一个维度的标准差,然后使用标准正态分布公式计算规范化,即X减去均值除以标准差(X减去均值除以方差)
# 最后对结果诚意缩放参数,即a2,再加上位移参数b2
def forward(self, x):
mean = x.mean(-1, keepdim=True) # 均值
std = x.std(-1, keepdim=True) # 标准差
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
2.2 Feed Forward
class PositionwiseFeedForward(nn.Module):
"Implements FFN equation."
# 因为我们希望输入通过前馈全连接层后输入和输出的维度不变,第一个参数d_model是第一个线性层的输入维度也是第二个线性层的输出维度,第二个参数d_ff是第一个线性层的输出维度和第二个线性层的输入维度.
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 首先经过第一个线性层,然后使用Funtional中relu函数进行激活,
# 之后再使用dropout进行随机置0,最后通过第二个线性层w2,返回最终结果.
return self.w_2(self.dropout(self.w_1(x).relu()))
2.3 Encoder
# 使用SublayerConnection实现子层连接结构的类
class SublayerConnection(nn.Module):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
# size是词嵌入维度的大小
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
# 实例化规范化对象self.norm
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
# 前向传播中接收上一层或者子层的输入作为第一个参数,将该子层连接中的子层函数作为第二个参数
def forward(self, x, sublayer):
"Apply residual connection to any sublayer with the same size."
# LayerNorm+sublayer(Self-Attention/FeedForward)+dropout+残差连接
# 先对输入进行规范化然后将结果传给子层处理,子层做dropout操作后进行add操作,将输入x与dropout后的子层输出结果相加作为最终的子层连接输出
return x + self.dropout(sublayer(self.norm(x)))
# EncoderLayer由self-attention和feed forward组成
class EncoderLayer(nn.Module):
"Encoder is made up of self-attn and feed forward (defined below)"
def __init__(self, size, self_attn, feed_forward, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
# 构造两个SublayerConnection
self.sublayer = clones(SublayerConnection(size, dropout), 2)
self.size = size
def forward(self, x, mask):
"Follow Figure 1 (left) for connections."
# z = lambda x: self.self_attn(x, x, x, mask)
# sublayer[0]是SublayerConnection对象)的__call__方法,最终会调到它的forward方法,需要输入Tensor以及callable这两个参数forward调用sublayer[0],
# sublayer[0]是个callable,self.sublayer[0](x,z)会调用self.sublayer[0].call(x,z),然后调用SublayerConnection.forward(x, z)
# 然后会调用sublayer(self.norm(x)),sublayer就是传入的参数z,因此就是z(self.norm(x)
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
return self.sublayer[1](x, self.feed_forward)
# 将多层EncoderLayer连接,使用clones函数复制多个layer返回ModuleList
# 多个EncoderLayer的Stack
class Encoder(nn.Module):
"Core encoder is a stack of N layers"
# layer表示编码器层,N表示编码器层的个数
def __init__(self, layer, N):
super(Encoder, self).__init__()
# layer是SubLayer,使用clone复制N个编码器层放在self.layers中
self.layers = clones(layer, N)
# 初始化一个规范化层,它将用在编码器的最后面
self.norm = LayerNorm(layer.size)
def forward(self, x, mask):
"Pass the input (and mask) through each layer in turn."
# 逐层处理,对克隆的编码器层进行循环,每次都会得到一个新的x,循环过程相当于输出的x经过N个编码器层的处理
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
3. Decoder
class DecoderLayer(nn.Module):
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
# Decoder包括self-attn, src-attn, 和feed forward,self_attn和scr_attn只有输入的Q、K、V矩阵不同
# 分别是来自上一层的输入x,来自编码器层的语义存储变量memory
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 3)
# 普通Attention(src_attn)的Query是下层输入进来的(来自self_attn的输出),Key和Value是Encoder最后一层的输出memory;Self-Attention的Query、Key和Value都是来自下层输入进来的
def forward(self, x, memory, src_mask, tgt_mask):
"Follow Figure 1 (right) for connections."
m = memory
# 将x传入第一个子层结构,第一个子层结构的输入分别是x和self-attn函数,因为是自注意力机制,所以Q,K,V都是x,
# 最后一个参数是目标数据掩码张量,这时要对目标数据进行遮掩,因为此时模型可能还没有生成任何目标数据(模型生成第一个字符时会遮掩,生成第二个字符只能利用第一个字符的信息)
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
# 进入第二个子层,这个子层中计算常规的注意力,q是输入x; k,v是编码层输出的memory矩阵,
# 同样也传入source_mask,但是进行源数据遮掩的原因并非是抑制信息泄漏,而是遮蔽掉对结果没有意义的字符而产生的注意力值,
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
# 最后经过一个全连接子层后返回结果
return self.sublayer[2](x, self.feed_forward)
"""
由于Encoder是一次性计算出所有词,而Decoder在解码时必须根据encoder在不同时刻输出的内容进行解码,不能一次性进行解码,是根据已经预测出来的词预测下一个词,而计算是并行的,因此需要mask
subsequent_mask:
[[[ True False False False False]
[ True True False False False]
[ True True True False False]
[ True True True True False]
[ True True True True True]]]
可以将上面的输出表示为已经预测的单词个数为1、2、3、4、5
"""
# N个DecoderLayer组成的解码器
class Decoder(nn.Module):
"Generic N layer decoder with masking."
# layer是DecoderLayer,layer会使用回调函数callable调用DecoderLayer的forward方法
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
# 输入x,输出memory,输入mask,输出mask
def forward(self, x, memory, src_mask, tgt_mask):
# x经过N个解码器层解码
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
二、Vision Transformer
1 Patch embed
# 将2D图片转换为196x768的Patch Embedding(token,词向量)
# 将图片分成16x16的patch,每个patch可以看做是一个token(词向量),共有224x224/16/16=196个token,每个token的维度为768,然后再经过layernorm映射到768维
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
image_size(224,224)
patch_size(16,16)
grid_size(224/16,224/16)=(14,14)
num_patches=14 * 14 = 196
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # 窗格数量
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
# 使用16x16的卷积核,输入通道数3,输出通道数768,stride为16,即224x224x3->14x14x768
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
# 如果norm_layer = true 则使用layernorm,否则直接输出
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
'''
self.proj(x):[B,3,224,224]->[B,768,14,14]
flatten(2),降维:[B,768,14,14]->[B,768,14*14]=[B,768,196]
transpose(1, 2):[B,768,196]->[B,196,768]
'''
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
2 Multi-Head Attention
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
'''
num_heads = 12
head_dim = 768 / 12 = 64
scale = head_dim ** -0.5 = 8
'''
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5 # 维度缩放
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
# B = batch_size,N = 196 + 1 = 197,C = 768
B, N, C = x.shape
'''
qkv(x):[B,197,768]->[B,197,768*3]
reshape:[B,197,768*3]->[B,197,3,12,64]
permute(2,0,3,1,4):[B,197,3,12,64]->[3,B,12,197,64]
'''
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# qkv.unbind(0):qkv[0], qkv[1], qkv[2],给q、k、v分别赋值,即每个为[B,12,197,64]
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
'''
k.transpose(-2, -1):[B,12,197,64]->[B,12,64,197]
q @ k.transpose(-2, -1):[B,12,197,64] @ [B,12,64,197] = [B,12,197,197],@是矩阵乘法
attn:[B,12,197,197]
'''
attn = (q @ k.transpose(-2, -1)) * self.scale # q和k计算
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
'''
(attn @ v):[B,12,197,197] @ [B,12,197,64] = [B,12,197,64]
transpose(1, 2):[B,197,12,64]
reshape(B, N, C):[B,197,12,64]->[B,197,12*64] = [B,197,768]
'''
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # q、k计算完后和v计算
x = self.proj(x)
x = self.proj_drop(x)
return x
3 Block
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
4 VisionTransformer
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True,是否使用 qkv 偏置(即使用 Linear 将输入映射到 qkv 时,Linear是否使用 bias )
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
'''
self.num_features = self.embed_dim = embed_dim = 768
self.num_prefix_tokens = 1
'''
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.grad_checkpointing = False
# 构建patch embedding layer
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
num_patches = self.patch_embed.num_patches # num_patches=14 * 14 = 196
'''
self.cls_token:[1,1,768]
self.pos_embed:[1,197,768],位置编码
cls token的作用是为了同NLP领域的Transformer保持一致,最后直接提取cls token作为网络提取得到的特征,作为类别预测的输出,放入MLP进行分类
不使用cls token也是可以的,对196个维度为768的patch进行全局均值池化,得到结果维度为[196,1]后作为类别预测的输出,放入MPL进行分类
'''
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None # 生成一个(1,1,768),全为0的tensor作为cls_token
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens # 196
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) # (1,196,768)
self.pos_drop = nn.Dropout(p=drop_rate)
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
'''
构建首项为0,长度为depth的等差数列,且每一项小于drop_path_rate(默认为0),即传入block的drop_path概率是递增的
构建12层block后进行LayerNorm
'''
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule,linspace()函数用于在线性空间中以均匀步长生成数字序列。
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer
)
for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Classifier Head
# 构建分类器:
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
self.init_weights(weight_init)
# 初始化权重
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
def _init_weights(self, m):
# this fn left here for compat with downstream users
init_weights_vit_timm(m)
# 加载预训练模型
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
_load_weights(self, checkpoint_path, prefix)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
# 初始化pos_embed、cls_token
def _pos_embed(self, x):
'''
合并 cls_token:
x:[B,196,768]
如果没有embed_class,那么x先与位置编码相加然后再加上cls,否则x先与加上cls然后再加上位置编码
'''
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
'''
patch_embed加上pos_embed构成encoder输入
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1):加上cls,[B,197,768]
'''
x = x + self.pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) # cls和xconcat
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed
return self.pos_drop(x)
def forward_features(self, x):
# self.patch_embed(x):[B,3,224,224]->[B,196,768]
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
#
def forward_head(self, x, pre_logits: bool = False):
'''
如果为全局平均池化则会对全部token进行平均池化然后送入layernorm,否则不进行池化只使用cls_toke,即x[:,0]只取第一个token(即只有cls输出结果)
'''
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x) # 分类
return x if pre_logits else self.head(x)
'''
forward_features(x):x先进行PatchEmbed,然后进行PosEmbed(PatchEmbed加上PosEmbed后再加上cls),再进入block
block:layernorm->attention->(layerscale+dropout)->layernorm->mlp->(layerscale+dropout) ,最后再进行一个layernorm
'''
def forward(self, x):
x = self.forward_features(x) # [B,768]
x = self.forward_head(x) # [B,1000]
return x