0. 前言

在 PyTorch 中,nn.ModuleListnn.ParameterList 是两种非常有用的工具,可以让你以更加灵活的方式构建和管理动态网络结构。这两种列表允许你在构建模型时轻松地添加或删除层,这对于构建自适应模型、循环网络或其他需要动态调整结构的场景非常有用。

本文将详细介绍这两个类的使用方法及其应用场景,帮助你更好地理解和运用它们来构建复杂和灵活的神经网络模型。

1. 为什么需要 nn.ModuleListnn.ParameterList

在构建深度学习模型时,我们经常需要创建包含多个层的网络。传统的做法是显式地定义每一层,但这在某些情况下可能不够灵活。例如,当你需要根据输入数据动态决定网络结构时,就需要一种更加灵活的方式来组织和管理这些层。

nn.ModuleListnn.ParameterList 提供了这样的灵活性。它们允许你将多个层或参数集合组织在一起,并且可以方便地在运行时增加、删除或修改这些层或参数。

2. nn.ModuleList:管理模块的列表

2.1 什么是 nn.ModuleList

nn.ModuleList 是一个包含 nn.Module 子类实例的有序列表。它可以用于管理一个模型中的多个层,而且这些层可以是任意类型的 nn.Module 对象。

2.2 创建 nn.ModuleList

要创建一个 nn.ModuleList,你可以简单地将 nn.Module 的实例作为一个列表传递给构造函数。例如:

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
M = MyModel()
print(M)

输出为:

MyModel(
  (layers): ModuleList(
    (0-4): 5 x Linear(in_features=10, out_features=10, bias=True)
  )
)

在这个例子中,MyModel 包含了一个由五个 nn.Linear 层组成的 ModuleList。每个层都将输入的维度从 10 映射到 10。

2.3 动态添加或删除层

nn.ModuleList 支持像 Python 列表那样的索引操作,因此可以轻松地添加、删除或替换其中的层:

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

M = MyModel()

M.layers.append(nn.Conv2d(in_channels=1,out_channels=3,kernel_size=3))
print(M.layers[5])

输出为:

Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))

3. nn.ParameterList:管理参数列表

3.1 什么是 nn.ParameterList

nn.ParameterList 类似于 nn.ModuleList,但它用于管理一组 nn.Parameter 对象。这些参数可以是权重矩阵、偏置向量等。

3.2 创建 nn.ParameterList

要创建一个 nn.ParameterList,你可以将 nn.Parameter 对象作为一个列表传递给构造函数:

import torch.nn as nn
import torch

torch.manual_seed(666)

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.weights = nn.ParameterList([nn.Parameter(torch.randn(3, 3)) for _ in range(5)])

    def forward(self, x):
        for weight in self.weights:
            x = torch.mm(x, weight)
        return x

M = MyModel()
print(M.weights)
print(M.weights[-1])

输出为:

ParameterList(
    (0): Parameter containing: [torch.float32 of size 3x3]
    (1): Parameter containing: [torch.float32 of size 3x3]
    (2): Parameter containing: [torch.float32 of size 3x3]
    (3): Parameter containing: [torch.float32 of size 3x3]
    (4): Parameter containing: [torch.float32 of size 3x3]
)
Parameter containing:
tensor([[ 2.1743, -0.9672, -0.7672],
        [-0.5229, -2.2826,  0.1051],
        [-0.2497, -1.5241,  1.5813]], requires_grad=True)

在这个例子中,MyModel 包含了一个由五个随机权重矩阵组成的 ParameterList

3.3 动态添加或删除参数

nn.ParameterList 同样支持像 Python 列表那样的索引操作,因此你可以轻松地添加、删除或替换其中的参数:

import torch.nn as nn
import torch

torch.manual_seed(666)

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.weights = nn.ParameterList([nn.Parameter(torch.randn(3, 3)) for _ in range(5)])

    def forward(self, x):
        for weight in self.weights:
            x = torch.mm(x, weight)
        return x

M = MyModel()

M.weights.append(torch.zeros(3,3))
print(M.weights)
print(M.weights[-1])

输出为:

ParameterList(
    (0): Parameter containing: [torch.float32 of size 3x3]
    (1): Parameter containing: [torch.float32 of size 3x3]
    (2): Parameter containing: [torch.float32 of size 3x3]
    (3): Parameter containing: [torch.float32 of size 3x3]
    (4): Parameter containing: [torch.float32 of size 3x3]
    (5): Parameter containing: [torch.float32 of size 3x3]
)
Parameter containing:
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], requires_grad=True)

4. 自适应模型

比如构建一个模型,其中的某些层可以根据输入数据动态决定是否使用。你可以使用 nn.ModuleList 来存储这些层,并在 .forward() 方法中根据条件决定是否使用它们:

import torch.nn as nn
class AdaptiveModel(nn.Module):
    def __init__(self, num_layers):
        super(AdaptiveModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(5, 5) for _ in range(num_layers)])

    def forward(self, x):
        use_layers = [True, False, True, True, False]  # 示例:使用第0、2、3层
        for i, layer in enumerate(self.layers):
            if use_layers[i]:
                x = layer(x)
        return x

5. 总结

nn.ModuleListnn.ParameterList 提供了一种灵活的方式来构建和管理动态网络结构。通过这些工具,可以轻松地构建自适应模型、循环网络或其他需要动态调整结构的场景。

09-17 00:05