1. 准备数据集

1.1 MNIST数据集获取:

  • torchvision.datasets接口直接下载,该接口可以直接构建数据集,推荐

  • 其他途径下载后,编写程序进行读取,然后由Datasets构建自己的数据集

​​本文使用第一种方法获取数据集,并使用Dataloader进行按批装载。如果使用程序下载失败,请将其他途径下载的MNIST数据集 [文件][解压文件] 放置在 <data/MNIST/raw/> 位置下,本文的程序及文件结构图如下:

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)-LMLPHP

​​其中,model文件夹用来存储每个epoch训练的模型参数,根文件夹下包含model.py用于训练模型,test.py为测试集测试,show.py为展示部分

1.2 程序部分

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

# 1. 准备数据集
## 1.1 使用torchvision自动下载MNIST数据集
train_data = datasets.MNIST(root='data\\',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)

## 1.2 构建数据集装载器
train_loader = DataLoader(dataset=train_data,
                          batch_size=100,
                          shuffle=True,
                          drop_last=False,
                          num_workers=4)

if __name__ == "__main__":
    print("===============数据统计===============")
    print("训练集样本:",train_data.__len__(), train_data.data.shape)

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)-LMLPHP

​​【代码解析】

  • root为存放MNIST的路径,trian=True代表下载的为训练集和训练集标签,False则代表测试集和标签

  • transforms.ToTensor()表示将shape为(H, W, C)的 numpy 数组或 img 转为shape为(C, H, W)的tensor,并将数值归一化为[0,1]

  • download为True则代表自动下载,若该文件夹下已经下载,则直接跳过下载步骤

  • shuffle=True,表示对分好的batch进行洗牌操作,drop_last=True表示对最后不足batch大小的剩余样本舍去,False表示保留

  • num_works表示每次读取的进程数,和核心数有关

​​Dataset和Dataloader详细说明,请移步:[Pytorch Dataset和Dataloader 学习笔记(二)]

2. 设计网络结构

2.1 网络设计

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)-LMLPHP

​​网络结构如上图所示,输入图像—>卷积1—>池化1—>卷积2—>池化2—>全连接1—>全连接2—>softmax,每次卷积通道数都增加一倍,最后送入全连接层实现分类

2.2 程序部分

# 2. Design model using class
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_layer1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.max_pooling1 = nn.MaxPool2d(2)
        self.conv_layer2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.max_pooling2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(1568, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.max_pooling1(F.relu(self.conv_layer1(x)))
        x = self.max_pooling2(F.relu(self.conv_layer2(x)))
        x = x.view(-1, 32*7*7)
        x = F.relu(self.fc1(x))
        y_hat = self.fc2(x)     # CrossEntropyLoss会自动激活最后一层的输出以及softmax处理
        return y_hat

net = Net()

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)-LMLPHP

​​【代码解析】

  • fc1的1568维度是因为最后一次池化后的shape为32*7*7=1568

  • 在最后一层,并没有进行relu激活以及接入softmax,是因为,在CrossEntropyLoss中会自动激活最后一层的输出以及softmax处理

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)-LMLPHP

​​CrossEntropyLoss图参考:《PyTorch深度学习实践》完结合集
​​详细网络结构搭建说明,请移步:Pytorch线性规划模型 学习笔记(一)

3. 迭代训练

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

# 4. Training
if __name__ == "__main__":
    print("Training...")
    for epoch in range(20):
        strat = time.time()
        total_correct = 0
        for x, y in train_loader:
            y_hat = net(x)
            y_pre = torch.argmax(y_hat, dim=1)
            total_correct += sum(torch.eq(y_pre, y))    # 统计当前epoch下的正确个数

            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        acc = (float(total_correct) / train_data.__len__())*100
        save_path = "model/net" + str(epoch+1) + ".pth"
        torch.save(obj=net.state_dict(), f=save_path)
        print("epoch:", str(epoch + 1) + "/20",
              " \n time:", "%.1f" % (time.time() - strat) + "s"
              " train_loss:", loss.item(),
              " acc:%.3f%%" % acc,)

    print("we are done!")

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)-LMLPHP

​​【代码解析】

  • total_correct变量用于统计每个epoch下正确预测值的个数,每进行epoch进行一次清零
  • torch.argmax(y_hat, dim=1)用于选取y_hat下每一行的最大值(每个样本的最高得分),并返回与y相同维度的tensor
  • torch.eq(y_pre, y)用于比较两个矩阵元素是否相同,相同则返回True,不同则返回False,用于判断预测值与真实值是否相同
  • torch.save保存了每个epoch的网络权重参数

4. 测试集预测部分

# 测试模型,测试集为test_data

import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import Net

test_data = datasets.MNIST(root='data\\',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=True)
test_loader = DataLoader(dataset=test_data,
                          batch_size=100,
                          shuffle=True,
                          drop_last=False,
                          num_workers=4)

if __name__ == "__main__":
    print("---------------预测分析---------------")
    print("测试集样本:", test_data.__len__(), test_data.data.shape)
    model = Net()
    model.load_state_dict(torch.load("model/net20.pth"))
    model.eval()

    total_correct = 0
    for x, y in test_loader:
        y_hat = model(x)
        y_pre = torch.argmax(y_hat, dim=1)
        total_correct += sum(torch.eq(y_pre, y))

    acc = (float(total_correct) / test_data.__len__())*100
    print("total_test_samples:", test_data.__len__(),
          " test_acc:", "%.3f%%" % acc)

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)-LMLPHP

​​经过20个epoch的训练,在测试集上达到了98.590%的准确率,部分batch真实值与预测值展示如下:

5. 全部代码

链接:链接:https://pan.baidu.com/s/1GGhG1Slw2Tlsgl13yzHUIw
提取码:82l4

转载请说明出处

06-21 22:00