我是一个对称矩阵

我是一个对称矩阵

网络的教程来看,在半精度amp训练出现nan问题,无非就是这几种:

  • 计算loss 时,出现了除以0的情况
  • loss过大,被半精度判断为inf
  • 网络参数中有nan,那么运算结果也会输出nan(这个更像是现象而不是原因,网络中出现nan肯定是之前出现了nan或inf)

但是总结起来就三种:

  • 运算错误,比如计算Loss时出现x/0造成错误
  • 数值溢出,运算结果超出了表示范围,比如权重和输入正常,但是运算结果Nan或Inf。比如loss过大其实就是超出表示范围变成inf
  • 梯度问题,可能梯度回传出现问题(不了解)

0、结论

先说结论,我使用amp半精度训练,即中间会参杂float16数据类型,加快训练过程。

但是本文出现Nan就是因为float16,因为float16支持的最大值在65504,而我的模型中涉及一个矩阵乘法(其实就是transformer中的q@k运算)。其中,a∈[-38,40],b∈[-39,40],而矩阵乘法a@b=c,c∈[-61408,inf]。因为a和b的矩阵乘法运算后最大值超过了float16最大表示,造成出现inf,所以最终结果出现Nan。

1、粗定位

一个训练的过程可以表示为以下流程:
记录PyTorch中半精度amp训练出现Nan的排查过程-LMLPHP

1.1 定位到epoch

首先看到,在epoch4的输出loss是正常的,意味着在epoch4中的0~498iter的训练过程中正常,那么问题就可能出现在epoch4第499iter和epoch0~499iter这501个iter之中。

1.2 定位到iter

现在我们需要定位到具体iter。

可以根据二分法进行判断,debug模型,在epoch=5轮次中的100iter、300iter、499iter分别查看loss是否正常,依此类推定位到具体的iter。

我的是在epoch=5的iter161~162之间,iter=161时loss正常,iter=162时loss为Nan。iter还是遵从上图的流程,可以看到问题无非出现在iter=161时的梯度计算和权重更新,以及iter=162的前向运算和损失计算,这4处。

1.3 定位到具体步骤

在debug时直接暂停到epoch=5和iter=162的前向运算之前。

首先来看权重是否正常:

# 在iter=162的模型推理之前,检查权重是否存在异常值,比如Nan或inf
if epoch == 5:
    if i == 162:
        print(epoch, i)

        class bcolors:
            HEADER = '\033[95m'
            OKBLUE = '\033[94m'
            OKGREEN = '\033[92m'
            WARNING = '\033[93m'
            FAIL = '\033[91m'
            ENDC = '\033[0m'
            BOLD = '\033[1m'
            UNDERLINE = '\033[4m'

        # print grad check
        v_n = []
        v_v = []
        v_g = []
        for name, parameter in model.named_parameters():
            v_n.append(name)
            v_v.append(parameter.detach().cpu().numpy() if parameter is not None else [0])
            v_g.append(parameter.grad.detach().cpu().numpy() if parameter.grad is not None else [0])
        for j in range(len(v_n)):
            if np.isnan(np.max(v_v[j]).item() - np.min(v_v[j]).item()) or np.isnan(
                    np.max(v_g[j]).item() - np.min(v_g[j]).item()):
                color = bcolors.FAIL + '*'
            else:
                color = bcolors.OKGREEN + ' '
            print('%svalue %s: %.3e ~ %.3e' % (color, v_n[j], np.min(v_v[j]).item(), np.max(v_v[j]).item()))
            print('%sgrad  %s: %.3e ~ %.3e' % (color, v_n[j], np.min(v_g[j]).item(), np.max(v_g[j]).item()))

outputs = model(images)

通过检查,证明权重没有问题,所以问题被限定iter=162的前向推理和损失计算两处

检查输入和输出
通过代码:

print(images.mean())	# 检查输入,正常
outputs = model(images)
print(outputs .mean())	# 检查输出,Nan

由此我们知道了情况:模型权重正常,模型输入正常,但是模型的输出Nan

2、精确定位

到这里就好办了,借助pycharm,我们一步一步调试在模型中各个模型的输入输出,看看到底是在模型的哪一个部分出现了Nan或者Inf,最终定位到一行代码:

attn = (q @ k.transpose(-2, -1)) * self.scale

这句代码是想实现q和k的矩阵乘法, 他们的值域分别为:

从这里可以发现,就是单纯的计算问题,一种很常见的就是数值溢出,考虑到我使用半精度float16,通过查询其最大值是65504,所以很有可能是最大值溢出了。为了验证,我们可以在计算前,将q和k转为double(float64),可以发现其计算结果正常了,类型也是float64。这表明就是因为数值溢出造成的。
记录PyTorch中半精度amp训练出现Nan的排查过程-LMLPHP

3、解决办法

现在已知我的原因是数值溢出,一种方法是截取:将inf或nan设置为一个常量,我则在运算前将q和k进行norm归一化到[-1,1],这样保证了运算结果不会太大(没有什么原因,就是无脑操作,不建议学习)。

12-07 07:13