问题描述
当我尝试查看批处理时,通过打印BucketIterator
对象的下一次迭代,将抛出AttributeError
.
When I try to look into a batch, by printing the next iteration of the BucketIterator
object, the AttributeError
is thrown.
tv_datafields=[("Tweet",TEXT), ("Anger",LABEL), ("Fear",LABEL), ("Joy",LABEL), ("Sadness",LABEL)]
train, vld = data.TabularDataset.splits(path="./data/", train="train.csv",validation="test.csv",format="csv", fields=tv_datafields)
train_iter, val_iter = BucketIterator.splits(
(train, vld),
batch_sizes=(64, 64),
device=-1,
sort_key=lambda x: len(x.Tweet),
sort_within_batch=False,
repeat=False
)
print(next(iter(train_dl)))
推荐答案
我不确定您遇到的具体错误,但是在这种情况下,您可以使用以下代码遍历一个批处理:
I am not sure about the specific error you are getting but, in this case, you can iterate over a batch by using the following code:
for i in train_iter:
print i.Tweet
print i.Anger
print i.Fear
print i.Joy
print i.Sadness
i.Tweet
(也是其他)是形状为(input_data_length, batch_size)
的张量.
i.Tweet
(also others) is a tensor of shape (input_data_length, batch_size)
.
因此,要查看单个批处理数据(比如说批处理0),可以执行print i.Tweet[:,0]
.
So, to view a single batch data (lets say batch 0), you can do print i.Tweet[:,0]
.
val_iter
(和test_iter
,如果需要)也一样.
Same goes for val_iter
(and test_iter
, if needed).
这篇关于遍历Torchtext.data.BucketIterator对象将引发AttributeError'Field'对象没有属性'vocab'的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!