我们说的大模型参数量都是6B,130B等等, 一般模型运行需要多少显存呢.
我其实在想,为啥要发明一个B出来, 我们用G来描述不香吗.
推理过程:
计算的公式是 G = B ( 1.024 ) 3 × s i z e o f ( 参数类型 ) G = \frac{B}{(1.024)^3}\times sizeof(参数类型) G=(1.024)3B×sizeof(参数类型)
比如6B全精度, 那就是 G = 6 ( 1.024 ) 3 × s i z e o f ( f l o a t ) ≈ 22.35 G B G= \frac{6}{(1.024)^3}\times sizeof(float)\approx22.35GB G=(1.024)36×sizeof(float)≈22.35GB, 当然,一个模型计算,除了模型参数外还有一些计算参数, 但是大概显存在22.35G左右.
训练过程:
训练过程要保存梯度等计算过程, 所以需要 上面的 G ∗ 4 G*4 G∗4 ,也就是说, 上面全精度6B需要 G × 4 ≈ 22.35 × 4 = 89.40 G B G\times4\approx22.35\times4=89.40GB G×4≈22.35×4=89.40GB显存. 如果是24G显存的显卡, 需要4张才能够展开训练.
当然训练速度能耗也是大家需要考虑的, 所以可以考虑RoLA 等低参数训练方式.
有什么不对的地方, 欢迎指出, 一起学习!