Alexnet-2012
研究背景
ILSVRC-2012 ImageNet Large Scale Visual Recognition Challenge.
ImageNet 数据集包含21,841个 类别, 14,197,122张图片。
top 5 error
AlexNet 在ILSVRC-2012以超出第二名10.9个百分点夺冠。
研究意义
- 里程碑式的论文
- 加速计算机视觉应用落地。 端到端式的,不需要再加特征工程。
论文精读
Abstruct
- ILSVRC-2010的120万张图片上训练AlexNet,最有结果: top1 error: 37.5, top-5 error 17%.
- 该网络由5个卷积层和3个全联接层组成,共计6000万个参数, 65万个神经元。
- 为加快训练,采用ReLU + GPU进行训练。
- 为减轻过拟合,采用Dropout.
- 基于上面的技巧,在ILSVRC-2012以超过第二名10.9个百分点的成绩夺冠。
1. Introduction
3. Architecture
3.1 ReLU Nonlinearity
ReLU Nonlinearity
f ( x ) = m a x ( 0 , x ) f(x) = max(0,x) f(x)=max(0,x)
Tanh 激活函数
f ( x ) = 1 1 + e − x f(x) = \frac{1}{1+ e^{-x}} f(x)=1+e−x1
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-u5AuKAgR-1669454493669)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125142750067.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jedPVQAr-1669454493670)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125143351852.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xsITmv18-1669454493670)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125143323365.png)]
3.2 Training on Multiple GPUs
3.3 LRN local response normalization
局部响应标准化
局部响应标准化:有助于AlexNet泛化能力的提升,受到真实神经元侧抑制启发。
侧抑制:细胞分化变为不同时,它会对周围细胞产生抑制信号,阻止它们向相同的方向分化,最终表现为细胞命运的不同。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PYPDrwYz-1669454493671)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125143654127.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B1P6Olno-1669454493671)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125144420240.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bl8cgN2z-1669454493671)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125144449149.png)]
3.4 overlapping pooling
带重叠的池化层
3.5 Overall Architecture
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PDYK01zq-1669454493672)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125114705661.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RM3vdF0d-1669454493673)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125140113824.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4zYD3PC0-1669454493673)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125140142396.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-W7ktWEks-1669454493674)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125141951508.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rsxS6DuG-1669454493674)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125142006998.png)]
4. Reducing Overfitting
4.1 Data Augmentation 图像增强
方法1. 针对位置
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Gg67stBy-1669454493674)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125145455866.png)]
方法2 针对色彩
通过PCA方法修改RGB通道的像素值,实现颜色扰动,效果有限。
4.2 DropOut
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YcagXjJc-1669454493675)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125150515033.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yu9zZ1aw-1669454493675)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125150551510.png)]
实验结果和分析
卷积核的可视化
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EwOq5u9N-1669454493675)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125151740366.png)]
特征的相似性
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QtFT1dfh-1669454493676)(https://gitee.com/ml_666/markdown_pictures/raw/master/image-20221125152844529.png)]
相似图片的第二个全联接层输出特征的欧式距离相近。
关键代码
torch.topk(input, k, dim = None, largest=True, sorted = True, out = None)
"""
功能: 找出前k大的数据,及其索引序列号
1. input : 张量
2. k 决定选取k个值
3. dim: 索引维度
返回
1. Tensor: 前k个最大的值
2. LongTensor: 前k大的值所在的位置
"""
transforms.FiveCrop(size)
transforms.TenCrop(size, vertical_flip = False)
"""
功能:在图片的上下左右及其中心裁出尺寸为size的五张图片, TenCrop 对这五张图片进行水平或者垂直镜像获得10张图片。
1. size: 所需要裁剪的尺寸
2. vertical_flip 是否要垂直翻转
"""
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range = None, scale_each = False, pad_value = 0)
"""
功能:制作网格图像
1. tensor: 图像数据, B*C*H*W 形式
2. nrow : 行数(列数自动计算)
3. padding : 图像间距(像素单位)
4. normalize: 是否将像素值标准化
5. range 标准化范围
6. scale_each: 是否单张图片维度标准化
7. pad_value: padding 的像素值
"""
torchvision介绍
torchvision.datasets: 一些加载数据的函数及常用的数据接口
torchvision.models: 包含常用的深度学习模型(含预训练模型)
torchvision.transforms: 常用的图像变化,例如裁剪,旋转等
torchvision.utils: 其他的一些有用的方法
class torchvision.transforms.Compose(transforms):
# Composes several transforms together
# parameters: transforms (list of transform objects) -list of transforms to compose
transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])
model.eval()
# 模型中有BatchNormalization和Dropout,在预测时使用model.eval()后会将其关闭以免影响预测结果。
# https://blog.csdn.net/qq_38410428/article/details/101102075
# -*- coding: utf-8 -*-
"""
# @file name : train_alexnet.py
# @author : TingsongYu https://github.com/TingsongYu
# @date : 2020-02-14
# @brief : alexnet traning
"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from ToolsUtils.my_dataset import CatDogDataset
from torch.utils.data import DataLoader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_model(path_state_dict, vis_model=False):
"""
创建模型,加载参数
:param path_state_dict:
:return:
"""
model = models.alexnet()
pretrained_state_dict = torch.load(path_state_dict)
model.load_state_dict(pretrained_state_dict)
if vis_model:
from torchsummary import summary
summary(model, input_size=(3, 224, 224), device="cpu")
model.to(device)
return model
if __name__ == "__main__":
# config
data_dir = os.path.join(BASE_DIR, "..", "data", "train")
path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth")
num_classes = 2
MAX_EPOCH = 3 # 可自行修改
BATCH_SIZE = 128 # 可自行修改
LR = 0.001 # 可自行修改
log_interval = 1 # 可自行修改
val_interval = 1 # 可自行修改
classes = 2
start_epoch = -1
lr_decay_step = 1 # 可自行修改
# ============================ step 1/5 数据 ============================
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((256)), # (256, 256) 区别
transforms.CenterCrop(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
normalizes = transforms.Normalize(norm_mean, norm_std)
valid_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.TenCrop(224, vertical_flip=False),
transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops])),
])
# 构建MyDataset实例
train_data = CatDogDataset(data_dir=data_dir, mode="train", transform=train_transform)
valid_data = CatDogDataset(data_dir=data_dir, mode="valid", transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=4)
# ============================ step 2/5 模型 ============================
alexnet_model = get_model(path_state_dict, False)
num_ftrs = alexnet_model.classifier._modules["6"].in_features
alexnet_model.classifier._modules["6"] = nn.Linear(num_ftrs, num_classes)
alexnet_model.to(device)
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()
# ============================ step 4/5 优化器 ============================
# 冻结卷积层
flag = 0
# flag = 1
if flag:
fc_params_id = list(map(id, alexnet_model.classifier.parameters())) # 返回的是parameters的 内存地址
base_params = filter(lambda p: id(p) not in fc_params_id, alexnet_model.parameters())
optimizer = optim.SGD([
{'params': base_params, 'lr': LR * 0.1}, # 0
{'params': alexnet_model.classifier.parameters(), 'lr': LR}], momentum=0.9)
else:
optimizer = optim.SGD(alexnet_model.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) # 设置学习率下降策略
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(patience=5)
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(start_epoch + 1, MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
alexnet_model.train()
for i, data in enumerate(train_loader):
# if i > 1:
# break
# forward
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = alexnet_model(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().cpu().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
scheduler.step() # 更新学习率
# validate the model
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
alexnet_model.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
bs, ncrops, c, h, w = inputs.size() # [4, 10, 3, 224, 224
outputs = alexnet_model(inputs.view(-1, c, h, w))
outputs_avg = outputs.view(bs, ncrops, -1).mean(1)
loss = criterion(outputs_avg, labels)
_, predicted = torch.max(outputs_avg.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().cpu().sum().numpy()
loss_val += loss.item()
loss_val_mean = loss_val/len(valid_loader)
valid_curve.append(loss_val_mean)
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))
alexnet_model.train()
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.savefig()
plt.show()