padding mask
一个batch中不同长度的句子需要添加padding变成统一的长度,因此需要使用padding mask功能对padding进行清除操作。
seq.shape: (batch, seqlen)
pad_idx : padding值
return.shape: (batch, 1, seqlen)
def get_pad_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(-2)
sequence mask
在decoder时,为了防止当前时刻看到未来时刻的信息,需要将未来时刻的信息进行掩码操作。
以下函数会返回一个对角线以下为True的矩阵。
# seq.shape: (batch, seqlen)
# return.shape: (1, seqlen, seqlen)
def get_subsequent_mask(seq):
''' For masking out the subsequent info. '''
sz_b, len_s = seq.size()
subsequent_mask = (1 - torch.triu(
torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool()
return subsequent_mask
inputs:
data = torch.randn(3, 5)
return :
tensor([[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]])