我想做的是在自定义RNN类中使用DataParallel。

似乎我以错误的方式初始化了hidden_​​0。



class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
    super(RNN, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.n_layers = n_layers

    self.encoder = nn.Embedding(input_size, hidden_size)
    self.gru = nn.GRU(hidden_size, hidden_size, n_layers,batch_first = True)
    self.decoder = nn.Linear(hidden_size, output_size)
    self.init_hidden(batch_size)


def forward(self, input):
    input = self.encoder(input)
    output, self.hidden = self.gru(input,self.hidden)
    output = self.decoder(output.contiguous().view(-1,self.hidden_size))
    output = output.contiguous().view(batch_size,num_steps,N_CHARACTERS)
    #print (output.size())10,50,67

    return output

def init_hidden(self,batch_size):
    self.hidden = Variable(T.zeros(self.n_layers, batch_size, self.hidden_size).cuda())


我以这种方式称呼网络:

decoder = T.nn.DataParallel(RNN(N_CHARACTERS, HIDDEN_SIZE, N_CHARACTERS), dim=1).cuda()


然后开始训练:

for epoch in range(EPOCH_):
    hidden = decoder.init_hidden()


但是我得到了错误,并且我不知道如何解决它……


“ DataParallel”对象没有属性“ init_hidden”


谢谢你的帮助!

最佳答案

使用DataParallel时,原始模块将在并行模块的属性module中:

for epoch in range(EPOCH_):
    hidden = decoder.module.init_hidden()

09-16 02:53