• 前文 PyTorch入门(2)—— 自动求梯度 介绍过 Pytorch 中的自动微分机制,这是实现神经网络反向传播的基础,也是所有深度学习框架最重要的基础设施之一
  • 梯度计算是需要占用计算资源的,而我们并不总是需要计算梯度(比如做评估时),Pytorch 提供了几种方式来控制梯度计算,本文对这些方法进行梳理
  • 参考自 pytorch 文档:Locally disabling gradient computation

1. 回顾 pytorch 的自动微分机制

  • PyTorch 提供的 autograd 是一个反向自动微分系统,它能根据对 tensor 的操作过程自动构建计算图。具体而言,计算图是一个有向无环图,记录了前向传播过程的全部数据操作,图中的根节点是输出张量(Output Tensor),叶节点是输入张量(Input Tensor)。沿着这个图即可使用链式法则计算得到中间的梯度。下面给出一个例子
  • pytorch 中计算的对象都是 tensor,所以计算图中每个节点也都是一个 tensor 对象。Pytorch 使用了动态计算图机制,每次前向传播过程中都会从头构造一次计算图,我们可以在每一次迭代中改变计算过程或禁用部分梯度计算,从而改变计算图的形状和大小
  • 关于梯度计算的更多细节,请参考前文 PyTorch入门(2)—— 自动求梯度

2. 局部梯度控制

  • 有几种机制可以在Python中临时禁用梯度计算:
    1. 要在一段代码块中禁用梯度计算,可以使用 no-grad 模式和 inference模式等上下文管理器
    2. 要更精确地控制梯度,比如从计算图中剔除部分子图,可以通过设置计算图中节点 tensor 的 requires_grad 字段来实现。这样可以打断有向无环图中的一些有向边,从而选择性地排除某些子图不参与梯度计算
  • PyTorch 中还有一个针对 nn.Module 的评估模式方法 nn.Module.eval(),它实际上并不用于禁用梯度计算。但是由于其名称的误导性,经常与上述机制混淆使用

2.1 通过设置 requires_grad 实现精确梯度控制

  • Tensor.requires_grad 是 tensor 对象的一个标志变量,默认为 False,它在前向传播和反向传播中都起作用,允许对梯度计算中的子图进行精细排除。

    1. 正向传递过程中,一个操作只有在其输入张量中至少有一个 requires_grad=True 时才会记录在计算图中
    2. 在向后传递期间,只有 requires_grad=True叶张量才会将梯度累积到它们的 .grad 字段中

    有几种方式可以将 Tensor.requires_grad 设置为 True

    1. 定义 tensor 时设置参数,如 torch.ones(2,2,requires_grad=True)
    2. 使用 requires_grad_() 进行 in-place 设置,如 a.requires_grad_(True)
    3. 使用 nn.Parameter 对 tensor 进行包装,如 nn.Parameter(torch.zeros(2,2))。定义神经网络时如果需要优化一个张量(比如定义Transformer的可训练位置编码),通常使用这种方法
  • 值得注意的是,尽管每个张量都有这个标志,但设置它只对 leaf tensor 有意义;non-leaf tensor 是有可以记录其计算过程的 .grad_fn 方法,有一部分反向图与之相关的 tensor,它们自动具有 require_grad=True,因为要计算 leaf tensor 的梯度时必须借助相关的 non-leaf tensor 的梯度作为中间结果

  • 设置 requires_grad 是控制模型进行部分梯度计算的主要方法。例如考虑函数: y 3 = y 1 + y 2 = x 2 + x 3 y_3 = y_1+y_2=x^2+x^3 y3=y1+y2=x2+x3,有
    ∂ y 3 ∂ x = ∂ y 1 ∂ x + ∂ y 2 ∂ x = 2 x + 3 x = 5 x \frac{\partial y_3}{\partial x} = \frac{\partial y_1}{\partial x} + \frac{\partial y_2}{\partial x} = 2x+3x = 5x xy3=xy1+xy2=2x+3x=5x 如果将其中的 y 2 y_2 y2 设置为 requires_grad=False,梯度就无法从 y 2 y_2 y2 往回传播,这时有
    ∂ y 3 ∂ x = ∂ y 1 ∂ x = 2 x \frac{\partial y_3}{\partial x} = \frac{\partial y_1}{\partial x}= 2x xy3=xy1=2x 计算图如下
    Pytorch入门(6)—— 梯度计算控制-LMLPHP

    x = torch.tensor(1.0, requires_grad=True)
    y1 = x ** 2 
    with torch.no_grad():
        y2 = x ** 3
    y3 = y1 + y2
    
    print(x.requires_grad)		# True
    print(y1, y1.requires_grad) # tensor(1., grad_fn=<PowBackward0>) True
    print(y2, y2.requires_grad) # tensor(1.) False
    print(y3, y3.requires_grad) # tensor(2., grad_fn=<AddBackward0>) True
    
    y3.backward()
    print(x.grad)               # tensor(2.)
    
    #y2.backward() # 报错: element 0 of tensors does not require grad and does not have a grad_fn
    
  • 另外,requires_grad 也可以在模块级别通过 nn.Module.requires_grad_() 进行设置,这对模块内的所有参数生效 (默认情况下requires_grad=True)

2.2 三种梯度计算模式

  • 除了设置 requires_grad 之外,Pytorch 还提供了三种可以影响 autograd 内部梯度计算的模式:默认模式/梯度模式(Grad Mode)无梯度模式(No-grad Mode)推理模式(Inference Mode),所有这些模式都可以通过python语法中的上下文管理器和装饰器进行切换

