有人可以告诉我forward()
方法中多个参数背后的概念吗?
通常,forward()
方法的实现有两个参数
自
输入
如果正向方法的参数多于这些参数,PyTorch将如何使用正向方法。
让我们考虑以下代码库:
https://github.com/bamps53/kaggle-autonomous-driving2019/blob/master/models/centernet.py
在这里在线236作者使用了带有两个其他参数的正向方法:
中心
return_embeddings
我找不到任何一篇文章可以回答我关于第254(return_embeddings:
)行和第257(if centers is not None:
)行将要执行的条件的查询。据我所知,该方法由nn模块内部调用。有人可以为此点灯吗?
最佳答案
您设置的转发功能。这意味着您可以根据需要添加更多参数。例如,您可以添加输入,如下所示
def forward(self, input1, input2,input3):
x = self.layer1(input1)
y = self.layer2(input2)
z = self.layer3(input3)
net = torch.cat((x,y,z),1)
return net
您必须在馈入网络时控制参数。不能使用超过一个参数的方式来馈送图层。因此,您需要一个一个地从输入中提取特征,并用
torch.cat((x,y),1)
(维数为1)将它们连接起来。