前言
深度学习框架提供了内置函数来保存和加载整个网络。需要注意的是,这将保存模型的参数而不是整个模型。
加载和保存
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden=nn.Linear(20,256)
self.output=nn.Linear(256,10)
def forward(self,X):
return self.output(F.relu(self.hidden(X)))
net=MLP()
X=torch.randn(size=(2,20))
Y=net(X)
torch.save(net.state_dict(),'mlp.param')
clone=MLP()
clone.load_state_dict(torch.load('mlp.param'))