PyTorch深度学习实战(31)——生成对抗网络

0. 前言

生成对抗网络 (Generative Adversarial Networks, GAN) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用广泛,且 GAN 的相关研究正在迅速发展,以生成与真实图像难以区分的逼真图像。在本节中,我们将学习 GAN 网络的原理并使用 PyTorch 实现 GAN

1. GAN

生成对抗网络 (Generative Adversarial Networks, GAN) 包含两个网络:生成网络( Generator,也称生成器)和判别网络( discriminator,也称判别器)。在 GAN 网络训练过程中,需要有一个合理的图像样本数据集,生成网络从图像样本中学习图像表示,然后生成与图像样本相似的图像。判别网络接收(由生成网络)生成的图像和原始图像样本作为输入,并将图像分类为原始(真实)图像或生成(伪造)图像。
生成网络的目标是生成逼真的伪造图像骗过判别网络,判别网络的目标是将生成的图像分类为伪造图像,将原始图像样本分类为真实图像。本质上,GAN 中的对抗表示两个网络的相反性质,生成网络生成图像来欺骗判别网络,判别网络通过判别图像是生成图像还是原始图像来对输入图像进行分类:

PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)-LMLPHP

在上图中,生成网络根据输入随机噪声生成图像,判别网络接收生成网络生成的图像,并将它们与真实图像样本进行比较,以判断生成的图像是真实的还是伪造的。生成网络尝试生成尽可能逼真的图像,而判别网络尝试判定生成网络生成图像的真实性,从而学习生成尽可能逼真的图像。
GAN 的关键思想是生成网络和判别网络之间的竞争和动态平衡,通过不断的训练和迭代,生成网络和判别网络会逐渐提高性能,生成网络能够生成更加逼真的样本,而判别网络则能够更准确地区分真实和伪造的样本。
通常,生成网络和判别网络交替训练,将生成网络和判别网络视为博弈双方,并通过两者之间的对抗来推动模型性能的提升,直到生成网络生成的样本能够以假乱真,判别网络无法分辨真实样本和生成样本之间的差异:

  • 生成网络的训练过程:冻结判别网络权重,生成网络以噪声 z 作为输入,通过最小化生成网络与真实数据之间的差异来学习如何生成更好的样本,以便判别网络将图像分类为真实图像
  • 判别网络的训练过程:冻结生成网络权重,判别网络通过最小化真实样本和假样本之间的分类误差来更新判别网络,区分真实样本和生成样本,将生成网络生成的图像分类为伪造图像

重复训练生成网络与判别网络,直到达到平衡,当判别网络能够很好地检测到生成的图像时,生成网络对应的损失比判别网络对应的损失要高得多。通过不断训练生成网络和判别网络,直到生成网络可以生成逼真图像,而判别网络无法区分真实图像和生成图像。

2. GAN 模型分析

为了生成手写数字的图像,我们采取以下策略:

  • 导入 MNIST 数据
  • 初始化随机噪声
  • 定义生成网络模型
  • 定义判别网络模型
  • 使用生成网络生成伪造图像,生成网络在最初只能生成噪声图像,噪声图像是通过将一组噪声值通过权重随机的神经网络得到的图像
  • 交替训练两个模型
    • 将生成的图像与原始图像串联起来,判别网络预测每个图像是伪造图像还是真实图像,对判别网络进行训练,判别网络的损失是图像的预测值和实际值(标签)的二进制交叉熵,生成的伪造图像的实际值(标签)为 0,原始数据集中真实图像的实际值(标签)为 1
    • 训练生成网络利用输入噪声生成伪造图像,使其看起来更接近真实图像,从而使生成图像有可能欺骗判别网络
    • 输入噪声通过生成网络传递输出伪造图像,将生成网络生成的图像输入到判别网络中,此时,判别网络权重被冻结,因为生成网络的目标是欺骗判别网络,因此,假设生成的伪造图像实际值(标签)为 1,生成网络的损失是判别网络对输入图像的预测值和实际值 (1) 的二进制交叉熵

