一、什么是ModelEMA:

在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。

二、如何实现ModelEMA

创建EMA eval mode,去并行化

self.ema = deepcopy(de_parallel(model)).eval() 

EMA更新次数

self.updates = updates

根据更新次数,获取衰减系数

self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)

去掉梯度,ema不需要梯度

for p in self.ema.parameters():
    p.requires_grad_(False)

EMA更新次数+1

self.updates += 1

根据更新次数,获取衰减系数

d = self.decay(self.updates)

根据衰减系数,当前模型(去并行化)来修改当前ema模型

msd = de_parallel(model).state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1 - d) * msd[k].detach()

三、ModelEMA完整实现

#----------------------#
#   判断是否并行训练模式
#----------------------#
def is_parallel(model):
    # Returns True if model is of type DP or DDP
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)

#----------------------#
#   去并行训练模式
#----------------------#
def de_parallel(model):
    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
    return model.module if is_parallel(model) else model

#----------------------#
#   模型拷贝
#----------------------#
def copy_attr(a, b, include=(), exclude=()):
    # Copy attributes from b to a, options to only include [...] and to exclude [...]
    for k, v in b.__dict__.items():
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
            continue
        else:
            setattr(a, k, v)


class ModelEMA:
    """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    """

    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
        #----------------------#
        #   创建EMA eval mode,去并行化
        #----------------------#
        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
        #----------------------#
        #   EMA更新次数
        #----------------------#
        self.updates = updates
        #----------------------#
        #   根据更新次数,获取衰减系数
        #----------------------#
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
        #----------------------#
        #   去掉梯度,ema不需要梯度
        #----------------------#
        for p in self.ema.parameters():
            p.requires_grad_(False)

    #----------------------#
    #   根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
    #----------------------#
    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            #----------------------#
            #   EMA更新次数+1
            #----------------------#
            self.updates += 1
            #----------------------#
            #   根据更新次数,获取衰减系数
            #----------------------#
            d = self.decay(self.updates)
            print('decay:',d)
            dict_decay.append(d)

            #----------------------#
            #   根据衰减系数,当前模型(去并行化)来修改当前ema模型
            #----------------------#
            msd = de_parallel(model).state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1 - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)

四、ModelEMA在训练框架中的使用

#----------------------#
#   搭建训练框架
#----------------------#
model = models.AlexNet()
model = model.train()
#----------------------#
#   创建EMA模型
#----------------------#
ema = ModelEMA(model)


num_train_data = 50
batch_size = 10
epoch_step = num_train_data // batch_size

Init_epoch = 50
Total_epoch = 60

#----------------------#
#   记录EMA更新次数
#----------------------#
ema.updates = Init_epoch * epoch_step

#----------------------#
#   训练
#----------------------#
for epoch in range(Init_epoch, Total_epoch):
    dict_epoch.append(epoch)
    for iter in range(epoch_step):
        #----------------------#
        #   根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
        #----------------------#
        ema.update(model)

#----------------------#
#   验证
#----------------------#
#----------------------#
#   获取EMA eval mode,去并行化
#----------------------#
model = ema.ema
for epoch in range(Init_epoch, Total_epoch):
    for iter in range(epoch_step):
        pass

#----------------------#
#   保存权重
#----------------------#
#----------------------#
#   获取EMA模型的权重
#----------------------#
save_state_dict = ema.ema.state_dict()
path = "yourpath"
torch.save(save_state_dict,path)
print('dene')

五、完整代码

import torch
import math
import torch.nn as nn
from copy import deepcopy
from torchvision import models
import matplotlib.pyplot as plt

dict_decay = []
dict_update_num = []

#----------------------#
#   判断是否并行训练模式
#----------------------#
def is_parallel(model):
    # Returns True if model is of type DP or DDP
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)

#----------------------#
#   去并行训练模式
#----------------------#
def de_parallel(model):
    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
    return model.module if is_parallel(model) else model

#----------------------#
#   模型拷贝
#----------------------#
def copy_attr(a, b, include=(), exclude=()):
    # Copy attributes from b to a, options to only include [...] and to exclude [...]
    for k, v in b.__dict__.items():
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
            continue
        else:
            setattr(a, k, v)


class ModelEMA:
    """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    """

    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
        #----------------------#
        #   创建EMA eval mode,去并行化
        #----------------------#
        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
        #----------------------#
        #   EMA更新次数
        #----------------------#
        self.updates = updates
        #----------------------#
        #   根据更新次数,获取衰减系数
        #----------------------#
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
        #----------------------#
        #   去掉梯度,ema不需要梯度
        #----------------------#
        for p in self.ema.parameters():
            p.requires_grad_(False)

    #----------------------#
    #   根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
    #----------------------#
    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            #----------------------#
            #   EMA更新次数+1
            #----------------------#
            self.updates += 1
            #----------------------#
            #   根据更新次数,获取衰减系数
            #----------------------#
            d = self.decay(self.updates)
            print('decay:',d)
            dict_decay.append(d)

            #----------------------#
            #   根据衰减系数,当前模型(去并行化)来修改当前ema模型
            #----------------------#
            msd = de_parallel(model).state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1 - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)


#----------------------#
#   搭建训练框架
#----------------------#
model = models.AlexNet()
model = model.train()
#----------------------#
#   创建EMA模型
#----------------------#
ema = ModelEMA(model)


num_train_data = 100
batch_size = 10
epoch_step = num_train_data // batch_size

Init_epoch = 50
Total_epoch = 300

#----------------------#
#   记录EMA更新次数
#----------------------#
ema.updates = Init_epoch * epoch_step

#----------------------#
#   训练
#----------------------#
num_update = 0
for epoch in range(Init_epoch, Total_epoch):
    for iter in range(epoch_step):
        #----------------------#
        #   根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
        #----------------------#
        ema.update(model)
        num_update += 1
        dict_update_num.append(num_update)


#----------------------#
#   验证
#----------------------#
#----------------------#
#   获取EMA eval mode,去并行化
#----------------------#
model = ema.ema
for epoch in range(Init_epoch, Total_epoch):
    for iter in range(epoch_step):
        pass

#----------------------#
#   保存权重
#----------------------#
#----------------------#
#   获取EMA模型的权重
#----------------------#
save_state_dict = ema.ema.state_dict()
path = "yourpath"
#torch.save(save_state_dict,path)
print('dene')



# -----------------------------------------------#
#   save EMA decay figure
# -----------------------------------------------#
plt.figure()
plt.title('EMA decay during training')
plt.plot(dict_update_num, dict_decay, label="EMA decay")
plt.legend()
plt.grid()
plt.draw()
plt.savefig('EMA decay')
plt.show()

EMA decay 曲线变化图
【深度学习实战(25)】搭建训练框架之ModelEMA-LMLPHP

04-27 14:48