网络的教程来看,在半精度amp训练出现nan问题,无非就是这几种:
- 计算loss 时,出现了除以0的情况
- loss过大,被半精度判断为inf
- 网络参数中有nan,那么运算结果也会输出nan(这个更像是现象而不是原因,网络中出现nan肯定是之前出现了nan或inf)
但是总结起来就三种:
- 运算错误,比如计算Loss时出现x/0造成错误
- 数值溢出,运算结果超出了表示范围,比如权重和输入正常,但是运算结果Nan或Inf。比如loss过大其实就是超出表示范围变成inf
- 梯度问题,可能梯度回传出现问题(不了解)
0、结论
先说结论,我使用amp半精度训练,即中间会参杂float16数据类型,加快训练过程。
但是本文出现Nan就是因为float16,因为float16支持的最大值在65504,而我的模型中涉及一个矩阵乘法(其实就是transformer中的q@k运算)。其中,a∈[-38,40],b∈[-39,40],而矩阵乘法a@b=c,c∈[-61408,inf]。因为a和b的矩阵乘法运算后最大值超过了float16最大表示,造成出现inf,所以最终结果出现Nan。
1、粗定位
一个训练的过程可以表示为以下流程:
1.1 定位到epoch
首先看到,在epoch4的输出loss是正常的,意味着在epoch4中的0~498iter的训练过程中正常,那么问题就可能出现在epoch4第499iter和epoch0~499iter这501个iter之中。
1.2 定位到iter
现在我们需要定位到具体iter。
可以根据二分法进行判断,debug模型,在epoch=5轮次中的100iter、300iter、499iter分别查看loss是否正常,依此类推定位到具体的iter。
我的是在epoch=5的iter161~162之间,iter=161时loss正常,iter=162时loss为Nan。iter还是遵从上图的流程,可以看到问题无非出现在iter=161时的梯度计算和权重更新,以及iter=162的前向运算和损失计算,这4处。
1.3 定位到具体步骤
在debug时直接暂停到epoch=5和iter=162的前向运算之前。
首先来看权重是否正常:
# 在iter=162的模型推理之前,检查权重是否存在异常值,比如Nan或inf
if epoch == 5:
if i == 162:
print(epoch, i)
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
# print grad check
v_n = []
v_v = []
v_g = []
for name, parameter in model.named_parameters():
v_n.append(name)
v_v.append(parameter.detach().cpu().numpy() if parameter is not None else [0])
v_g.append(parameter.grad.detach().cpu().numpy() if parameter.grad is not None else [0])
for j in range(len(v_n)):
if np.isnan(np.max(v_v[j]).item() - np.min(v_v[j]).item()) or np.isnan(
np.max(v_g[j]).item() - np.min(v_g[j]).item()):
color = bcolors.FAIL + '*'
else:
color = bcolors.OKGREEN + ' '
print('%svalue %s: %.3e ~ %.3e' % (color, v_n[j], np.min(v_v[j]).item(), np.max(v_v[j]).item()))
print('%sgrad %s: %.3e ~ %.3e' % (color, v_n[j], np.min(v_g[j]).item(), np.max(v_g[j]).item()))
outputs = model(images)
通过检查,证明权重没有问题,所以问题被限定iter=162的前向推理和损失计算两处
检查输入和输出
通过代码:
print(images.mean()) # 检查输入,正常
outputs = model(images)
print(outputs .mean()) # 检查输出,Nan
由此我们知道了情况:模型权重正常,模型输入正常,但是模型的输出Nan
2、精确定位
到这里就好办了,借助pycharm,我们一步一步调试在模型中各个模型的输入输出,看看到底是在模型的哪一个部分出现了Nan或者Inf,最终定位到一行代码:
attn = (q @ k.transpose(-2, -1)) * self.scale
这句代码是想实现q和k的矩阵乘法, 他们的值域分别为:
从这里可以发现,就是单纯的计算问题,一种很常见的就是数值溢出,考虑到我使用半精度float16,通过查询其最大值是65504,所以很有可能是最大值溢出了。为了验证,我们可以在计算前,将q和k转为double(float64),可以发现其计算结果正常了,类型也是float64。这表明就是因为数值溢出造成的。
3、解决办法
现在已知我的原因是数值溢出,一种方法是截取:将inf或nan设置为一个常量,我则在运算前将q和k进行norm归一化到[-1,1],这样保证了运算结果不会太大(没有什么原因,就是无脑操作,不建议学习)。