以bert为例,了解Lora是如何添加到模型中的
本文以bert为例,对比了添加Lora模块前后的网络结构图
说明:
- 1.为了加快速度,将bert修改为一层
- 2.lora只加到intermediate.dense,方便对比
- 3.使用了几种不同的可视化方式(onnx可视化,torchviz图,torch.fx可视化,tensorboard可视化)
可参考的点:
- 1.peft使用
- 2.几种不同的pytorch模型可视化方法
一.效果图
1.torch.fx可视化
A.添加前
B.添加后
2.onnx可视化
A.添加前
B.添加后
3.tensorboard可视化
A.添加前
B.添加后
二.复现步骤
1.生成配置文件(num_hidden_layers=1)
tee ./config.json <<-'EOF'
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 1,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"type_vocab_size": 2,
"vocab_size": 21128
}
EOF
2.运行测试脚本
tee bert_lora.py <<-'EOF'
import time
import os
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.init as init
import time
import numpy as np
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType
from torchviz import make_dot
from torch.utils.tensorboard import SummaryWriter
from torch._functorch.partitioners import draw_graph
def onnx_infer_shape(onnx_path):
import onnx
onnx_model = onnx.load_model(onnx_path)
new_onnx= onnx.shape_inference.infer_shapes(onnx_model)
onnx.save_model(new_onnx, onnx_path)
def get_model():
torch.manual_seed(1)
from transformers import AutoModelForMaskedLM,BertConfig
config=BertConfig.from_pretrained("./config.json")
model = AutoModelForMaskedLM.from_config(config)
return model,config
def my_compiler(fx_module: torch.fx.GraphModule, _):
draw_graph(fx_module, f"bert.{time.time()}.svg")
return fx_module.forward
if __name__ == "__main__":
model,config=get_model()
model.eval()
input_tokens=torch.randint(0,config.vocab_size,(1,128))
# 一.原始模型
# 1.onnx可视化
torch.onnx.export(model,input_tokens,
"bert_base.onnx",
export_params=False,
opset_version=11,
do_constant_folding=True)
onnx_infer_shape("bert_base.onnx")
# 2.torchviz图
output = model(input_tokens)
logits = output.logits
viz = make_dot(logits, params=dict(model.named_parameters()))
viz.render("bert_base", view=False)
# 3.torch.fx可视化
compiled_model = torch.compile(model, backend=my_compiler)
output = compiled_model(input_tokens)
# 4.tensorboard可视化
writer = SummaryWriter('./runs')
writer.add_graph(model, input_to_model = input_tokens,use_strict_trace=False)
writer.close()
# 二.Lora模型
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=True,
r=8,
lora_alpha=32,
target_modules=['intermediate.dense'],
lora_dropout=0.1,
)
lora_model = get_peft_model(model, peft_config)
lora_model.eval()
torch.onnx.export(lora_model,input_tokens,
"bert_base_lora_inference_mode.onnx",
export_params=False,
opset_version=11,
do_constant_folding=True)
onnx_infer_shape("bert_base_lora_inference_mode.onnx")
compiled_model = torch.compile(lora_model, backend=my_compiler)
output = compiled_model(input_tokens)
writer = SummaryWriter('./runs_lora')
writer.add_graph(lora_model, input_to_model = input_tokens,use_strict_trace=False)
writer.close()
EOF
# 安装依赖
apt install graphviz -y
pip install torchviz
pip install pydot
# 运行测试程序
python bert_lora.py