如何修改Pytorch scripted_model 模型
本文演示了如何修改Pytorch scripted_model 结构,需求背景
- 某些AI加速卡的推理软件栈会对模型做图优化,一些模型的图匹配策略不完善,导致编译失败
- 方案一是等待厂家解决,方案二是自己修改图结构,向厂家支持的结构靠拢
源码
import sys
import os
import torch
max_seq_length=384
input = torch.randint(0, 2, (1, max_seq_length), dtype=torch.long)
scripted_model=torch.jit.load(model_path).eval()
torch._C._jit_pass_constant_propagation(scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)
torch._C._jit_pass_inline(scripted_model.graph)
with open("{}/graph.txt".format(prefix),"w") as f:
f.write(str(scripted_model.graph))
# 修改匹配embedding
pattern = """
graph(%input_ids.1,%44,%position_ids,%37,%45,%28):
%63 = aten::size(%input_ids.1, %44)
%seq_length.1 = prim::NumToTensor(%63)
%65 = aten::add(%seq_length.1, %37, %44)
%66 = aten::Int(%65)
%67 = aten::slice(%position_ids, %45, %45, %28, %44)
%input.11 = aten::slice(%67, %44, %45, %66, %44)
return (%input.11)
"""
replacement = """
graph(%input_ids.1,%44,%position_ids,%37,%45,%28):
%35 : int = prim::Constant[value=384]()
%67 = aten::slice(%position_ids, %45, %45, %28, %44)
%input.11 = aten::slice(%67, %44, %45, %35, %44)
return (%input.11)
"""
torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)
# 替换linear为matmul
pattern = """
graph(%input.9, %weight.6, %bias.6):
%x.5 = aten::linear(%input.9, %weight.6, %bias.6)
return (%x.5)
"""
replacement = """
graph(%input.7, %weight.6, %bias.6):
%120 = aten::t(%weight.6)
%45 : int = prim::Constant[value=1]()
%output.10 = aten::matmul(%input.7, %120)
%122 = aten::add_(%output.10, %bias.6, %45)
return (%122)
"""
torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)
# 删除掉split
pattern = """
graph(%1056,%45,%44,%43):
%1057 = aten::split(%1056, %44, %43)
%start_logits.1 , %end_logits.1 = prim::ListUnpack(%1057)
%1060 = aten::squeeze(%start_logits.1, %43)
%1061 = aten::contiguous(%1060, %45)
%1062 = aten::squeeze(%end_logits.1, %43)
%1063 = aten::contiguous(%1062, %45)
%1064 = prim::TupleConstruct(%1061, %1063)
%11,%12 = prim::TupleUnpack(%1064)
%15 =prim::TupleConstruct(%11, %12)
return (%15)
"""
replacement = """
graph(%1056,%45,%44,%43):
return (%1056)
"""
torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)
with open("{}/graph_opt.txt".format(prefix),"w") as f:
f.write(str(scripted_model.graph))
# 推理测试,确认模型正常
out = scripted_model(input,input,input)
for i in out:
print(i.shape)