1. 介绍
本文介绍如何在pytorch中载入模型的部分权重
, 总结了2个比较常见的问题:
- 第1个常见的问题: 在分类网络中,当载入的预训练权重的全连接层与我们自己实例化模型的节点个数不一样时,该如何载入?
- 第2个常见的问题: 如果对网络的结构进行了一定的修改,修改之后很明显是不能直接载入预训练权重了。
2. 代码实现说明
以分类网络ResNet为例说明,对应项目中的load_weights.py
来介绍对部分权重进行载入。
import os
import torch
import torch.nn as nn
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# option1
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
# option2
# net = resnet34(num_classes=5)
# pre_weights = torch.load(model_weight_path, map_location=device)
# del_key = []
# for key, _ in pre_weights.items():
# if "fc" in key:
# del_key.append(key)
#
# for key in del_key:
# del pre_weights[key]
#
# missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
# print("[missing_keys]:", *missing_keys, sep="\n")
# print("[unexpected_keys]:", *unexpected_keys, sep="\n")
if __name__ == '__main__':
main()
下载官方提供的ResNet34预训练模型, 并将它命名为resnet34-pre.pth
,接下来介绍官方提供的载入部分权重的方法。
2. 1 pytorch 官方提供方法
- 首先实例化resnet34模型,注意并没有传入
num_classes
参数,此时默认的num_classes=1000,此时就可以直接载入官方的预训练权重。因为我们使用的是默认的全连接层个数1000,与预训练权重是一致的。
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# option1
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
- 由于我们自己的分类个数是不等于1000的,比如我们这里的分类个数为
5
,接下来该怎么办呢?首先查看resnet34模型搭建的源码。可以看到全连接层是通过sef.fc=nn.Linear(512*block.expansion,num_class)
这条语句实现的。
点开nn.Linear
类,可以看到它有这么几个参数self.in_features
和self.out_features
,分别表示全连接层的输入和输出
的节点个数。对于imagenet-1k,输出节点个数self.out_features对应的就是1000. 因此我们可以通过fc.in_features
获得网络的输入节点个数,然后输出节点个数定义为我们自己的分类个数5
。
net.fc=nn.Linear(in_channel,5)
通过创建新的全连接层来替换原来的全连接层。这样我们就变相的载入了Conv1
到layer4_x
的层结构,替换掉全连接层相当于没有载入全连接层权重,刚好符合我们的要求
2. 2 另外一种实现方式
net = resnet34(num_classes=5)
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key:
del_key.append(key)
for key in del_key:
del pre_weights[key]
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")
- 首先实例化resnet34,这里需要注意的是我们传入了
num_classes
参数,也就是最后一个全连接层节点个数一开始就设置为5了。此时就不能像前一种方法一样直接通过net.load_state_dict(torch.load(model_weight_path, map_location=device))
载入预训练权重了。因为网络的全连接层节点个数和预训练模型是不一样的,直接载入就会报错
。我们应该怎么办呢? - 通过
(torch.load(model_weight_path, map_location=device)
,先读取预训练权重保存为一个有序字典Orderedict
的形式。每个键值对对应一组参数和权重。
- 由于我们只想保留除全连接层fc之外的预训练权重,我们可以通过遍历pre_weights字典,去删减掉不需要的键值对。通过点击
resnet34
查看构建的代码,可以看到,其全连接层为self.fc
包含了fc
字段。除此之外,也可以通过实例化后的模型,调用state_dict()
函数,查看模型的所有模型权重的key和value值:
net = resnet34(num_classes=5)
net_weights = net.state_dict()
- 可以看到全连接层包含两个权重,分别是
fc.weight
和fc.bias
,此时我们可以遍历pre_weights的每个key值,如果key中包含有fc
这个字段我们就可以知道它是属于全连接层的权重,后续把包含fc的权重删除掉,然后我们再去载入剩下的权重。 - 我们
实例化的模型和载入的模型
,他们权重的名称(key值)要是一样的才可以载入和方便删减。还有一种情况可能载入模型的key与实例化的模型中的key
值不一样。那么这种情况的话就会比较麻烦点。那么就需要将载入模型的key值跟实例化一一对应,将载入模型的key改为实例化模型的key值
。这就需要你对网络搭建过程非常清楚,你要知道每个层它所对应的权重是什么,这样的话就可以编辑有序字典中的key来载入你想载入的权重。这个例子我们载入的权重和我们创建的模型它的key值都是一样的,因此相对于刚才说的这种情况,载入会比较简单些。 - 上面的例子,只要包含了
fc
字段,我们就将这个key值先存到del_key
列表中。通过调试可以发现del_key
存的就是fc_weights
和fc_bias
。紧接着我们再遍历del_key依次将这些key从pre_weights
字典中删除。
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key:
del_key.append(key)
for key in del_key:
del pre_weights[key]
- 这里需要注意,在载入预训练权重的时候,我们多传入了一个参数
strict=False
, 如果你不传的话,它默认是为True的。如果strict=True
它会严格的载入每个key值,因为我们删减掉全连接中的权重,因此就不能将strict设置为True。net.load_state_dict(pre_weights, strict=False)
会返回两个 变量,分别是missing_keys
和unexpected_keys
。missing_key
:表示在我们实例化的模型net中有部分权重并没有在pre_weights
预训练权重中出现,就相当于与pre_weights
中漏掉了这些权重。unexpected_key
:就是说在我们载入的pre_weights
中有一部分权重它不在我们的net中,此时就会存在unexpected_keys
中。针对我们刚才讲的情况,应该会出现两个missing_key
:fc.weights和fc.bias:
执行以后打印的信息:missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False) print("[missing_keys]:", *missing_keys, sep="\n") print("[unexpected_keys]:", *unexpected_keys, sep="\n")
可以看到>> [missing_keys]: >> fc.weight >> fc.bias >> [unexpected_keys]:
missing_key
中有fc.weights和fc.bias,在unexpected_keys
中是没有任何参数的。也就时除了fc.weights和fc.bias两个全连接参数外,其他参数都载入进来了。
如果有些人,除了
fc层外还改动了某些高层的结构如resnet中Conv5_x,我们如何去载入低层没有改动的权重呢?
: 此时对于resnet模型就需要载入除了Conv5_x
和fc
层之外的所有权重
此时我们可以在条件中,判断key是否包含layer4
,如果有的话也将它删掉。
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key or "layer4" in key:
del_key.append(key)
for key in del_key:
del pre_weights[key]
执行之后,我们发现在missing_key
列表中除了我们之前两个全连接层权重之外,剩下,剩下的都是layer4所对应的权重,也就是说我们也没有将layer4所对应的权重载入进去。
总结
以上介绍的是2种比较常见的载入部分权重的方法,除了我们讲到的在载入的权重的有序字典筛选之外,我们可以自己新创建一个字典,新创建一个字典之后,可以自己组建key,value然后用上文介绍的方法进行载入就可以了,这样的话会更加的灵活.
- 在这里感谢B站霹雳吧啦Up主
代码链接:https://pan.baidu.com/s/1j34QBVb9ZKxWX7d1Vm9QrQ?pwd=stxx
提取码:stxx