简介
PyTorch 是一个著名的开源深度学习框架,由于其灵活性和易用性而广受欢迎。如果你正在使用 PyTorch框架来探索深度学习,那么你就会遇到核心之一+++PyTorch 模块。那么了解它们是什么、为什么它们很重要以及它们是如何工作的是非常必要的。
什么是Pytorch Modules
在 PyTorch 中,nn.Module 类是构建神经网络架构的基石。它可作为神经网络特定的组件。以下是它们的区别:
封装:
模块将模型层的可训练参数(权重和偏差)整齐地打包在一起。它们还可以包含其他模块,这样你就可以构建复杂的层次结构。
状态性:
模块会记住它们的参数和内部状态,本质上是存储它们在训练的时候所学习到的内容。
可组合性:
与构建块一样,模块可以连接起来形成复杂的神经网络。
为什么使用模块
组织:
模块提供了一种结构化的方式来组织神经网络的不同组件,从而增强了代码的可读性和可维护性。
可重用性:
自定义模块可以轻松地在不同的模型或项目中重复使用,从而促进代码模块化。
抽象:
模块抽象好了参数的处理和梯度计算的复杂性,这样你就可以专注于模型更高级别的架构。
预训练模型:
模块有助于加载和集成预训练模型组件,这是迁移学习中的常见做法。
模块的工作原理
模块最重要的方面之一是其管理参数的能力,其工作原理如下:
参数注册:
当使用 PyTorch 函数(如 nn.Linear、nn.Conv2d 等)在模块的构造函数(init 方法)中定义变量时,这些变量会自动作为参数注册到模块。
参数访问:
可以通过多种方式访问模块的参数:module.parameters():返回一个迭代器以访问所有参数。module.state_dict():返回一个包含所有参数及其值的 OrderedDict字典。
优化:
在训练期间,PyTorch 优化器(如 Adam、SGD)专门迭代模块中注册的参数,更新它们以最小化损失函数。
模块如何跟踪参数
nn.Parameter:
nn.Parameter 类是 PyTorch 的 Tensor 的一个特殊子类。当使用 nn.Parameter 在自定义模块中定义属性时,它会被标记为需要跟踪的参数。
setattr 重载:
nn.Module 类会重载 setattr 方法(负责设置属性)。当将 nn.Parameter 分配给模块的属性时(例如,self.linear_layer = nn.Linear(…)),它会在内部注册此参数。
注册:
nn.Module 维护一个内部列表或类似字典的结构来存储这些已注册的参数。
如下面的代码所示:
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2,