版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:[email protected]
前面几篇文章介绍了MINIST,对这种简单图片的识别,LeNet-5可以达到99%的识别率。
CIFAR10是另一个著名的深度学习图像分类识别数据集,比MINIST更复杂,而且是RGB彩色图片。
看看较简单的LeNet-5可以达到多少准确率。网络结构基本和前面MINIST代码中的差不多,主要是输入图片的通道数不同,代码如下:
# -*- coding:utf-8 -*- u"""LeNet卷积神经网络训练学习CIFAR10""" __author__ = 'zhengbiqing [email protected]' import torch as t
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch.backends.cudnn as cudnn import datetime
import argparse # 样本读取线程数
WORKERS = 4 # 网络参赛保存文件名
PARAS_FN = 'cifar_lenet_params.pkl' # minist数据存放位置
ROOT = '/home/zbq/PycharmProjects/cifar' # 目标函数
loss_func = nn.CrossEntropyLoss() # 最优结果
best_acc = 0 # 定义网络模型
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__() # 卷积层
self.cnn = nn.Sequential(
# 卷积层1,3通道输入,6个卷积核,核大小5*5
# 经过该层图像大小变为32-5+1,28*28
# 经2*2最大池化,图像变为14*14
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2), # 卷积层2,6输入通道,16个卷积核,核大小5*5
# 经过该层图像变为14-5+1,10*10
# 经2*2最大池化,图像变为5*5
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2)
) # 全连接层
self.fc = nn.Sequential(
# 16个feature,每个feature 5*5
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10)
) def forward(self, x):
x = self.cnn(x) # x.size()[0]: batch size
x = x.view(x.size()[0], -1)
x = self.fc(x) return x '''
训练并测试网络
net:网络模型
train_data_load:训练数据集
optimizer:优化器
epoch:第几次训练迭代
log_interval:训练过程中损失函数值和准确率的打印频率
'''
def net_train(net, train_data_load, optimizer, epoch, log_interval):
net.train() begin = datetime.datetime.now() # 样本总数
total = len(train_data_load.dataset) # 样本批次训练的损失函数值的和
train_loss = 0 # 识别正确的样本数
ok = 0 for i, data in enumerate(train_data_load, 0):
img, label = data
img, label = img.cuda(), label.cuda() optimizer.zero_grad() outs = net(img)
loss = loss_func(outs, label)
loss.backward()
optimizer.step() # 累加损失值和训练样本数
train_loss += loss.item()
# total += label.size(0) _, predicted = t.max(outs.data, 1)
# 累加识别正确的样本数
ok += (predicted == label).sum() if (i + 1) % log_interval == 0:
# 训练结果输出 # 损失函数均值
loss_mean = train_loss / (i + 1) # 已训练的样本数
traind_total = (i + 1) * len(label) # 准确度
acc = 100. * ok / traind_total # 一个迭代的进度百分比
progress = 100. * traind_total / total print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Acc: {:.6f}'.format(
epoch, traind_total, total, progress, loss_mean, acc)) end = datetime.datetime.now()
print('one epoch spend: ', end - begin) '''
用测试集检查准确率
'''
def net_test(net, test_data_load, epoch):
net.eval() ok = 0 for i, data in enumerate(test_data_load):
img, label = data
img, label = img.cuda(), label.cuda() outs = net(img)
_, pre = t.max(outs.data, 1)
ok += (pre == label).sum() acc = ok.item() * 100. / (len(test_data_load.dataset))
print('EPOCH:{}, ACC:{}\n'.format(epoch, acc)) global best_acc
if acc > best_acc:
best_acc = acc '''
显示数据集中一个图片
'''
def img_show(dataset, index):
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck') show = ToPILImage() data, label = dataset[index]
print('img is a ', classes[label])
show((data + 1) / 2).resize((100, 100)).show() def main():
# 训练超参数设置,可通过命令行设置
parser = argparse.ArgumentParser(description='PyTorch CIFA10 LeNet Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=20, metavar='N',
help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status (default: 100)')
parser.add_argument('--no-train', action='store_true', default=False,
help='If train the Model')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args() # 图像数值转换,ToTensor源码注释
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
# 归一化把[0.0, 1.0]变换为[-1,1], ([0, 1] - 0.5) / 0.5 = [-1, 1]
transform = tv.transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) # 定义数据集
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform) train_load = t.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=WORKERS)
test_load = t.utils.data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=WORKERS) net = LeNet().cuda()
print(net) # 如果不训练,直接加载保存的网络参数进行测试集验证
if args.no_train:
net.load_state_dict(t.load(PARAS_FN))
net_test(net, test_load, 0)
return optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum) start_time = datetime.datetime.now() for epoch in range(1, args.epochs + 1):
net_train(net, train_load, optimizer, epoch, args.log_interval) # 每个epoch结束后用测试集检查识别准确度
net_test(net, test_load, epoch) end_time = datetime.datetime.now() global best_acc
print('CIFAR10 pytorch LeNet Train: EPOCH:{}, BATCH_SZ:{}, LR:{}, ACC:{}'.format(args.epochs, args.batch_size, args.lr, best_acc))
print('train spend time: ', end_time - start_time) if args.save_model:
t.save(net.state_dict(), PARAS_FN) if __name__ == '__main__':
main()
运行结果如下:
Files already downloaded and verified
LeNet(
(cnn): Sequential(
(0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(fc): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): ReLU()
(2): Linear(in_features=120, out_features=84, bias=True)
(3): ReLU()
(4): Linear(in_features=84, out_features=10, bias=True)
)
)
Train Epoch: 1 [6400/50000 (13%)] Loss: 2.297558 Acc: 10.000000
Train Epoch: 1 [12800/50000 (26%)] Loss: 2.219855 Acc: 16.000000
Train Epoch: 1 [19200/50000 (38%)] Loss: 2.117518 Acc: 20.000000
Train Epoch: 1 [25600/50000 (51%)] Loss: 2.030452 Acc: 23.000000
Train Epoch: 1 [32000/50000 (64%)] Loss: 1.956154 Acc: 26.000000
Train Epoch: 1 [38400/50000 (77%)] Loss: 1.894052 Acc: 29.000000
Train Epoch: 1 [44800/50000 (90%)] Loss: 1.845520 Acc: 31.000000
one epoch spend: 0:00:02.007186
EPOCH:1, ACC:43.86
Train Epoch: 2 [6400/50000 (13%)] Loss: 1.497962 Acc: 44.000000
Train Epoch: 2 [12800/50000 (26%)] Loss: 1.471271 Acc: 45.000000
Train Epoch: 2 [19200/50000 (38%)] Loss: 1.458853 Acc: 46.000000
Train Epoch: 2 [25600/50000 (51%)] Loss: 1.445787 Acc: 47.000000
Train Epoch: 2 [32000/50000 (64%)] Loss: 1.436431 Acc: 47.000000
Train Epoch: 2 [38400/50000 (77%)] Loss: 1.425798 Acc: 47.000000
Train Epoch: 2 [44800/50000 (90%)] Loss: 1.415501 Acc: 48.000000
one epoch spend: 0:00:01.879316
EPOCH:2, ACC:53.16
Train Epoch: 3 [6400/50000 (13%)] Loss: 1.288907 Acc: 52.000000
Train Epoch: 3 [12800/50000 (26%)] Loss: 1.293646 Acc: 53.000000
Train Epoch: 3 [19200/50000 (38%)] Loss: 1.284784 Acc: 53.000000
Train Epoch: 3 [25600/50000 (51%)] Loss: 1.281050 Acc: 53.000000
Train Epoch: 3 [32000/50000 (64%)] Loss: 1.281222 Acc: 53.000000
Train Epoch: 3 [38400/50000 (77%)] Loss: 1.269620 Acc: 54.000000
Train Epoch: 3 [44800/50000 (90%)] Loss: 1.262982 Acc: 54.000000
one epoch spend: 0:00:01.928787
EPOCH:3, ACC:54.31
Train Epoch: 4 [6400/50000 (13%)] Loss: 1.157912 Acc: 58.000000
Train Epoch: 4 [12800/50000 (26%)] Loss: 1.157038 Acc: 58.000000
Train Epoch: 4 [19200/50000 (38%)] Loss: 1.164880 Acc: 58.000000
Train Epoch: 4 [25600/50000 (51%)] Loss: 1.169460 Acc: 58.000000
Train Epoch: 4 [32000/50000 (64%)] Loss: 1.169655 Acc: 58.000000
Train Epoch: 4 [38400/50000 (77%)] Loss: 1.169239 Acc: 58.000000
Train Epoch: 4 [44800/50000 (90%)] Loss: 1.159252 Acc: 58.000000
one epoch spend: 0:00:01.928551
EPOCH:4, ACC:60.15
Train Epoch: 5 [6400/50000 (13%)] Loss: 1.081296 Acc: 61.000000
Train Epoch: 5 [12800/50000 (26%)] Loss: 1.073868 Acc: 61.000000
Train Epoch: 5 [19200/50000 (38%)] Loss: 1.086076 Acc: 61.000000
Train Epoch: 5 [25600/50000 (51%)] Loss: 1.088019 Acc: 61.000000
Train Epoch: 5 [32000/50000 (64%)] Loss: 1.083983 Acc: 61.000000
Train Epoch: 5 [38400/50000 (77%)] Loss: 1.088050 Acc: 61.000000
Train Epoch: 5 [44800/50000 (90%)] Loss: 1.087298 Acc: 61.000000
one epoch spend: 0:00:01.898825
EPOCH:5, ACC:59.84
Train Epoch: 6 [6400/50000 (13%)] Loss: 0.979352 Acc: 65.000000
Train Epoch: 6 [12800/50000 (26%)] Loss: 1.005338 Acc: 64.000000
Train Epoch: 6 [19200/50000 (38%)] Loss: 1.019300 Acc: 63.000000
Train Epoch: 6 [25600/50000 (51%)] Loss: 1.022704 Acc: 63.000000
Train Epoch: 6 [32000/50000 (64%)] Loss: 1.021217 Acc: 63.000000
Train Epoch: 6 [38400/50000 (77%)] Loss: 1.022035 Acc: 63.000000
Train Epoch: 6 [44800/50000 (90%)] Loss: 1.024987 Acc: 63.000000
one epoch spend: 0:00:01.926922
EPOCH:6, ACC:60.04
Train Epoch: 7 [6400/50000 (13%)] Loss: 0.952975 Acc: 66.000000
Train Epoch: 7 [12800/50000 (26%)] Loss: 0.965437 Acc: 65.000000
Train Epoch: 7 [19200/50000 (38%)] Loss: 0.964711 Acc: 65.000000
Train Epoch: 7 [25600/50000 (51%)] Loss: 0.962520 Acc: 65.000000
Train Epoch: 7 [32000/50000 (64%)] Loss: 0.964768 Acc: 65.000000
Train Epoch: 7 [38400/50000 (77%)] Loss: 0.966530 Acc: 65.000000
Train Epoch: 7 [44800/50000 (90%)] Loss: 0.971995 Acc: 65.000000
one epoch spend: 0:00:01.858537
EPOCH:7, ACC:62.63
Train Epoch: 8 [6400/50000 (13%)] Loss: 0.901441 Acc: 67.000000
Train Epoch: 8 [12800/50000 (26%)] Loss: 0.896776 Acc: 68.000000
Train Epoch: 8 [19200/50000 (38%)] Loss: 0.898365 Acc: 68.000000
Train Epoch: 8 [25600/50000 (51%)] Loss: 0.898383 Acc: 68.000000
Train Epoch: 8 [32000/50000 (64%)] Loss: 0.909455 Acc: 67.000000
Train Epoch: 8 [38400/50000 (77%)] Loss: 0.910068 Acc: 67.000000
Train Epoch: 8 [44800/50000 (90%)] Loss: 0.914733 Acc: 67.000000
one epoch spend: 0:00:01.849259
EPOCH:8, ACC:62.99
Train Epoch: 9 [6400/50000 (13%)] Loss: 0.842184 Acc: 69.000000
Train Epoch: 9 [12800/50000 (26%)] Loss: 0.853178 Acc: 69.000000
Train Epoch: 9 [19200/50000 (38%)] Loss: 0.863828 Acc: 69.000000
Train Epoch: 9 [25600/50000 (51%)] Loss: 0.868452 Acc: 69.000000
Train Epoch: 9 [32000/50000 (64%)] Loss: 0.870991 Acc: 69.000000
Train Epoch: 9 [38400/50000 (77%)] Loss: 0.874963 Acc: 69.000000
Train Epoch: 9 [44800/50000 (90%)] Loss: 0.878533 Acc: 68.000000
one epoch spend: 0:00:01.954615
EPOCH:9, ACC:62.5
Train Epoch: 10 [6400/50000 (13%)] Loss: 0.837819 Acc: 70.000000
Train Epoch: 10 [12800/50000 (26%)] Loss: 0.823905 Acc: 70.000000
Train Epoch: 10 [19200/50000 (38%)] Loss: 0.833733 Acc: 70.000000
Train Epoch: 10 [25600/50000 (51%)] Loss: 0.838861 Acc: 70.000000
Train Epoch: 10 [32000/50000 (64%)] Loss: 0.841117 Acc: 70.000000
Train Epoch: 10 [38400/50000 (77%)] Loss: 0.849762 Acc: 69.000000
Train Epoch: 10 [44800/50000 (90%)] Loss: 0.850071 Acc: 69.000000
one epoch spend: 0:00:01.812348
EPOCH:10, ACC:63.48
Train Epoch: 11 [6400/50000 (13%)] Loss: 0.781857 Acc: 72.000000
Train Epoch: 11 [12800/50000 (26%)] Loss: 0.773329 Acc: 72.000000
Train Epoch: 11 [19200/50000 (38%)] Loss: 0.785191 Acc: 72.000000
Train Epoch: 11 [25600/50000 (51%)] Loss: 0.797921 Acc: 71.000000
Train Epoch: 11 [32000/50000 (64%)] Loss: 0.802146 Acc: 71.000000
Train Epoch: 11 [38400/50000 (77%)] Loss: 0.804404 Acc: 71.000000
Train Epoch: 11 [44800/50000 (90%)] Loss: 0.805919 Acc: 71.000000
one epoch spend: 0:00:01.881838
EPOCH:11, ACC:63.72
Train Epoch: 12 [6400/50000 (13%)] Loss: 0.734165 Acc: 74.000000
Train Epoch: 12 [12800/50000 (26%)] Loss: 0.739923 Acc: 74.000000
Train Epoch: 12 [19200/50000 (38%)] Loss: 0.753080 Acc: 73.000000
Train Epoch: 12 [25600/50000 (51%)] Loss: 0.755026 Acc: 73.000000
Train Epoch: 12 [32000/50000 (64%)] Loss: 0.758760 Acc: 73.000000
Train Epoch: 12 [38400/50000 (77%)] Loss: 0.765208 Acc: 72.000000
Train Epoch: 12 [44800/50000 (90%)] Loss: 0.774539 Acc: 72.000000
one epoch spend: 0:00:01.856290
EPOCH:12, ACC:63.71
Train Epoch: 13 [6400/50000 (13%)] Loss: 0.709528 Acc: 75.000000
Train Epoch: 13 [12800/50000 (26%)] Loss: 0.713831 Acc: 74.000000
Train Epoch: 13 [19200/50000 (38%)] Loss: 0.720146 Acc: 74.000000
Train Epoch: 13 [25600/50000 (51%)] Loss: 0.723680 Acc: 74.000000
Train Epoch: 13 [32000/50000 (64%)] Loss: 0.730473 Acc: 73.000000
Train Epoch: 13 [38400/50000 (77%)] Loss: 0.742575 Acc: 73.000000
Train Epoch: 13 [44800/50000 (90%)] Loss: 0.744857 Acc: 73.000000
one epoch spend: 0:00:01.808256
EPOCH:13, ACC:61.71
Train Epoch: 14 [6400/50000 (13%)] Loss: 0.700821 Acc: 74.000000
Train Epoch: 14 [12800/50000 (26%)] Loss: 0.691082 Acc: 75.000000
Train Epoch: 14 [19200/50000 (38%)] Loss: 0.693119 Acc: 75.000000
Train Epoch: 14 [25600/50000 (51%)] Loss: 0.706147 Acc: 74.000000
Train Epoch: 14 [32000/50000 (64%)] Loss: 0.710033 Acc: 74.000000
Train Epoch: 14 [38400/50000 (77%)] Loss: 0.717097 Acc: 74.000000
Train Epoch: 14 [44800/50000 (90%)] Loss: 0.724987 Acc: 74.000000
one epoch spend: 0:00:01.797417
EPOCH:14, ACC:63.15
Train Epoch: 15 [6400/50000 (13%)] Loss: 0.624073 Acc: 77.000000
Train Epoch: 15 [12800/50000 (26%)] Loss: 0.637354 Acc: 77.000000
Train Epoch: 15 [19200/50000 (38%)] Loss: 0.646385 Acc: 76.000000
Train Epoch: 15 [25600/50000 (51%)] Loss: 0.662080 Acc: 76.000000
Train Epoch: 15 [32000/50000 (64%)] Loss: 0.668658 Acc: 76.000000
Train Epoch: 15 [38400/50000 (77%)] Loss: 0.679682 Acc: 75.000000
Train Epoch: 15 [44800/50000 (90%)] Loss: 0.688876 Acc: 75.000000
one epoch spend: 0:00:01.916400
EPOCH:15, ACC:62.81
Train Epoch: 16 [6400/50000 (13%)] Loss: 0.611007 Acc: 78.000000
Train Epoch: 16 [12800/50000 (26%)] Loss: 0.612629 Acc: 78.000000
Train Epoch: 16 [19200/50000 (38%)] Loss: 0.622980 Acc: 77.000000
Train Epoch: 16 [25600/50000 (51%)] Loss: 0.638267 Acc: 77.000000
Train Epoch: 16 [32000/50000 (64%)] Loss: 0.650756 Acc: 76.000000
Train Epoch: 16 [38400/50000 (77%)] Loss: 0.656675 Acc: 76.000000
Train Epoch: 16 [44800/50000 (90%)] Loss: 0.665181 Acc: 75.000000
one epoch spend: 0:00:01.878367
EPOCH:16, ACC:61.64
Train Epoch: 17 [6400/50000 (13%)] Loss: 0.591583 Acc: 78.000000
Train Epoch: 17 [12800/50000 (26%)] Loss: 0.601943 Acc: 78.000000
Train Epoch: 17 [19200/50000 (38%)] Loss: 0.612084 Acc: 78.000000
Train Epoch: 17 [25600/50000 (51%)] Loss: 0.619225 Acc: 77.000000
Train Epoch: 17 [32000/50000 (64%)] Loss: 0.633562 Acc: 77.000000
Train Epoch: 17 [38400/50000 (77%)] Loss: 0.641217 Acc: 77.000000
Train Epoch: 17 [44800/50000 (90%)] Loss: 0.648393 Acc: 76.000000
one epoch spend: 0:00:01.894760
EPOCH:17, ACC:61.44
Train Epoch: 18 [6400/50000 (13%)] Loss: 0.553651 Acc: 80.000000
Train Epoch: 18 [12800/50000 (26%)] Loss: 0.569668 Acc: 79.000000
Train Epoch: 18 [19200/50000 (38%)] Loss: 0.584057 Acc: 78.000000
Train Epoch: 18 [25600/50000 (51%)] Loss: 0.598776 Acc: 78.000000
Train Epoch: 18 [32000/50000 (64%)] Loss: 0.610767 Acc: 78.000000
Train Epoch: 18 [38400/50000 (77%)] Loss: 0.617563 Acc: 77.000000
Train Epoch: 18 [44800/50000 (90%)] Loss: 0.628669 Acc: 77.000000
one epoch spend: 0:00:01.925175
EPOCH:18, ACC:62.46
Train Epoch: 19 [6400/50000 (13%)] Loss: 0.554530 Acc: 79.000000
Train Epoch: 19 [12800/50000 (26%)] Loss: 0.574952 Acc: 78.000000
Train Epoch: 19 [19200/50000 (38%)] Loss: 0.576819 Acc: 79.000000
Train Epoch: 19 [25600/50000 (51%)] Loss: 0.584052 Acc: 78.000000
Train Epoch: 19 [32000/50000 (64%)] Loss: 0.590673 Acc: 78.000000
Train Epoch: 19 [38400/50000 (77%)] Loss: 0.599807 Acc: 78.000000
Train Epoch: 19 [44800/50000 (90%)] Loss: 0.607849 Acc: 78.000000
one epoch spend: 0:00:01.827582
EPOCH:19, ACC:62.16
Train Epoch: 20 [6400/50000 (13%)] Loss: 0.534505 Acc: 80.000000
Train Epoch: 20 [12800/50000 (26%)] Loss: 0.547133 Acc: 80.000000
Train Epoch: 20 [19200/50000 (38%)] Loss: 0.557482 Acc: 79.000000
Train Epoch: 20 [25600/50000 (51%)] Loss: 0.567949 Acc: 79.000000
Train Epoch: 20 [32000/50000 (64%)] Loss: 0.579047 Acc: 79.000000
Train Epoch: 20 [38400/50000 (77%)] Loss: 0.591825 Acc: 78.000000
Train Epoch: 20 [44800/50000 (90%)] Loss: 0.598099 Acc: 78.000000
one epoch spend: 0:00:01.846124
EPOCH:20, ACC:62.47
CIFAR10 pytorch LeNet Train: EPOCH:20, BATCH_SZ:64, LR:0.01, ACC:63.72
train spend time: 0:00:46.669295
Process finished with exit code 0
训练的lenet准确度在63%左右,远低于MINIST的99%,简单的LeNet对较复杂的图片准确度不高。