1. 引言

本文重点介绍BatchNorm的定义和相关特性,并介绍了其详细实现和具体应用。希望可以帮助大家加深对其理解。

嗯嗯,闲话少说,我们直接开始吧!

2. 什么是BatchNorm?

BatchNorm是2015年提出的网络层,这个层具有以下特性:

  • 易于训练:由于网络权重的分布随这一层的变化小得多,因此我们可以使用更高的学习率。我们在训练中收敛的方向没有那么不稳定,这样我们就可以更快地朝着loss收敛的方向前进。

  • 提升正则化:尽管网络在每个epoch都会遇到相同的训练样本,但每个小批量的归一化是不同的,因此每次都会稍微改变其值。

  • 提升精度:可能是由于前面两点的结合,论文提到他们获得了比当时最先进的结果更好的准确性。

3. BatchNorm是如何工作的?

BatchNorm所做的是确保接收到的输入具有平均值0和标准偏差1。
本文中介绍的算法如下:
一文弄懂CNN中的BatchNorm-LMLPHP
下面是我自己用pytorch进行的实现:

import numpy as np
import torch
from torch import nn
from torch.nn import Parameter

class BatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.gamma = Parameter(torch.Tensor(num_features))
        self.beta = Parameter(torch.Tensor(num_features))
        self.register_buffer("moving_avg", torch.zeros(num_features))
        self.register_buffer("moving_var", torch.ones(num_features))
        self.register_buffer("eps", torch.tensor(eps))
        self.register_buffer("momentum", torch.tensor(momentum))
        self._reset()
    
    def _reset(self):
        self.gamma.data.fill_(1)
        self.beta.data.fill_(0)
    
    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0)
            var = x.var(dim=0)
            self.moving_avg = self.moving_avg * momentum + mean * (1 - momentum)
            self.moving_var = self.moving_var * momentum + var * (1 - momentum)
        else:
            mean = self.moving_avg
            var = self.moving_var
            
        x_norm = (x - mean) / (torch.sqrt(var + self.eps))
        return x_norm * self.gamma + self.beta

这里对其进行补充说明如下:

  • 我们在训练和推理过程中BatchNorm有不同的行为。在训练中,我们记录均值和方差的指数移动平均值,以供以后在推理时使用。其原因是,在训练期间处理批次时,我们可以获得输入随时间变化的均值和方差的更好估计,然后将其用于推理。在推理过程中使用输入批次的平均值和方差将不太准确,因为其大小可能比训练中使用的小得多,大数定律在这里发挥了作用。

4. 什么时候使用Batchnorm ?

这似乎总是有帮助的,所以没有理由不使用它。通常它出现在全连接层/卷积层和激活函数之间。但也有人认为,最好把它放在激活层之后。我找不到任何关于激活函数之后使用它的论文,所以最安全的选择是按照每个人的做法,在激活函数前使用它。

5. 一些技巧总结

列举下关于实际应用中BatchNorm的技巧总结如下:

  • 我们知道,一个已经训练的网络包含用于训练它的数据集的移动平均值和方差,这可能是一个问题。在迁移学习期间,我们通常会冻结大部分层,如果不小心,BatchNorm层也会冻结,这意味着应用的移动平均值属于原始数据集,而不是新数据集。解冻BatchNorm层是一个好主意,将允许网络重新计算自己数据集上的移动平均值和方差。
11-21 06:50