如何更改Pytorch预训练网络的激活层?
这是我的代码:
print("All modules")
for child in net.children():
if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
print(child)
print('Before changing activation')
for child in net.children():
if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
print(child)
child=nn.SELU()
print(child)
print('after changing activation')
for child in net.children():
if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
print(child)
这是我的输出:
All modules
ReLU(inplace=True)
Before changing activation
ReLU(inplace=True)
SELU()
after changing activation
ReLU(inplace=True)
最佳答案
我假设您使用模块接口nn.ReLU
来创建激活层,而不是使用功能接口F.relu
。如果是这样,setattr
为我工作。
import torch
import torch.nn as nn
# This function will recursively replace all relu module to selu module.
def replace_relu_to_selu(model):
for child_name, child in model.named_children():
if isinstance(child, nn.ReLU):
setattr(model, child_name, nn.SELU())
else:
replace_relu_to_selu(child)
########## A toy example ##########
net = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1),
nn.ReLU(inplace=True),
nn.Conv2d(3, 32, kernel_size=3, stride=1),
nn.ReLU(inplace=True)
)
########## Test ##########
print('Before changing activation')
for child in net.children():
if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
print(child)
# Before changing activation
# ReLU(inplace=True)
# ReLU(inplace=True)
print('after changing activation')
for child in net.children():
if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
print(child)
# after changing activation
# SELU()
# SELU(
关于python - 如何在Pytorch预训练模块中更改激活层?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/58297197/