如何修改Pytorch scripted_model 模型

本文演示了如何修改Pytorch scripted_model 结构,需求背景

  • 某些AI加速卡的推理软件栈会对模型做图优化,一些模型的图匹配策略不完善,导致编译失败
  • 方案一是等待厂家解决,方案二是自己修改图结构,向厂家支持的结构靠拢

源码

import sys
import os
import torch

max_seq_length=384
input = torch.randint(0, 2, (1, max_seq_length), dtype=torch.long)
scripted_model=torch.jit.load(model_path).eval()

torch._C._jit_pass_constant_propagation(scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)
torch._C._jit_pass_inline(scripted_model.graph)

with open("{}/graph.txt".format(prefix),"w") as f:
    f.write(str(scripted_model.graph))

# 修改匹配embedding
pattern = """
    graph(%input_ids.1,%44,%position_ids,%37,%45,%28):
        %63  = aten::size(%input_ids.1, %44) 
        %seq_length.1  = prim::NumToTensor(%63)
        %65  = aten::add(%seq_length.1, %37, %44) 
        %66  = aten::Int(%65)
        %67  = aten::slice(%position_ids, %45, %45, %28, %44) 
        %input.11  = aten::slice(%67, %44, %45, %66, %44)
        return (%input.11)
"""
replacement = """
    graph(%input_ids.1,%44,%position_ids,%37,%45,%28):
        %35 : int = prim::Constant[value=384]()
        %67  = aten::slice(%position_ids, %45, %45, %28, %44) 
        %input.11  = aten::slice(%67, %44, %45, %35, %44)
        return (%input.11)
"""

torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)

# 替换linear为matmul
pattern = """
    graph(%input.9, %weight.6, %bias.6):
        %x.5 = aten::linear(%input.9, %weight.6, %bias.6)
        return (%x.5)
"""
replacement = """
    graph(%input.7, %weight.6, %bias.6):
        %120  = aten::t(%weight.6)
        %45 : int = prim::Constant[value=1]()
        %output.10  = aten::matmul(%input.7, %120) 
        %122  = aten::add_(%output.10, %bias.6, %45)
        return (%122)
"""

torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)

# 删除掉split
pattern = """
    graph(%1056,%45,%44,%43):
        %1057 = aten::split(%1056, %44, %43)
        %start_logits.1 , %end_logits.1  = prim::ListUnpack(%1057)
        %1060 = aten::squeeze(%start_logits.1, %43)
        %1061 = aten::contiguous(%1060, %45)
        %1062 = aten::squeeze(%end_logits.1, %43)
        %1063 = aten::contiguous(%1062, %45)
        %1064 = prim::TupleConstruct(%1061, %1063)
        %11,%12  = prim::TupleUnpack(%1064)
        %15 =prim::TupleConstruct(%11, %12)
        return (%15)
"""
replacement = """
    graph(%1056,%45,%44,%43):
        return (%1056)
"""

torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)

with open("{}/graph_opt.txt".format(prefix),"w") as f:
    f.write(str(scripted_model.graph))

# 推理测试,确认模型正常
out = scripted_model(input,input,input)
for i in out:
    print(i.shape)
02-26 23:45