有人可以告诉我forward()方法中多个参数背后的概念吗?
通常,forward()方法的实现有两个参数



输入


如果正向方法的参数多于这些参数,PyTorch将如何使用正向方法。

让我们考虑以下代码库:
https://github.com/bamps53/kaggle-autonomous-driving2019/blob/master/models/centernet.py
在这里在线236作者使用了带有两个其他参数的正向方法:


中心
return_embeddings


我找不到任何一篇文章可以回答我关于第254(return_embeddings:)行和第257(if centers is not None:)行将要执行的条件的查询。据我所知,该方法由nn模块内部调用。有人可以为此点灯吗?

最佳答案

您设置的转发功能。这意味着您可以根据需要添加更多参数。例如,您可以添加输入,如下所示

def forward(self, input1, input2,input3):
    x = self.layer1(input1)
    y = self.layer2(input2)
    z = self.layer3(input3)

    net = torch.cat((x,y,z),1)

    return net


您必须在馈入网络时控制参数。不能使用超过一个参数的方式来馈送图层。因此,您需要一个一个地从输入中提取特征,并用torch.cat((x,y),1)(维数为1)将它们连接起来。

10-07 13:17
查看更多