我创建了以下具有辍学层的深度网络,如下所示:
class QNet_dropout(nn.Module):
"""
A MLP with 2 hidden layer and dropout
observation_dim (int): number of observation features
action_dim (int): Dimension of each action
seed (int): Random seed
"""
def __init__(self, observation_dim, action_dim, seed):
super(QNet_dropout, self).__init__()
self.seed = torch.manual_seed(seed)
self.fc1 = nn.Linear(observation_dim, 128)
self.fc2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(128, 64)
self.fc4 = nn.Dropout(0.5)
self.fc5 = nn.Linear(64, action_dim)
def forward(self, observations):
"""
Forward propagation of neural network
"""
x = F.relu(self.fc1(observations))
x = F.linear(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.linear(self.fc4(x))
x = self.fc5(x)
return x
但是,当我尝试运行代码时,出现以下错误:
/home/workspace/QNetworks.py in forward(self, observations)
90
91 x = F.relu(self.fc1(observations))
---> 92 x = F.linear(self.fc2(x))
93 x = F.relu(self.fc3(x))
94 x = F.linear(self.fc4(x))
TypeError: linear() missing 1 required positional argument: 'weight'
似乎我没有正确使用/转发辍学层。对辍学层进行转发的正确方法应该是什么?谢谢!
最佳答案
F.linear()函数使用不正确。您应该使用声明的线性函数而不是torch.nn.functional。辍学层应该在Relu之后。您可以从torch.nn.functional调用Relu函数。
import torch
import torch.nn.functional as F
class QNet_dropout(nn.Module):
"""
A MLP with 2 hidden layer and dropout
observation_dim (int): number of observation features
action_dim (int): Dimension of each action
seed (int): Random seed
"""
def __init__(self, observation_dim, action_dim, seed):
super(QNet_dropout, self).__init__()
self.seed = torch.manual_seed(seed)
self.fc1 = nn.Linear(observation_dim, 128)
self.fc2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(128, 64)
self.fc4 = nn.Dropout(0.5)
self.fc5 = nn.Linear(64, action_dim)
def forward(self, observations):
"""
Forward propagation of neural network
"""
x = self.fc2(F.relu(self.fc1(observations)))
x = self.fc4(F.relu(self.fc3(x)))
x = self.fc5(x)
return x
observation_dim = 512
model = QNet_dropout(observation_dim, 10, 512)
batch_size = 8
inpt = torch.rand(batch_size, observation_dim)
output = model(inpt)
print ("output shape: ", output.shape)
关于python - 如何正确转发辍学层,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/56401266/