问题描述
给定一个 3d Tenzor,说:batch x 句子长度 x embedding dim
Given a 3d tenzor, say:batch x sentence length x embedding dim
a = torch.rand((10, 1000, 96))
以及每个句子的实际长度数组(或张量)
and an array(or tensor) of actual lengths for each sentence
lengths = torch .randint(1000,(10,))
输出张量([ 370., 502., 652., 859., 545., 964., 566., 576.,1000., 803.])
如何根据张量长度"在维度 1(句子长度)的某个索引后用零填充张量a"?
How to fill tensor ‘a’ with zeros after certain index along dimension 1 (sentence length) according to tensor ‘lengths’ ?
我想要这样的:
a[ : , lengths : , : ] = 0
一种方法(如果批量足够大,速度会很慢):
One way of doing it (slow if batch size is big enough):
for i_batch in range(10):
a[ i_batch , lengths[i_batch ] : , : ] = 0
推荐答案
您可以使用二进制掩码来完成.
使用 lengths
作为 mask
的列索引,我们指示每个序列的结束位置(注意我们使 mask
长于 a.size(1)
允许全长序列).
使用 cumsum()
我们将 seq len 之后的 mask
中的所有条目设置为 1.
You can do it using a binary mask.
Using lengths
as column-indices to mask
we indicate where each sequence ends (note that we make mask
longer than a.size(1)
to allow for sequences with full length).
Using cumsum()
we set all entries in mask
after the seq len to 1.
mask = torch.zeros(a.shape[0], a.shape[1] + 1, dtype=a.dtype, device=a.device)
mask[(torch.arange(a.shape[0]), lengths)] = 1
mask = mask.cumsum(dim=1)[:, :-1] # remove the superfluous column
a = a * (1. - mask[..., None]) # use mask to zero after each column
对于 a.shape = (10, 5, 96)
和 lengths = [1, 2, 1, 1, 3, 0, 4, 4, 1, 3]
.
在每一行为各自的 lengths
分配 1,mask
看起来像:
For a.shape = (10, 5, 96)
, and lengths = [1, 2, 1, 1, 3, 0, 4, 4, 1, 3]
.
Assigning 1 to respective lengths
at each row, mask
looks like:
mask =
tensor([[0., 1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.]])
在cumsum
之后你得到
mask =
tensor([[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 0., 1., 1.]])
请注意,它在有效序列条目所在的位置正好有零,而在序列长度之外的地方有一个.取 1 - mask
给你你想要的.
Note that it exactly has zeros where the valid sequence entries are and ones beyond the lengths of the sequences. Taking 1 - mask
gives you exactly what you want.
享受;)
这篇关于在特定索引后用零填充火炬张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!