pytorch实现 --- 手写数字识别

目录

1.项目介绍

2.实现方法

3.程序代码

4.运行结果


1.项目介绍

        使用pytorch实现手写数字识别,十分简单的小项目,环境搭建好,一跑就通。


2.实现方法

2.1方式1        

 安装库:

pip install numpy torch torchvision matplotlib

 运行:

python test.py

首次运行会下载MNIST数据集,请保持网络畅通

2.2方式2


3.程序代码

"""手写数字识别项目
    时间:2023.11.6
    环境:pytorch
    作者:Rainbook
"""

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

class Net(torch.nn.Module):  # 定义一个Net类,神经网络的主体
    def __init__(self):  # 全连接层,四个
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, 64)  # 输入层输入28*28,输出64
        self.fc2 = torch.nn.Linear(64, 64)  # 中间层,输入64,输出64
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)  # 中间层(隐藏层)的最后一层,输出10个特征值
    
    def forward(self, x):  # 前向传播过程
        # self.fc1(x)全连接线性计算,再套上一个激活函数torch.nn.functional.relu()
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        # 最后一层进行softmax归一化,log_softmax是为了提高计算稳定性,在softmax后面套上了一个对数运算
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x


def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])  # 定义数据转换类型tensor,多维数组(张量)
    """下载MNIST数据集,
        "":当前位置
        is_train:判断是训练集还是测试集;
        batch_size:一个批次包含15张图片;
        shuffle:数据随机打乱的
    """
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)  # 数据加载器


def evaluate(test_data, net):  # 用来评估神经网络
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28*28))  # 计算神经网络的预测值
            for i, output in enumerate(outputs):  # 对每个批次的预测值进行比较,累加正确预测的数量
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total  # 返回正确率


def main():
    # 导入训练集和测试集
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()  # 初始化神经网络

    # 打印初始网络的正确率,应当是10%附近。手写数字有十种结果,随机猜的正确率就是1/10
    print("initial accuracy:", evaluate(test_data, net))
    """训练神经网络
    pytorch的固定写法
    """
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(5):  # 需要在一个数据集上反复训练神经网络,epoch网络轮次,提高数据集的利用率
        for (x, y) in train_data:
            net.zero_grad()  # 初始化
            output = net.forward(x.view(-1, 28*28))  # 正向传播
            # 计算差值,nll_loss对数损失函数,为了匹配log_softmax的log运算
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()  # 反向误差传播
            optimizer.step()  # 优化网络参数
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))  # 打印当前网络的正确率

    """测试神经网络
        训练完成后,随机抽取3张图片进行测试
    """
    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28*28)))  # 测试结果
        plt.figure(n)  # 画出图像
        plt.imshow(x[0].view(28, 28))  # 像素大小28*28
        plt.title("prediction: " + str(int(predict)))  # figure的标题
    plt.show()


if __name__ == "__main__":
    main()

4.运行结果

4.1正确率

pytorch实现 --- 手写数字识别-LMLPHP

4.2测试结果

pytorch实现 --- 手写数字识别-LMLPHPpytorch实现 --- 手写数字识别-LMLPHPpytorch实现 --- 手写数字识别-LMLPHP

11-06 16:28