深度学习--实战 LeNet5

数据集

数据集选用CIFAR-10的数据集,Cifar-10 是由 Hinton 的学生 Alex Krizhevsky、Ilya Sutskever 收集的一个用于普适物体识别的计算机视觉数据集,它包含 60000 张 32 X 32 的 RGB 彩色图片,总共 10 个分类。其中,包括 50000 张用于训练集,10000 张用于测试集。

模型实现

模型需要继承nn.module

import torch
from torch import  nn


class Lenet5(nn.Module):
    """
    for cifar10 dataset.
    """
    def __init__(self):
        super(Lenet5,self).__init__()

        self.conv_unit = nn.Sequential(
            #input:[b,3,32,32] ===> output:[b,6,x,x]
            #Conv2d(Input_channel:输入的通道数,kernel_channels:卷积核的数量,输出的通道数,kernel_size:卷积核的大小,stride:步长,padding:边缘补足)
            nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),

            #池化
            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),

            #卷积层
            nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),

            #池化
            nn.AvgPool2d(kernel_size=2,stride=2,padding=0)

            #output:[b,16,5,5]
        )

        #flatten

        #Linear层
        self.fc_unit=nn.Sequential(
            nn.Linear(16*5*5,120),
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10)
        )

        #测试卷积输出到全连接层的输入
        #tmp = torch.rand(2,3,32,32)
        #out = self.conv_unit(tmp)
        #print("conv_out:",out.shape)

        #Loss评价  Cross Entropy Loss  分类  在其中包含一个softmax()操作
        #self.criteon = nn.MSELoss()  回归
        #self.criteon = nn.CrossEntropyLoss()

    def forward(self,x):
        """

        :param x:[b,3,32,32]
        :return:
        """
        batchsz = x.size(0)
        #[b,3,32,32]=>[b,16,5,5]
        x = self.conv_unit(x)
        #[b,16,5,5]=>[b,16*5*5]
        x = x.view(batchsz,16*5*5)
        #[b,16*5*5]=>[b,10]
        logits = self.fc_unit(x)

        return logits

        # [b,10]
        # pred = F.softmax(logits,dim=1)  这步在CEL中包含了,所以不需要再写一次
        #loss = self.criteon(logits,y)




def main():
    net = Lenet5()
    tmp = torch.rand(2,3,32,32)
    out = net(tmp)
    print("lenet_out:",out.shape)

if __name__ == '__main__':
    main()

训练与测试

import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from lenet5 import Lenet5
import torch.nn.functional as F
from torch import  nn,optim

def main():

    batch_size = 32
    epochs = 1000
    learn_rate = 1e-3

    #导入图片,一次只导入一张
    cifer_train = datasets.CIFAR10('cifar',train=True,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)

    #加载图
    cifer_train = DataLoader(cifer_train,batch_size=batch_size,shuffle=True)

    #导入图片,一次只导入一张
    cifer_test = datasets.CIFAR10('cifar',train=False,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)

    #加载图
    cifer_test = DataLoader(cifer_test,batch_size=batch_size,shuffle=True)

    #iter迭代器,__next__()方法可以获得数据
    x, label = iter(cifer_train).__next__()
    print("x:",x.shape,"label:",label.shape)
    #x: torch.Size([32, 3, 32, 32]) label: torch.Size([32])


    device = torch.device('cuda')
    model = Lenet5().to(device)
    print(model)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=learn_rate)


    for epoch in range(epochs):
        model.train()
        for batchidx,(x,label) in enumerate(cifer_train):
            x,label = x.to(device),label.to(device)

            logits = model(x)
            #logits:[b,10]

            loss = criteon(logits,label)

            #backprop
            optimizer.zero_grad()  #梯度清零
            loss.backward()
            optimizer.step()  #梯度更新
        #
        print(epoch,loss.item())

        model.eval()
        with torch.no_grad():
            #test
            total_correct = 0
            total_num = 0
            for x,label in cifer_test:
                x,label = x.to(device),label.to(device)
                #[b,10]
                logits = model(x)
                #[b]
                pred =logits.argmax(dim=1)

                #[b] vs [b] => scalar tensor
                total_correct += torch.eq(pred,label).float().sum().item()
                total_num += x.size(0)

        acc = total_correct/total_num
        print("epoch:",epoch,"acc:",acc)


if __name__ == '__main__':
    main()
05-04 05:22