torch.save()和torch.load()是PyTorch中用于模型保存和加载的函数。它们提供了一种方便的方式来保存和恢复模型的状态、结构和参数。可以使用它们来保存和加载整个模型或其他任意的Python对象,并且可以在加载模型时指定目标设备。
1.语法介绍
1.1 torch.save()语法
torch.save()函数用于将PyTorch模型保存到磁盘上的文件中,以便以后可以重新加载和使用。它的基本语法如下:
torch.save(obj, f, pickle_module=<module 'pickle' from '...'>, pickle_protocol=2)
参数说明:
obj是要保存的对象,通常是一个模型的状态字典(state_dict())。
f是文件的路径或文件对象,用于存储模型。
pickle_module是用于序列化的Python模块,默认为pickle。
pickle_protocol是序列化时使用的协议版本,默认为2。
1.2 torch.load()语法
torch.load()函数用于从磁盘上的文件加载保存的模型。它的基本语法如下:
torch.load(f, map_location=None, pickle_module=<module 'pickle' from '...'>)
参数说明:
f是要加载的文件的路径或文件对象。
map_location用于指定加载模型的设备(CPU或特定的GPU设备)。默认情况下,加载的模型将被存储在与保存模型时相同的设备上。
pickle_module是用于反序列化的Python模块,默认为pickle。
2. 基本使用示例介绍
2.1 保存和加载整个模型
除了保存和加载模型的状态字典外,torch.save()和torch.load()还可以用于保存和加载整个模型,包括模型的结构、参数和其他相关信息。
要保存整个模型,使用以下代码:
torch.save(model, 'model.pth')
要加载整个模型,使用以下代码:
model = torch.load('model.pth')
注意,加载整个模型时,需要确保模型的定义代码可用,因为它将用于重新创建模型的结构。
2.2 保存和加载其他对象
torch.save()和torch.load()不仅限于保存和加载模型,还可以用于保存和加载其他任意的Python对象。只需将要保存的对象传递给torch.save(),然后使用torch.load()来加载该对象。
例如:
data = [1, 2, 3, 4, 5]
torch.save(data, 'data.pth')
loaded_data = torch.load('data.pth')
这样可以方便地保存和加载各种数据,如训练集、测试集、预处理数据等。
2.3 跨设备加载模型
torch.load()函数允许在加载模型时指定目标设备。通过使用map_location参数,可以将模型加载到不同的设备上,例如从GPU加载到CPU或从一种GPU加载到另一种GPU。
以下是一个示例:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 从GPU加载到CPU
model = torch.load('model.pth', map_location='cpu')
# 从一种GPU加载到另一种GPU
model = torch.load('model.pth', map_location='cuda:1')
这对于在不同设备上运行模型或在没有GPU的机器上加载训练好的GPU模型非常有用。
2.4 序列化兼容性
torch.save()使用Python的pickle模块进行序列化,默认使用协议版本2。这个默认版本在PyTorch 1.6及更高版本中是兼容的。如果您需要与旧版本的PyTorch或其他Python库进行兼容,您可以通过设置pickle_protocol参数来选择不同的协议版本。
torch.save(model.state_dict(), 'model.pth', pickle_protocol=4)
在选择协议版本时,需要权衡序列化的性能和兼容性。
3. 模型保存和加载
当涉及到模型保存和加载时,还有一些其他的注意事项和用法:
3.1 保存和加载模型的状态字典
通常情况下,我们只保存和加载模型的状态字典(state_dict()),而不是整个模型。状态字典包含了模型的参数和缓冲区(如权重和偏置),但不包括模型的结构。这种做法更加灵活,因为它允许在加载模型时自由选择模型的结构,并且可以与不同的模型架构进行兼容。
#保存模型的状态字典:
torch.save(model.state_dict(), 'model.pth')
#加载模型的状态字典:
model = MyModel()
model.load_state_dict(torch.load('model.pth'))
请确保在加载模型之前,模型的定义与保存时的模型结构相匹配。
3.2 冻结某些层或参数
在某些情况下,可能希望冻结模型的某些层或参数,即在加载模型后不更新它们的参数。可以通过设置参数的requires_grad属性来实现这一点。
例如,假设模型有一个名为fc的全连接层,您可以冻结该层的参数:
model = MyModel()
model.load_state_dict(torch.load('model.pth'))
# 冻结全连接层的参数
for param in model.fc.parameters():
param.requires_grad = False
3.3 多个模型的保存和加载
如果您需要保存和加载多个模型,您可以将它们保存为一个字典,并使用一个文件来存储整个字典。
保存多个模型:
state = {
'model1': model1.state_dict(),
'model2': model2.state_dict()
}
torch.save(state, 'models.pth')
加载多个模型:
state = torch.load('models.pth')
model1.load_state_dict(state['model1'])
model2.load_state_dict(state['model2'])
这种方法可以方便地保存和加载多个相关模型。
3.3 保存和加载检查点
在训练过程中,可以定期保存模型的检查点,以便在训练过程中发生意外情况时能够恢复模型。通过定期保存检查点,可以避免从头开始训练,并从最新的检查点继续训练。
# 训练循环中的保存检查点
if epoch % checkpoint_interval == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, f'checkpoint_{epoch}.pth')
在发生中断或需要恢复训练时,可以加载最新的检查点:
# 加载最新的检查点
latest_checkpoint = get_latest_checkpoint(checkpoint_dir)
checkpoint = torch.load(latest_checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
这样,可以从最新的检查点恢复训练。