学习开源深度学习框架PyTorch中的模块【Modules】-LMLPHP

简介

PyTorch 是一个著名的开源深度学习框架,由于其灵活性和易用性而广受欢迎。如果你正在使用 PyTorch框架来探索深度学习,那么你就会遇到核心之一+++PyTorch 模块。那么了解它们是什么、为什么它们很重要以及它们是如何工作的是非常必要的。

什么是Pytorch Modules

在 PyTorch 中,nn.Module 类是构建神经网络架构的基石。它可作为神经网络特定的组件。以下是它们的区别:

封装:
模块将模型层的可训练参数(权重和偏差)整齐地打包在一起。它们还可以包含其他模块,这样你就可以构建复杂的层次结构。

状态性:
模块会记住它们的参数和内部状态,本质上是存储它们在训练的时候所学习到的内容。

可组合性:
与构建块一样,模块可以连接起来形成复杂的神经网络。

为什么使用模块

组织:
模块提供了一种结构化的方式来组织神经网络的不同组件,从而增强了代码的可读性和可维护性。

可重用性:
自定义模块可以轻松地在不同的模型或项目中重复使用,从而促进代码模块化。

抽象:
模块抽象好了参数的处理和梯度计算的复杂性,这样你就可以专注于模型更高级别的架构。

预训练模型:
模块有助于加载和集成预训练模型组件,这是迁移学习中的常见做法。

模块的工作原理

模块最重要的方面之一是其管理参数的能力,其工作原理如下:

参数注册:
当使用 PyTorch 函数(如 nn.Linear、nn.Conv2d 等)在模块的构造函数(init 方法)中定义变量时,这些变量会自动作为参数注册到模块。

参数访问:
可以通过多种方式访问​​模块的参数:module.parameters():返回一个迭代器以访问所有参数。module.state_dict():返回一个包含所有参数及其值的 OrderedDict字典。

优化:
在训练期间,PyTorch 优化器(如 Adam、SGD)专门迭代模块中注册的参数,更新它们以最小化损失函数。

模块如何跟踪参数

nn.Parameter:
nn.Parameter 类是 PyTorch 的 Tensor 的一个特殊子类。当使用 nn.Parameter 在自定义模块中定义属性时,它会被标记为需要跟踪的参数。

setattr 重载:
nn.Module 类会重载 setattr 方法(​​负责设置属性)。当将 nn.Parameter 分配给模块的属性时(例如,self.linear_layer = nn.Linear(…)),它会在内部注册此参数。

注册:
nn.Module 维护一个内部列表或类似字典的结构来存储这些已注册的参数。

如下面的代码所示:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2,
08-20 14:10