Pytorch用ConvTranspose2d替代Upsample
本文介绍了Pytorch如何用ConvTranspose2d算子等价替代Upsample算子。
背景介绍:
- 某些AI加速卡上Upsample算子的性能不够高,是否能用别的算子临时替代呢
- 可以手动推断出ConvTranspose2d 的权值,使其与Upsample等价算子
- 也可以搭建一个模型,输入分别给到ConvTranspose2d和Upsample算子,使它们之间的L1Loss最小
- 当网络收敛后,对ConvTranspose2d的权值做舍入处理
- 最后用上面的权值初始化ConvTranspose2d
网络结构
import onnx
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np
class UpsampleModel(torch.nn.Module):
def __init__(self):
super(UpsampleModel, self).__init__()
self.up=nn.Upsample(scale_factor=2, mode='nearest')
self.deconv1=nn.ConvTranspose2d(3,3,2,2,groups=1,bias=False)
def forward(self, x):
out0=self.up(x)
out1=self.deconv1(x)
return out0,out1
训练ConvTranspose2d的权值
def train():
input_shape = (1, 3, 224, 224)
model = UpsampleModel()
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(2100):
running_loss = 0.0
for i in range(100):
input_data = torch.randn(input_shape)
optimizer.zero_grad()
out0,out1=model(input_data)
loss = criterion(out0,out1)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss=running_loss / 100
print('[%d] loss: %f' % (epoch + 1,avg_loss ))
running_loss = 0.0
if avg_loss<1e-4:
w=model.deconv1.weight.detach().numpy()
#print(w)
print(np.round(w))
break
train()
结果
[[[[ 1. 1.]
[ 1. 1.]]
[[-0. -0.]
[-0. -0.]]
[[ 0. -0.]
[-0. -0.]]]
[[[ 0. 0.]
[ 0. -0.]]
[[ 1. 1.]
[ 1. 1.]]
[[ 0. 0.]
[ 0. -0.]]]
[[[ 0. -0.]
[-0. -0.]]
[[ 0. -0.]
[ 0. 0.]]
[[ 1. 1.]
[ 1. 1.]]]]
用上面生成的权值验证
def val():
w=np.array(
[[[[ 1. , 1.],
[ 1. , 1.]],
[[ 0. , 0.],
[ 0. , 0.]],
[[ 0. , 0.],
[ 0. , 0.]]],
[[[ 0. , 0.],
[ 0. , 0.]],
[[ 1. , 1.],
[ 1. , 1.]],
[[ 0. , 0.],
[ 0. , 0.]]],
[[[ 0. , 0.],
[ 0. , 0.]],
[[ 0. , 0.],
[ 0. , 0.]],
[[ 1. , 1.],
[ 1. , 1.]]]]
)
input_shape = (1, 3, 224, 224)
model = UpsampleModel().eval()
model.deconv1.weight=torch.nn.Parameter(torch.from_numpy(w.astype(np.float32))) #设置权值
input_data = torch.randn(input_shape)
out0,out1=model(input_data)
out0=out0.detach().numpy().reshape(-1)
out1=out1.detach().numpy().reshape(-1)
ret=(out0==out1).all()
val()
输出
True