ResNet是一种非常有效的图像分类识别的模型,可以参考如下的链接
https://blog.csdn.net/qq_45649076/article/details/120494328
ResNet网络由残差(Residual)结构的基本模块构成,每一个基本模块包含几个卷积层。其中,除了网络的推理输出,基本模块的输入也被直接加成到模块的输出。这种设计可以防止网络在深度加大之后产生退化的现象。
通常所见的有Resnet-18,Resnet-34,Resnet-50,Resnet-152等多层神经网络。作为示例,以下通过Tensorflow来构建Resnet-18模型。
首先,导入需要的模块
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import models, layers
import matplotlib.pyplot as plt
然后,定义Resnet的基本模块
# define class of basic block
class ResBlock(keras.Model):
def __init__(self, filters, strides=1, down_sample=False):
super().__init__()
self.down_sample = down_sample
self.conv1 = layers.Conv2D(filters, (3,3), strides=strides, padding='same', use_bias=False)
self.bn1 = layers.BatchNormalization()
self.relu1 = layers.Activation('relu')
self.conv2 = layers.Conv2D(filters, (3,3), strides=1, padding='same', use_bias=False)
self.bn2 = layers.BatchNormalization()
if self.down_sample:
self.down_conv = layers.Conv2D(filters, (1,1), strides=strides, padding='same', use_bias=False)
self.down_bn = layers.BatchNormalization()
self.relu2 = layers.Activation('relu')
def call(self, inputs):
net = self.conv1(inputs)
net = self.bn1(net)
net = self.relu1(net)
net = self.conv2(net)
net = self.bn2(net)
# down sample inputs if dimension changes
if self.down_sample:
identity_out = self.down_conv(inputs)
identity_out = self.down_bn(identity_out)
else:
identity_out = inputs
net = self.relu2(net+identity_out)
return net
ResBlock由两个卷积层组成,每一个卷积层后面跟BatchNormalization层和Relu层。输入层的数据与第二层的卷积输出相加,通过Relu层产生模块的输出。这里分为两类,如果模块中的第一个卷积层进行了stride>1(通常为2)的降维卷积,那么输入也需要进行kernel_size为1的降维操作。
然后根据ResBlock来构建Resnet网络
# define class of Resnet-18 model
class Resnet18(keras.Model):
def __init__(self, initial_filters=64):
# each item in block_list represent number of base blocks(ResBlock) in that block
super().__init__()
filters = initial_filters
# input layers
self.input_layer = models.Sequential([
layers.Conv2D(filters, (3,3), strides=1, padding='same', use_bias=False),
layers.BatchNormalization(),
layers.Activation('relu')
])
# first layers, no down sample
self.layer1 = models.Sequential([
ResBlock(filters),
ResBlock(filters)
])
# second layer, filters doubles
filters *= 2
self.layer2 = models.Sequential([
ResBlock(filters, strides=2, down_sample=True),
ResBlock(filters)
])
# third layer
filters *= 2
self.layer3 = models.Sequential([
ResBlock(filters, strides=2, down_sample=True),
ResBlock(filters)
])
# third layer
filters *= 2
self.layer4 = models.Sequential([
ResBlock(filters, strides=2, down_sample=True),
ResBlock(filters)
])
# output layer
self.output_layer = models.Sequential([
layers.GlobalAveragePooling2D(),
layers.Dense(10, activation='softmax')
])
def call(self, inputs):
# input layer
net = self.input_layer(inputs)
# Resnet layers
net = self.layer1(net)
net = self.layer2(net)
net = self.layer3(net)
net = self.layer4(net)
# output layer
net = self.output_layer(net)
return net
Resnet18由一个输入层、4个中间层和一个输出层组成,每一个中间层包含两基本模块,除了第一个中间层,每一层的第一个ResBlock为降维操作,第二个ResBlock为同维的卷积操作。每到下一个中间层,卷积特征的个数倍增。输入层为一个卷积层和一个BN层、一个Relu层组成。有些地方用7x7的降维卷积层和池化层,这里准备用于尺寸较小的CIFAR-10数据集,不进行降维操作。输出层为一个均值池化层和全连接层的连接。
下面构建网络并在CIFAR-10数据集上进行测试。
构建Resnet18网络,选择优化器
# build model, Resnet18
model = Resnet18()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
载入CIFAR-10数据
# test on CIFAR-10 data, load CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
# train_images: 50000*32*32*3, train_labels: 50000*1
# test_images: 10000*32*32*3, test_labels: 10000*1
# pre-process data
train_input = train_images/255.0
test_input = test_images/255.0
train_output = train_labels
test_output = test_labels
定义数据处理器
# define data generator
data_generator = keras.preprocessing.image.ImageDataGenerator(
featurewise_center=False, #set mean of data to 0, by feature
samplewise_center=False, #set mean of data to 0, by sample
featurewise_std_normalization=False, #normalize by std, by feature
samplewise_std_normalization=False, #normalize by std, by sample
zca_whitening=False, #zca whitening
#zca_epsilon #zca epsilon, default 1e-6
rotation_range=15, #degree of random rotation (integer, 0-180)
width_shift_range=0.1, #probability of horizontal shift
height_shift_range=0.1, #probability of vertical shift
horizontal_flip=True, #if random horizantal flip
vertical_flip=False #if random vertical flip
)
data_generator.fit(train_input)
进行训练
# train, with batch size and epochs
epochs = 60
batch_size = 128
history = model.fit(data_generator.flow(train_input, train_output, batch_size=batch_size), epochs=epochs,
steps_per_epoch=len(train_input)//batch_size, validation_data=(test_input, test_output))
结果如下
Epoch 1/60
390/390 [==============================] - 47s 113ms/step - loss: 1.4509 - sparse_categorical_accuracy: 0.4840 - val_loss: 2.3173 - val_sparse_categorical_accuracy: 0.3550
Epoch 2/60
390/390 [==============================] - 43s 110ms/step - loss: 0.9640 - sparse_categorical_accuracy: 0.6601 - val_loss: 1.2524 - val_sparse_categorical_accuracy: 0.5868
Epoch 3/60
390/390 [==============================] - 43s 111ms/step - loss: 0.7742 - sparse_categorical_accuracy: 0.7273 - val_loss: 1.4201 - val_sparse_categorical_accuracy: 0.5968
...
Epoch 59/60
390/390 [==============================] - 43s 110ms/step - loss: 0.0483 - sparse_categorical_accuracy: 0.9824 - val_loss: 0.3880 - val_sparse_categorical_accuracy: 0.9106
Epoch 60/60
390/390 [==============================] - 43s 110ms/step - loss: 0.0501 - sparse_categorical_accuracy: 0.9830 - val_loss: 0.4432 - val_sparse_categorical_accuracy: 0.9006
最后达到训练精度98.3%,测试精度90.06%。
绘制训练曲线
# plot train history
loss = history.history['loss']
val_loss = history.history['val_loss']
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
plt.figure(figsize=(11,3.5))
plt.subplot(1,2,1)
plt.plot(loss, color='blue', label='train')
plt.plot(val_loss, color='red', label='test')
plt.ylabel('loss')
plt.legend()
plt.subplot(1,2,2)
plt.plot(acc, color='blue', label='train')
plt.plot(val_acc, color='red', label='test')
plt.ylabel('accuracy')
plt.legend()
结果如下