目录
一.引言
SFT workflow 微调工作流程 一文中我们介绍了模型微调从数据到最终应用的流程
FastAPI 实现 get、post 请求 一文中我们介绍了如何使用 FastAPI 搭建简易接口
结合以上两者,我们使用 FastAPI 搭建一个简易的问答 Server。
二.辅助函数
1.黑名单
def check_sentence(_sentence, _blacklist=[]):
"""
检查句子中是否包含黑名单中的单词
参数:
sentence (str): 待检查的句子
blacklist (list): 黑名单单词列表
返回:
bool: 如果句子中不包含黑名单中的单词,则返回 True,否则返回 False
"""
for word in _blacklist:
if word in _sentence:
return False
return True
黑名单的逻辑很简单,遍历 sentence 中的 word 是否在自定义提供的 blacklist 中即可,这里主要是保证服务生成的句子不包含敏感和违法词汇,确保服务的安全性。
2.清除函数
def clean_sentence(_sentence):
# 删除网页链接
text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', _sentence)
# 删除@提及和#话题标签
text = re.sub(r'\@\w+|\#', '', text)
# 删除标点符号和特殊字符
text = re.sub(r'[%s]' % re.escape(punctuation), '', text)
# 去除 \r\s\n\t
text = re.sub(r'\\r|\\s|\\n|\\t|\r|\s|\n|\t', '', text)
# 合并正文中过多的空格
text = re.sub(r'\s+', ' ', text)
# 去除\u200b字符
text = text.replace('\u200b', '')
return text
黑名单逻辑保证生成句子的安全性,清除函数保证生成句子的合理性,这里是几个常用逻辑,大家有可以根据自己场景的需求和模型生成句子的特点进行修改:
◆ 删除网页链接
◆ 删除@与话题词
◆ 删除标点符号与特殊字符
◆ 删除 \t \n 等转移符号
◆ 删除过多空格
三.模型函数
1.加载模型
def load_lora_model(model_path, ckpt_path, compute_type=torch.bfloat16):
st = time.time()
# 载入预训练模型与 Tokenizer
config_kwargs = {
"trust_remote_code": True,
"cache_dir": None,
"revision": 'main',
"use_auth_token": None,
}
# 载入预训练模型
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="left", **config_kwargs)
config = AutoConfig.from_pretrained(model_path, **config_kwargs)
model = AutoModelForCausalLM.from_pretrained(
model_path,
config=config,
torch_dtype=compute_type,
low_cpu_mem_usage=True,
trust_remote_code=True,
revision='main'
)
model = PeftModel.from_pretrained(model, ckpt_path)
model = model.merge_and_unload()
# 修正模型参数
model.requires_grad_(False)
# 精度减半[cast from fp32 to fp16] export 模型预测时不 half
model = model.half() if model.config.torch_dtype == torch.float32 else model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
device_map = infer_auto_device_map(model, max_memory=get_balanced_memory(model))
model = dispatch_model(model, device_map)
print('multi GPU predict => {}'.format(device_map))
else:
model = model.cuda()
print("single GPU predict")
print('config = ', model.config)
end = time.time()
print('time cost: {}'.format(end - st))
return model, tokenizer
◆ 加载 LoRA 模型: LoRA 模型合并与保存
由于博主使用 LoRA 微调后的模型,所以涉及到加载 LoRA 参数,如果有完整的模型文件,忽略最后两行 PeftModel 和 merge_and_unload 方法即可。
◆ 多卡加载: 多卡加载与推理测试
多卡加载适合单卡内存不足以支持服务部署,或者希望多卡可以加速推理的情况。以 P40 为例,13B 模型使用 2 张 P40 部署服务,而 LLaMA-33B 则需要 4 张 P40。以 A800 为例,可以部署 2 x 13B 模型 + 1 x 7B 模型,或者单独部署一个 33B 模型。这里如果卡资源比较富裕,忽略即可。
◆ 量化加载: Model Load_in_8bit
量化加载对应 QLoRA,博主给出的示例并未使用 QLoRA,所以没有相关量化的步骤,有需要的同学可以参考上文,也可以到对应模型的 Git 界面,一些新模型内置简单易用的量化 API,可以更便捷的实现 LoRA。
2.生成配置
# 获取生成配置
def init_generation_args():
gen_conf = {
'do_sample': True,
'temperature': 0.95,
'top_p': 0.7,
'top_k': 50,
'num_beams': 1,
'max_new_tokens': 512,
'repetition_penalty': 1.0,
'length_penalty': 1.0
}
return gen_conf
◆ 批量推理: model batch generate 生成文本
服务生成的配置,为了统一写到一个函数里,如果想要动态控制,也可以通过 post 方法传参。更多参数含义与批量生成的方法可以参考上面的链接。
四.服务部署
1.post - predict
@app.post("/predict")
async def predict(request: Request):
now = datetime.datetime.now()
print('TIME: {} 开始预测 ...'.format(now.strftime("%Y-%m-%d %H:%M:%S")))
start = time.time()
js = await request.json()
req = js['question']
template = (
"{}"
)
query = template.format(req)
response = ''
i = 0
while len(re.findall(u"([\u4e00-\u9fa5])", response)) < 6:
input_ids = tokenizer(query, return_tensors="pt")['input_ids'].to(model.device)
output_ids = model.generate(input_ids, gen_config)
input_id_token_num = input_ids[0].shape[0]
response = tokenizer.decode(output_ids[0][input_id_token_num:], skip_special_tokens=True)
i += 1
if i > 5:
break
end = time.time()
cost = end - start
print('问:{}=>答:{}'.format(req, response))
if not check_sentence(response):
response = ''
print('time cost: {}'.format(cost))
return {'question': req, 'result': clean_sentence(response)}
文章顶部的链接介绍了如何实现简单的 get 和 post 请求,由于 LLM 模型语言生成时需要传入对应的 query,所以我们的推理方法需要使用 post 请求。
◆ Template
req = js['question']
template = (
"{}"
)
query = template.format(req)
上面给了默认的 Template 即模板,这里模板最好和训练时候对应的模板相对应,例如 Baichuan、LLaMA 等模型,官方都应用了不同的 Template,所以如果存在多个模型,需要注意修改正确的 Template。
◆ Sentence Length
while len(re.findall(u"([\u4e00-\u9fa5])", response)) < 6:
input_ids = tokenizer(query, return_tensors="pt")['input_ids'].to(model.device)
output_ids = model.generate(input_ids, gen_config)
input_id_token_num = input_ids[0].shape[0]
response = tokenizer.decode(output_ids[0][input_id_token_num:], skip_special_tokens=True)
i += 1
if i > 5:
break
第一个 While 循环条件 '[\u4e00-\u9fa5]' 这个正则表达式是用来匹配所有的中文字符。[\u4e00-\u9fa5] 是一个 Unicode 范围,代表了所有的中文字符。所以 re.findall 这行代码的含义是在 response 字符串中查找并返回所有的中文字符,而循环的要求需要生成的 response 中至少包含 6 个字符,否则持续生成,这里生成的配置根据上面的 init_generation_args 函数。
◆ CheckAndClean
print('问:{}=>答:{}'.format(req, response))
if not check_sentence(response):
response = ''
print('time cost: {}'.format(cost))
return {'question': req, 'result': clean_sentence(response)}
如果回答命中 black_list 则此次 response = '',除此之外还需要对 response 执行 clean 操作,去除无关的符号与字符,最终以 json 的形式返回。
2.get - clean_cache
@app.get("/clean_cache")
async def clean_cache():
import torch
torch.cuda.empty_cache()
print('receive clean_cache instruction...')
return {'flag': 'success'}
empty_cache 是 PyTorch 中的一个函数,它的作用是清理当前 CUDA 设备上未使用的缓存,以释放一些 GPU 内存。在你的程序长时间运行且需要频繁地进行张量创建和移动操作时我们可以调用该方法。然而要注意的是,频繁地调用此函数可能会导致效率下降,因为清理和重新填充缓存的操作本身也需要时间。由于该方法无需传递参数,所以使用 get 请求即可。
3.main - run_app
# -*- coding: utf-8 -*-
from fastapi import FastAPI, Request
import torch
import time
import datetime
import re
from string import punctuation
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel
app = FastAPI()
... ...
if __name__ == '__main__':
# 加载模型与生成参数
model_path = ""
ckpt_path = ""
model, tokenizer = load_lora_model(model_path, ckpt_path)
print("Finish Load Model...")
gen_kwargs = init_generation_args()
gen_config = GenerationConfig(**gen_kwargs)
print('generating_args = {}'.format(gen_kwargs))
print('gen_config = {}'.format(gen_config.to_dict()))
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8098)
把上面的函数添加到 ... 处,传入自己对应的模型地址和 LoRA 参数地址,uvicorn.run 运行服务即可。可以使用 post 方法调用获取模型文本生成的回答,也可以 get 方法清除内存。
五.总结
结合前面的文章,我们实现了 LLM 的训练与实践,通过服务的部署,我们可以将自己垂直领域微调得到的模型应用到自己的业务场景中,使用 LLM 的力量助力业务的扩展。