pytorch LLM训练过程中的精度调试实践
本文记录了,在某加速卡上进行LLM训练,精度问题的定位过程
1.查看权值的最大,最小值
tee dump_weight.py <<-'EOF'
import torch
import json
import numpy as np
weight_files=set()
with open('../llama-2-7b-hf/pytorch_model.bin.index.json', 'r') as f:
index_data = json.load(f)["weight_map"]
for k,v in index_data.items():
weight_files.add(v)
gmax=[]
gmin=[]
for i in weight_files:
for k,w in torch.load(f"../llama-2-7b-hf/{i}",map_location="cpu").items():
print(f"{k:<64s},max:{w.max().item():8.2f},{w.min().item():8.2f}")
gmax.append(w.max().item())
gmin.append(w.min().item())
print(f"\n\nglobal max:{np.max(gmax)} min:{np.min(gmin)}")
EOF
python3 dump_weight.py
2.检测训练过程中的异常值
A.通过hook module,检测异常值
B.拦截算子,检测异常值,打印调用栈,保存输入参数,方便复现
C.拦截算子,同时执行cpu计算,对比误差,找到第一个精度异常的算子
D.以上的代码
import torch
from torch import nn
import math
import copy
import torch.nn.functional as F
device="xpu"
if torch.cuda.is_available():
device="cuda"
def check_tensor(tensor, module_name, hook_type):
if isinstance(tensor, torch.Tensor):
if not torch.isfinite(tensor).all():
print(f"[ERROR] Detected NaN or Inf in {hook_type} pass of {module_name} rank:{torch.distributed.get_rank()}")
#os._exit(0)
elif isinstance(tensor, list) or isinstance(tensor, tuple):
for t in tensor:
check_tensor(t, module_name, hook_type)
# 定义钩子函数来监测 NaN 和 Inf
def forward_hook(module, inputs, output, module_name):
check_tensor(inputs, module_name, "forward-inputs")
check_tensor(output, module_name, "forward-output")
def backward_hook(module, grad_input, module_name):
check_tensor(grad_input, module_name, "backward")
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
import inspect
@dataclass
class _ProfilerState:
cls: Any
object: Any = None
def is_valid(val,name,stack):
# 判断是否为tensor或Parameter
if isinstance(val, (torch.Tensor, nn.Parameter)):
if not torch.isfinite(val.cpu()).all():
print("[ERROR]:",name,stack)
return -1
return 0
def check_tensor(name,stack, tensor):
if isinstance(tensor,(torch.Tensor, nn.Parameter)):
return is_valid(tensor,name,stack)
elif isinstance(tensor, (tuple, list)):
for idx, t in enumerate(tensor):
if is_valid(t,name,stack)!=0:
return -1
return 0
def save_tensor_data(tensor,name):
if isinstance(tensor,(torch.Tensor, nn.Parameter)):
torch.save(tensor,f"{name}.pth")
elif isinstance(tensor, (tuple, list)):
for idx, t in enumerate(tensor):
save_tensor_data(t,f"{name}-{idx}")
def to_cpu_and_fp32(data):
"""
将输入数据中的所有GPU Tensor转换为CPU Tensor,并将FP16类型的张量转换为FP32类型。
参数:
data: 可能是单个Tensor,列表,元组,字典或这些数据类型的嵌套结构。
返回:
与输入结构相似,但所有的GPU张量都已转换为CPU张量,所有的FP16张量都转换为FP32张量。
"""
if isinstance(data, torch.Tensor):
# 将GPU张量转换为CPU张量
tensor = data.cpu()
# 如果张量是FP16类型,则转换为FP32
if tensor.dtype == torch.float16:
tensor = tensor.to(torch.float32)
return tensor
elif isinstance(data, list):
return [to_cpu_and_fp32(item) for item in data]
elif isinstance(data, tuple):
return tuple(to_cpu_and_fp32(item) for item in data)
elif isinstance(data, dict):
return {key: to_cpu_and_fp32(value) for key, value in data.items()}
else:
# 如果既不是Tensor也不是列表、元组或字典,则直接返回数据
return data
@dataclass
class TensorDesc:
cat: Any
shape: Any
dtype: Any
value: Any
@dataclass
class DataDescriptor:
class_name: Any
shape: Any
value: Any
dtype: Any
max_v: Any
min_v: Any
def __repr__(self) -> str:
output_str=[]
if self.shape:
output_str.append("shape:({})".format(",".join([str(x) for x in self.shape])))
if self.max_v:
output_str.append(f"max:{self.max_v:.6f} min:{self.min_v:.6f}")
if self.value is not None:
if self.class_name in ["list","tuple"]:
for t in self.value:
output_str.append(str(t))
else:
output_str.append(str(self.value))
if self.dtype and self.class_name in ["Tensor","ndarray","Parameter"]:
output_str.append(str(self.dtype))
return "{}({})".format(self.class_name,"-".join(output_str))
class InputDescriptor:
def __init__(self) -> None:
self.input_vars=[]
self.input_kwargs={}
def _save_var(self,v):
class_name=v.__class__.__name__
if class_name in ["Tensor","Parameter"]:
return DataDescriptor(class_name,list(v.shape),None,v.dtype,v.max().item(),v.min().item())
elif class_name in ["UntypedStorage"]:
pass
#return DataDescriptor(class_name,None,list(v),type(v))
elif class_name in ["int","float","str","dtype","layout","device","NoneType","bool","memory_format"]:
return DataDescriptor(class_name,None,v,type(v),None,None)
elif class_name in ["ndarray"]:
return DataDescriptor(class_name,list(v.shape),None,v.dtype,None,None)
elif class_name in ["list","tuple"]:
output=[]
for t in v:
output.append(self._save_var(t))
return DataDescriptor(class_name,None,output,None,None,None)
def save_vars(self,*args,**kwargs):
for arg in args:
self.input_vars.append(self._save_var(arg))
for k,v in kwargs.items():
self.input_kwargs[k]=self._save_var(v)
def __repr__(self) -> str:
return str(self.input_vars) + "#" + str(self.input_kwargs)
def compare_tensor(tensorA,tensorB,name,params):
if isinstance(tensorA,(torch.Tensor, nn.Parameter)):
mse_loss = torch.nn.MSELoss()
loss = mse_loss(tensorA.cpu().float(), tensorB.cpu().float()).item()
print(f"{name:<64s} {loss:f} {params}")
return loss<1e-2
elif isinstance(tensorA, (tuple, list)):
for idx, t in enumerate(tensorA):
if not compare_tensor(tensorA[idx],tensorB[idx],f"{name}-{idx}",params):
return False
return True
def is_in_blacklist(name):
black_list=["empty","like","zero","detach","has","view",
"copy","arange","fill","ones","lift_fresh","alias",
"scalar_tensor","clone","stack","slice","source",
"select","random","unsqueeze","expand","normal","bernoulli"]
for i in black_list:
if name.find(i)>=0:
return False
return True
class TorchOpDiffDispatchMode(TorchDispatchMode):
def __init__(self,parent):
super().__init__()
self.parent=parent
self.global_index=0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
func_packet = func._overloadpacket
op_name=f"{func}"
enable_dump= is_in_blacklist(op_name)
self.global_index+=1
if kwargs is None:
kwargs = {}
if enable_dump:
args_cpu = to_cpu_and_fp32(args)
kwargs_cpu = to_cpu_and_fp32(kwargs)
cpu_out=func(*args_cpu, **kwargs_cpu)
ret= func(*args, **kwargs)
if enable_dump:
desc=InputDescriptor()
desc.save_vars(*args,**kwargs)
if not compare_tensor(cpu_out,ret,op_name,str(desc)):
save_tensor_data(args,f"{self.global_index}_{torch.distributed.get_rank()}{op_name}-input")
save_tensor_data(ret,f"{self.global_index}_{torch.distributed.get_rank()}{op_name}-output")
return ret
class TorchNanDetDispatchMode(TorchDispatchMode):
def __init__(self,parent):
super().__init__()
self.parent=parent
self.global_index=0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
func_packet = func._overloadpacket
op_name=f"{func}"
self.global_index+=1
enable_dump= is_in_blacklist(op_name)
if kwargs is None:
kwargs = {}
stacks=[i for i in inspect.stack()]
stacks_sz=len(stacks)
msg=[]
for idx in range(stacks_sz-1,1,-1):
if "self" in stacks[idx].frame.f_locals:
class_name = stacks[idx].frame.f_locals["self"].__class__.__name__
else:
class_name=""
msg.append(f"{stacks[idx].filename}:[{class_name}]:{stacks[idx].function}")
valid=0
if enable_dump:
valid+=check_tensor(f"aten-{op_name}-input","\n".join(msg),args)
ret= func(*args, **kwargs)
if enable_dump:
valid+=check_tensor(f"{op_name}-output","\n".join(msg), ret)
if valid!=0:
save_tensor_data(args,f"{self.global_index}_{torch.distributed.get_rank()}{op_name}-input")
save_tensor_data(ret,f"{self.global_index}_{torch.distributed.get_rank()}{op_name}-output")
return ret
class TorchHook:
_CURRENT_Dumper = None
def __init__(self,state):
self.p= _ProfilerState(state)
def __enter__(self):
assert TorchHook._CURRENT_Dumper is None
TorchHook._CURRENT_Dumper = self
if self.p.object is None:
o = self.p.cls(self)
o.__enter__()
self.p.object = o
else:
self.p.object.step()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
TorchHook._CURRENT_Dumper = None
if self.p.object is not None:
self.p.object.__exit__(exc_type, exc_val, exc_tb)
del self.p.object
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self,query, key, value, mask=None, dropout=None):
d_k = query.size(-1)
scores = query@key.transpose(-2,-1) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e20)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return p_attn@value, p_attn
class MultiHeadAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
self.attention = ScaledDotProductAttention()
def forward(self, query, key, value, mask=None):
if mask is not None:
mask = mask.unsqueeze(1)
nbatches = query.size(0)
query=self.linears[0](query).view(nbatches, -1, self.h, self.d_k)
query=query.transpose(1, 2)
key=self.linears[1](key).view(nbatches, -1, self.h, self.d_k)
key=key.transpose(1, 2)
value=self.linears[2](value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
x, self.attn = self.attention(query, key, value, mask=mask,
dropout=self.dropout)
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x)
torch.random.manual_seed(1)
model = MultiHeadAttention(h=8, d_model=64).half().to(device)
model.eval()
# 通过hook Module去检测是否有Nan值
for name, module in model.named_modules():
if isinstance(module, nn.Module): # 检测是否是 nn.Module 的子类
module.register_forward_hook(lambda module, inputs, output, name=name: forward_hook(module, inputs, output, name))
module.register_full_backward_pre_hook(lambda module, grad_input, name=name: backward_hook(module, grad_input, f"{name}-in"))
module.register_full_backward_hook(lambda module, grad_input,name=name: backward_hook(module, grad_input, f"{name}-out"))
q1 = torch.ones((100, 50, 64),dtype=torch.float32).half().to(device)
k1 = q1.clone()
v1 = q1.clone()
# 通过拦截算子去检测是否有nan值,并打印调用栈,保存输入参数,方便后续复现
with TorchHook(TorchNanDetDispatchMode):
out = model(q1,k1,v1).sum()
print("out:",out.item())
# 通过与cpu比较算子的计算误差,保存输入参数
with TorchHook(TorchOpDiffDispatchMode):
out = model(q1,k1,v1).sum()
print("out:",out.item())
3.根据上面dump的数据,准备最小复现环境
tee demo.py <<-'EOF'
import torch
import numpy as np
device="xpu"
def largest_k_errors(tensor1, tensor2, k, eps=1e-12):
"""
找到两个张量之间误差最大的 k 个元素及其原始数据。
参数:
tensor1: 第一个输入张量。
tensor2: 第二个输入张量。
k (int): 要找的误差最大的元素个数。
eps (float): 防止除以零的一个很小的值,默认值为1e-12。
返回:
包含误差最大的 k 个元素及其原始数据的元组 (errors, indices, orig_values1, orig_values2)。
"""
# 对输入张量进行广播操作,使它们的形状一致
tensor1, tensor2 = torch.broadcast_tensors(tensor1, tensor2)
# 计算绝对误差
absolute_diff = torch.abs(tensor1 - tensor2)
# 计算基准值,避免除以零的情况
denom = torch.abs(tensor1) + torch.abs(tensor2) + eps
# 计算相对误差
relative_error = absolute_diff / denom
# 把inf和nan替换为零
relative_error = torch.where(
torch.isfinite(relative_error),
relative_error,
torch.zeros_like(relative_error)
)
# 找到误差最大的 k 个元素的索引
k = min(k, tensor1.numel()) # 确保 k 不超过张量的元素总数
_, indices = torch.topk(relative_error.view(-1), k)
# 提取原始数据和误差
errors = relative_error.view(-1)[indices]
orig_values1 = tensor1.view(-1)[indices]
orig_values2 = tensor2.view(-1)[indices]
return errors, indices, orig_values1, orig_values2
left=torch.load("8611_0aten.silu_backward.default-input-0.pth",map_location="cpu")
right=torch.load("8611_0aten.silu_backward.default-input-1.pth",map_location="cpu")
tensor1=torch.ops.aten.silu_backward.default(left,right).cpu().float().reshape(-1)
out_xpu=torch.ops.aten.silu_backward.default(left.to(device),right.to(device))
tensor2=out_xpu.cpu().float().reshape(-1)
k = 10 # 找出误差最大的K个元素
errors, indices, orig_values1, orig_values2 = largest_k_errors(tensor1, tensor2, k)
print("----------------- torch.ops.aten.silu_backward.default --------------------------")
print(f"left dtype:{left.dtype} stride:{left.stride} shape:{left.shape} max:{left.max().item()} min:{left.min().item()}")
print(f"right dtype:{right.dtype} shape:{right.shape} max:{right.max().item()} min:{right.min().item()}")
print("Top", k, "errors:")
for i in range(k):
print(f"Index {indices[i].item()}: Error {errors[i].item()}, "
f"Original Values: CPU_FP16 = {orig_values1[i].item()}, "
f"XPU_FP16 = {orig_values2[i].item()}")
EOF
python3 demo.py