2.2.1 梯度模式 (Grad Mode)

  • 这是 Pytorch 工作的默认模式,是我们在没有启用其他模式时隐含的模式。为了与 “无梯度模式” 形成对比,有时也被称为 “梯度模式”。梯度模式是 requires_grad=True 生效的唯一模式,requires_grad 在其他两种模式中总是被设置为 False

2.2.2 无梯度模式 (No-grad Mode)

  • 在无梯度模式下即使有 require_grad=True 的输入,也不会在反向图中记录。有两种常用的进入无梯度模式的方法:使用上下文管理器(with语法)和函数装饰器
    with torch.no_grad():
        do_something()
    
    @torch.no_grad()
    def do_something_func():
        do_something()
    
    这两种方法可以方便地禁用代码块或函数的梯度。另外还有一种手动设置的方法
    torch.set_grad_enabled(False)
    do_something()
    torch.set_grad_enabled(True)
    
  • 无梯度模式适用于有一些操作无需记录梯度,但需要中间计算结果用于后续(梯度模式下)梯度计算的情况(可以理解成将计算图的一部分变成一个常数)
    1. 编写优化器时可能很适合使用无梯度模式:每轮迭代中,优化器要就地更新模型参数,这些更新操作不应被记录梯度,之后在下一轮的前向传递中要使用更新后的参数进行梯度模式的计算,例如
      def sgd(params, lr, batch_size): 
          """小批量随机梯度下降"""
          with torch.no_grad():
              for param in params:    # param 是一个list,如果模型是y=Xw+b,则param=[w,b]
                  param -= lr * param.grad / batch_size
                  param.grad.zero_()
      
    2. torch.nn.init 方法的实现也依赖于无梯度模式,以避免在就地初始化参数时就自动跟踪梯度了
    3. 做模型验证或评估时也常常使用无梯度模式
  • 由于无梯度模式下不会生成反向计算图,显存占用和计算资源的消耗都大大减少了,体现在代码上就是计算验证损失时 batch size 可以远大于计算训练损失时的 batch size,而且计算更快

2.2.3 推断模式 (Inference Mode)

  • 推断模式是无梯度模式的极端版本,这种模式下也不会记录反向图,它的执行速度更快,但缺点在于推断模式下创建的 tensor 将无法用于后续(在梯度模式下)梯度计算。进入推断模式也有上下文管理器(with语法)和函数装饰器两种方法
    with torch.inference_mode():
    	do_something()
    
    @torch.inference_mode()
    def func(x):
      	do_something()
    
  • 建议在代码中不需要自动梯度跟踪的部分 (如数据处理和模型评估阶段) 尝试推理模式,相比过去使用无梯度模式这可以无成本地提升性能。如果在启用推理模式后遇到错误,请检查是否在退出推理模式后由 Autograd 记录的计算中使用了在推理模式下创建的 tensor,如果无法避免这种情况,你可以随时切换回无梯度模式
  • 需要注意的是,推断模式是从 pytorch 1.10 开始引入的新特性,使用前需要确保 Pytorch 版本支持

2.3 容易混淆的模型评估模式(Evaluation Mode)

  • (模型)评估模式 nn.Module.eval() 实际上不是一种影响 autograd 内部梯度计算的模式,但它有时会被混淆成这样一种机制。 在功能上,module.eval()/ module.train() 与 2.2 节介绍的三种模式完全正交。model.eval() 如何影响模型完全取决于模型中使用的特定模块,以及它们是否定义了任何特定于训练模式的行为。具体而言:model.eval() 的作用是不启用 Batch Normalization 和 Dropout,即
    1. Dropout 层会让所有的激活单元都通过,不会随机失能
    2. Batch Normalization 层会停止计算和更新 mean 和 var,直接使用在训练阶段已经学出的 mean 和 var 值
  • 建议无论模型定义中是否涉及上述操作,都在训练时始终使用 model.train(),在评估模型(验证/测试)时始终使用 model.eval(),以免受到任何潜在导致这些操作的模型更新的影响
  • 注意在梯度模式下,即使调用了 module.eval(),所有梯度还是会被计算

3. 总结

  • Pytorch 使用 Autograd 机制自动追踪对 tensor 的各种操作,并实时生成可以用于计算梯度的反向计算图。可以通过设置 Tensor.requires_grad 参数来打断计算图中的某些边,以实现对梯度计算的精确控制。在微调模型时可以用这种方式冻结某些参数
  • Pytorch 还提供了三种可以影响梯度计算的模式
    1. 梯度模式:仅在这种模式下 requires_grad=True 生效,会进行计算图构建,这也是默认模式
    2. 无梯度模式:这时即使有 require_grad=True 的输入也不会在反向图中记录,适用于有一些操作无需记录梯度,但需要中间计算结果用于后续(梯度模式下)梯度计算的情况。这种模式下显存占用和计算资源的消耗都大大减少了
    3. 推断模式:无梯度模式的极端版本,也不会记录反向计算图,执行速度更快,但推断模式下创建的 tensor 将无法用于后续(在梯度模式下)梯度计算
  • nn.Module 单独有一种评估模式,它的功能和以上三种梯度计算模式是正交的,它仅影响 Dropout 和 Batch Normalization 的行为模式而和梯度计算无关。在训练应时始终使用 model.train(),在评估模型 (验证/测试) 时应始终使用 model.eval()
09-12 06:14