问题描述
我非常清楚加载字典然后有一个实例加载旧的参数字典(例如这个很好的问题和答案).不幸的是,当我有一个 torch.nn.Sequential
时,我当然没有它的类定义.
I am very well aware of loading the dictionary and then having a instance of be loaded with the old dictionary of parameters (e.g. this great question & answer). Unfortunately, when I have a torch.nn.Sequential
I of course do not have a class definition for it.
所以我想仔细检查一下,正确的做法是什么.我相信 torch.save
就足够了(到目前为止我的代码还没有崩溃),尽管这些事情可能比人们想象的更微妙(例如,我在使用 pickle 时收到警告,但 torch.save
在内部使用它,所以很混乱).此外,numpy 有它自己的保存功能(例如,参见 这个答案),这往往更有效,所以可能会有我可能会忽略一些微妙的权衡.
So I wanted to double check, what is the proper way to do it. I believe torch.save
is sufficient (so far my code has not collapsed), though these things can be more subtle than one might expect (e.g. I get a warning when I use pickle but torch.save
uses it internally so it's confusing). Also, numpy has it's own save functions (e.g. see this answer) which tend to be more efficient, so there might be a subtle trade off I might be overlooking.
我的测试代码:
# creating data and running through a nn and saving it
import torch
import torch.nn as nn
from pathlib import Path
from collections import OrderedDict
import numpy as np
import pickle
path = Path('~/data/tmp/').expanduser()
path.mkdir(parents=True, exist_ok=True)
num_samples = 3
Din, Dout = 1, 1
lb, ub = -1, 1
x = torch.torch.distributions.Uniform(low=lb, high=ub).sample((num_samples, Din))
f = nn.Sequential(OrderedDict([
('f1', nn.Linear(Din,Dout)),
('out', nn.SELU())
]))
y = f(x)
# save data torch to numpy
x_np, y_np = x.detach().cpu().numpy(), y.detach().cpu().numpy()
np.savez(path / 'db', x=x_np, y=y_np)
print(x_np)
# save model
with open('db_saving_seq', 'wb') as file:
pickle.dump({'f': f}, file)
# load model
with open('db_saving_seq', 'rb') as file:
db = pickle.load(file)
f2 = db['f']
# test that it outputs the right thing
y2 = f2(x)
y_eq_y2 = y == y2
print(y_eq_y2)
db2 = {'f': f, 'x': x, 'y': y}
torch.save(db2, path / 'db_f_x_y')
print('Done')
db3 = torch.load(path / 'db_f_x_y')
f3 = db3['f']
x3 = db3['x']
y3 = db3['y']
yy3 = f3(x3)
y_eq_y3 = y == y3
print(y_eq_y3)
y_eq_yy3 = y == yy3
print(y_eq_yy3)
相关:
Related:
推荐答案
从代码中可以看出torch.nn.Sequential
是基于torch.nn.Module
:https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential
As can be seen in the code torch.nn.Sequential
is based on torch.nn.Module
:https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential
所以你可以使用
f = torch.nn.Sequential(...)
torch.save(f.state_dict(), path)
就像任何其他torch.nn.Module
一样.
这篇关于如何正确保存 pytorch 中的 torch.nn.Sequential 模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!