了解了 GAN 的基本原理后,在下一小节,我们实现 GAN 生成 MNIST 手写数字图像。

3. 利用 GAN 模型生成手写数字

(1) 导入相关库并定义设备:

import torch
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
import numpy as np
from torchvision.utils import make_grid
device = "cuda" if torch.cuda.is_available() else "cpu"

from torchvision.datasets import MNIST
from torchvision import transforms

(2) 导入 MNIST 数据,定义具有内置数据转换功能的数据加载器,以便缩放输入数据:

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,))
])

data_loader = torch.utils.data.DataLoader(MNIST('MNIST/', train=True, download=True, transform=transform),batch_size=128, shuffle=True, drop_last=True)

(3) 定义判别网络模型类:

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential( 
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

在以上代码中,使用 LeakyReLU 激活函数替换 ReLU。打印判别网络的简要信息:

from torchsummary import summary
discriminator = Discriminator().to(device)
print(summary(discriminator, (1,784)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1              [-1, 1, 1024]         803,840
         LeakyReLU-2              [-1, 1, 1024]               0
           Dropout-3              [-1, 1, 1024]               0
            Linear-4               [-1, 1, 512]         524,800
         LeakyReLU-5               [-1, 1, 512]               0
           Dropout-6               [-1, 1, 512]               0
            Linear-7               [-1, 1, 256]         131,328
         LeakyReLU-8               [-1, 1, 256]               0
           Dropout-9               [-1, 1, 256]               0
           Linear-10                 [-1, 1, 1]             257
          Sigmoid-11                 [-1, 1, 1]               0
================================================================
Total params: 1,460,225
Trainable params: 1,460,225
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.57
Estimated Total Size (MB): 5.61
----------------------------------------------------------------

(4) 定义生成网络模型类 Generator

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

生成网络根据 100 维随机噪声输入生成图像。打印生成网络模型的简要信息:

generator = Generator().to(device)
print(summary(generator, (1,100)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1               [-1, 1, 256]          25,856
         LeakyReLU-2               [-1, 1, 256]               0
            Linear-3               [-1, 1, 512]         131,584
         LeakyReLU-4               [-1, 1, 512]               0
            Linear-5              [-1, 1, 1024]         525,312
         LeakyReLU-6              [-1, 1, 1024]               0
            Linear-7               [-1, 1, 784]         803,600
              Tanh-8               [-1, 1, 784]               0
================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.67
Estimated Total Size (MB): 5.71
----------------------------------------------------------------

(5) 定义函数生成随机噪声并将其注册到设备中:

def noise(size):
    n = torch.randn(size, 100)
    return n.to(device)

(6) 定义函数来训练判别网络。

判别网络训练函数 (discriminator_train_step) 将真实数据 (real_data) 和伪造数据 (fake_data) 作为输入:

def discriminator_train_step(real_data, fake_data, loss, d_optimizer):

重置优化器梯度:

    d_optimizer.zero_grad()

在对损失值执行反向传播之前,预测真实数据 (real_data) 并计算损失 (error_real):

    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
    error_real.backward()

在真实数据上计算判别网络损失时,我们期望判别网络预测输出为 1。因此,在判别网络的训练过程中,使用 torch.ones 作为标签,期望判别网络在真实数据上的输出为 1,从而计算判别网络在真实数据上的损失。

在对损失值执行反向传播之前,预测伪造数据 (fake_data) 并计算损失 (error_fake):

    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
    error_fake.backward()

在伪造数据上计算判别网络损失时,我们期望判别网络预测输出为 0。因此,在判别网络的训练过程中,使用 torch.zeros 作为标签,期望判别网络在伪造数据上的输出为 0,从而计算判别网络在伪造数据上的损失。

更新权重并返回整体损失(将模型在 real_dataerror_realfake_dataerror_fake 的损失值相加):

    d_optimizer.step()
    return error_real + error_fake

(7) 训练生成网络模型。

定义生成网络训练函数 generator_train_step 并传入伪造数据 fake_data 作为参数:

def generator_train_step(real_data, fake_data, loss, g_optimizer):

重置优化器梯度:

    g_optimizer.zero_grad()

预测判别网络对伪造数据 (fake_data) 的输出:

    prediction = discriminator(fake_data)

在计算生成网络的损失时,使用 torch.ones 作为标签,期望判别网络在伪造数据上的输出为 1,以在训练生成网络时欺骗判别网络输出值 1,以此来鼓励生成网络生成更加逼真的数据,并让判别网络无法区分其真伪:

    error = loss(prediction, torch.ones(len(real_data), 1).to(device))

执行反向传播,更新权重,并返回损失:

    error.backward()
    g_optimizer.step()
    return error

(8) 定义模型对象、生成网络和判别网络的优化器,以及损失函数:

discriminator = Discriminator().to(device)
generator = Generator().to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()

(9) 训练模型。

循环训练模型 200epochs (num_epochs):

num_epochs = 200

d_loss_epoch = []
g_loss_epoch = []
for epoch in range(num_epochs):
    N = len(data_loader)
    d_loss_items = []
    g_loss_items = []
    for i, (images, _) in enumerate(data_loader):

加载真实数据 (real_data) 和伪造数据,其中,伪造数据是通过将大小与真实数据样本数相同的噪声数据 (batch_size = len(real_data)) 传入生成网络网络获得的。需要注意的是,必须调用 fake_data.detach(),否则训练无法正常进行。通过 detach() 函数分离出来一个新的张量,这样在 discriminator_train_step() 中调用 error.backward() 时,与生成网络相关的张量(生成 fake_data )不会受到影响。使用 discriminator_train_step 函数训练判别网络:

        real_data = images.view(len(images), -1).to(device)
        fake_data = generator(noise(len(real_data))).to(device)
        fake_data = fake_data.detach()

训练判别网络后,继续训练生成网络。从噪声数据生成一组新的伪造图像 (fake_data) 并使用 generator_train_step 函数训练生成网络:

        fake_data = generator(noise(len(real_data))).to(device)
        g_loss = generator_train_step(real_data, fake_data, loss, g_optimizer)

记录损失变化:

        d_loss_items.append(d_loss.item())
        g_loss_items.append(g_loss.item())
    d_loss_epoch.append(np.average(d_loss_items))
    g_loss_epoch.append(np.average(g_loss_items))

绘制判别网络和生成网络的损失随训练的变化情况:

epochs = np.arange(num_epochs)+1
plt.plot(epochs, d_loss_epoch, 'bo', label='Discriminator Training loss')
plt.plot(epochs, g_loss_epoch, 'r-', label='Generator Training loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)-LMLPHP

(10) 可视化模型训练后生成的伪造数据:

z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
grid = make_grid(sample_images, nrow=8, normalize=True)
plt.imshow(grid.cpu().detach().permute(1,2,0), cmap='gray')
plt.show()

PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)-LMLPHP

在上图中,可以看到利用 GAN 生成逼真的图像,但仍有一定的改进空间,在之后的学习中,我们将介绍更多 GAN 的改进模型生成更逼真的图像。

小结

生成对抗网络是一种强大的深度学习模型,由生成器网络和判别器网络组成,通过彼此之间的竞争来提高性能,已经在图像生成、图像修复、图像转换和自然语言处理等领域取得了巨大的成功。其核心思想是通过生成器和判别器之间的博弈过程来实现真实样本的生成。生成器负责生成逼真的样本,而判别器则负责判断样本是真实还是伪造。通过不断的训练和迭代,生成器和判别器会相互竞争并逐渐提高性能。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(25)——自编码器(Autoencoder)
PyTorch深度学习实战(26)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(27)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(28)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(29)——神经风格迁移
PyTorch深度学习实战(30)——Deepfakes

01-23 07:11