问题描述
我正在尝试编写一个具有多个层的 pytorch 模块.由于我需要中间输出,我不能像往常一样将它们全部放入 Sequantial 中.另一方面,由于有很多层,我的想法是将层放在一个列表中并在循环中通过索引访问它们.下面描述我想要实现的目标:
I am trying to write a pytorch module with multiple layers. Since I need the intermediate outputs I cannot put them all in a Sequantial as usual. On the other hand, since there are many layers, what I have in mind is to put the layers in a list and access them by index in a loop. Below describe what I am trying to achieve:
import torch
import torch.nn as nn
import torch.optim as optim
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer_list = []
self.layer_list.append(nn.Linear(2,3))
self.layer_list.append(nn.Linear(3,4))
self.layer_list.append(nn.Linear(4,5))
def forward(self, x):
res_list = [x]
for i in range(len(self.layer_list)):
res_list.append(self.layer_list[i](res_list[-1]))
return res_list
model = MyModel()
x = torch.randn(4,2)
y = model(x)
print(y)
optimizer = optim.Adam(model.parameters())
forward 方法工作正常,但是当我想设置优化器时,程序说
The forward method works fine, but when I want to set an optimizer the program says
ValueError: optimizer got an empty parameter list
列表中的图层似乎没有在这里注册.我能做什么?
It appears that the layers in the list are not registered here. What can I do?
推荐答案
如果你把你的图层放在 python 列表中,pytorch 不会正确注册它们.您必须使用 ModuleList
(https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html).
If you put your layers in a python list, pytorch does not register them correctly. You have to do so using ModuleList
(https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html).
ModuleList 可以像常规 Python 列表一样被索引,但它包含的模块已正确注册,并且所有模块方法都可以看到.
您的代码应该类似于:
import torch
import torch.nn as nn
import torch.optim as optim
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer_list = nn.ModuleList() # << the only changed line! <<
self.layer_list.append(nn.Linear(2,3))
self.layer_list.append(nn.Linear(3,4))
self.layer_list.append(nn.Linear(4,5))
def forward(self, x):
res_list = [x]
for i in range(len(self.layer_list)):
res_list.append(self.layer_list[i](res_list[-1]))
return res_list
通过使用 ModuleList
,您可以确保所有层都在计算图中注册.
By using ModuleList
you make sure all layers are registered in the computational graph.
还有一个 ModuleDict
如果你想按名称索引你的图层,你可以使用它.您可以在此处查看 pytorch 的容器:https://pytorch.org/docs/master/nn.html#containers
There is also a ModuleDict
that you can use if you want to index your layers by name. You can check pytorch's containers here: https://pytorch.org/docs/master/nn.html#containers
这篇关于如何通过索引访问 pytorch 模块中的层?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!