参考

如何计算神经网络模型的复杂度
深度学习卷积、全连接层、深度可分离层参数量和FLOPs计算公式

概念

  1. Params:模型的参数量。(空间复杂度
  2. FLOPs:FLoating point Operations,前向推理的计算量。(时间复杂度
  3. MAC:Memory Access Cost。基本上看每个计算输出结果 C o u t × H o u t × W o u t C_{out} \times H_{out} \times W_{out} Cout×Hout×Wout 相加的总和。
  4. MACC(MADD):multiply-accumulate operations:先乘起来再加起来的运算次数。 也就是 乘加 看做一次运算。
    所以 1个 MACC = 2个 FLOPs。
  5. 内存量

H i n H_{in} Hin: 输入的 height
W i n W_{in} Win: 输入的 width
H o u t H_{out} Hout: 输出的 height
W o u t W_{out} Wout: 输入的 width
K K K: 卷积核size
C i n C_{in} Cin: 输入的channel 数
C o u t C_{out} Cout: 输出的 channel数

参数量计算

全连接层

如何计算模型的复杂度(参数量,FLOPs)-LMLPHP

卷积层

普通卷积: 输入尺寸 C i ∗ H i ∗ W i C_i * H_i * W_i CiHiWi, 卷积核的大小为 K ∗ K K*K KK, 输出的尺寸大小为 C o ∗ H o ∗ W o C_o * H_o * W_o CoHoWo.

  • 不考虑 bias
    K 2 × C i × C o K^2 \times C_{i} \times C_{o} K2×Ci×Co
  • 考虑bias
    ( K 2 × C i + 1 ) × C o (K^2 \times C_{i} + 1) \times C_{o} (K2×Ci+1)×Co

池化层

对于池化层而言,常用的Max-pooling,Avg-pooling等是不存在参数量的。

batch norm

每个 batch 减均值,除方差。
再根据参数 α \alpha α, β \beta β 做缩放
在训练时计算的均值方差是直接计算,在预测时是用 running mean,running var.
如何计算模型的复杂度(参数量,FLOPs)-LMLPHP
所以参数量是?2HW*C, 错了, 是
2 × C i 2 \times C_{i} 2×Ci

激活函数

无参数

FLOPs

卷积层

  • 不考虑 bias
    ( 2 × ( K 2 × C i ) − 1 ) × ( C o × H o × W o ) (2\times (K^2 \times C_{i} ) -1 ) \times (C_{o} \times H_{o} \times W_{o}) (2×(K2×Ci)1)×(Co×Ho×Wo)

先计算输出的feature中一个元素需要的计算量。 ( K 2 × C i ) (K^2 \times C_{i} ) (K2×Ci) 表示乘法次数, ( K 2 × C i ) − 1 (K^2 \times C_{i} ) -1 (K2×Ci)1 表示加法次数。

  • 考虑bias
    带bias 的计算(一部分是乘法,一部分是加法)
    2 × ( K 2 × C i ) × ( C o × H o × W o ) 2\times (K^2 \times C_{i} ) \times (C_{o} \times H_{o} \times W_{o}) 2×(K2×Ci)×(Co×Ho×Wo)

全连接层

输入维度 C i C_i Ci, 输出 C o C_o Co. 全连接层就理解为一个矩阵,矩阵行数,矩阵列数,如考虑bias,则先计算输出向量中一个元素需要多少计算量,首先要做 C i C_i Ci 次乘法,然后做 C i − 1 C_i -1 Ci1 次加法。若考虑 bias,则做的加法会多一次。

  • 不考虑 bias : ( 2 N i n − 1 ) N o u t (2N_{in}-1)N_{out} (2Nin1)Nout
    N i n N o u t N_{in}N_{out} NinNout 为乘法的运算量,
    ( N i n − 1 ) N o u t (N_{in} - 1)N_{out} (Nin1)Nout为加法的运算量
  • 考虑 bias : ( 2 N i n ) N o u t (2N_{in})N_{out} (2Nin)Nout

工具

torchinfo
mmdetection 工具代码

如何计算模型的复杂度(参数量,FLOPs)-LMLPHP

More

https://github.com/sovrasov/flops-counter.pytorch
https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/utils/flops_counter.py

02-03 21:30