1.模型保存与加载
1.1
#a、保存 推荐仅仅保存模型的state_dict
torch.save(model.state_dict(), MODELPATH) # .pt .pth
#b、加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
#Pytorch保存的模型后缀一般是.pt或者.pth
#必须在加载模型后调用model.eval函数来将dropout及批归一化层设置为预测模式。如果不这么做结果出错。
1.2 a、保存临时模型用于预测或再训练
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss, ... },
PATH)
当保存一个临时模型用于预测或再训练时,需要保存比state_dict更多的参数。包括优化器的state_dict,迭代次数epoch,最后一层迭代的loss及其他任何需要的参数。
当保存多个组件时,将多个组件以字典的形式组织,然后用torch.savee()来序列化该字典。在Pytorch中常用.tar文件后缀表示这种模型。
b、加载
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval() #预测 # - or - model.train() #再训练
e.g.
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'lr': args.lr,
'optimizer' : optimizer.state_dict(),
}, checkpoint=args.checkpoint)