一、什么是ModelEMA:
在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。
二、如何实现ModelEMA
创建EMA eval mode,去并行化
self.ema = deepcopy(de_parallel(model)).eval()
EMA更新次数
self.updates = updates
根据更新次数,获取衰减系数
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
去掉梯度,ema不需要梯度
for p in self.ema.parameters():
p.requires_grad_(False)
EMA更新次数+1
self.updates += 1
根据更新次数,获取衰减系数
d = self.decay(self.updates)
根据衰减系数,当前模型(去并行化)来修改当前ema模型
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
三、ModelEMA完整实现
#----------------------#
# 判断是否并行训练模式
#----------------------#
def is_parallel(model):
# Returns True if model is of type DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
#----------------------#
# 去并行训练模式
#----------------------#
def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
return model.module if is_parallel(model) else model
#----------------------#
# 模型拷贝
#----------------------#
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, v)
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
#----------------------#
# 创建EMA eval mode,去并行化
#----------------------#
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
#----------------------#
# EMA更新次数
#----------------------#
self.updates = updates
#----------------------#
# 根据更新次数,获取衰减系数
#----------------------#
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
#----------------------#
# 去掉梯度,ema不需要梯度
#----------------------#
for p in self.ema.parameters():
p.requires_grad_(False)
#----------------------#
# 根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
#----------------------#
def update(self, model):
# Update EMA parameters
with torch.no_grad():
#----------------------#
# EMA更新次数+1
#----------------------#
self.updates += 1
#----------------------#
# 根据更新次数,获取衰减系数
#----------------------#
d = self.decay(self.updates)
print('decay:',d)
dict_decay.append(d)
#----------------------#
# 根据衰减系数,当前模型(去并行化)来修改当前ema模型
#----------------------#
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)
四、ModelEMA在训练框架中的使用
#----------------------#
# 搭建训练框架
#----------------------#
model = models.AlexNet()
model = model.train()
#----------------------#
# 创建EMA模型
#----------------------#
ema = ModelEMA(model)
num_train_data = 50
batch_size = 10
epoch_step = num_train_data // batch_size
Init_epoch = 50
Total_epoch = 60
#----------------------#
# 记录EMA更新次数
#----------------------#
ema.updates = Init_epoch * epoch_step
#----------------------#
# 训练
#----------------------#
for epoch in range(Init_epoch, Total_epoch):
dict_epoch.append(epoch)
for iter in range(epoch_step):
#----------------------#
# 根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
#----------------------#
ema.update(model)
#----------------------#
# 验证
#----------------------#
#----------------------#
# 获取EMA eval mode,去并行化
#----------------------#
model = ema.ema
for epoch in range(Init_epoch, Total_epoch):
for iter in range(epoch_step):
pass
#----------------------#
# 保存权重
#----------------------#
#----------------------#
# 获取EMA模型的权重
#----------------------#
save_state_dict = ema.ema.state_dict()
path = "yourpath"
torch.save(save_state_dict,path)
print('dene')
五、完整代码
import torch
import math
import torch.nn as nn
from copy import deepcopy
from torchvision import models
import matplotlib.pyplot as plt
dict_decay = []
dict_update_num = []
#----------------------#
# 判断是否并行训练模式
#----------------------#
def is_parallel(model):
# Returns True if model is of type DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
#----------------------#
# 去并行训练模式
#----------------------#
def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
return model.module if is_parallel(model) else model
#----------------------#
# 模型拷贝
#----------------------#
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, v)
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
#----------------------#
# 创建EMA eval mode,去并行化
#----------------------#
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
#----------------------#
# EMA更新次数
#----------------------#
self.updates = updates
#----------------------#
# 根据更新次数,获取衰减系数
#----------------------#
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
#----------------------#
# 去掉梯度,ema不需要梯度
#----------------------#
for p in self.ema.parameters():
p.requires_grad_(False)
#----------------------#
# 根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
#----------------------#
def update(self, model):
# Update EMA parameters
with torch.no_grad():
#----------------------#
# EMA更新次数+1
#----------------------#
self.updates += 1
#----------------------#
# 根据更新次数,获取衰减系数
#----------------------#
d = self.decay(self.updates)
print('decay:',d)
dict_decay.append(d)
#----------------------#
# 根据衰减系数,当前模型(去并行化)来修改当前ema模型
#----------------------#
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)
#----------------------#
# 搭建训练框架
#----------------------#
model = models.AlexNet()
model = model.train()
#----------------------#
# 创建EMA模型
#----------------------#
ema = ModelEMA(model)
num_train_data = 100
batch_size = 10
epoch_step = num_train_data // batch_size
Init_epoch = 50
Total_epoch = 300
#----------------------#
# 记录EMA更新次数
#----------------------#
ema.updates = Init_epoch * epoch_step
#----------------------#
# 训练
#----------------------#
num_update = 0
for epoch in range(Init_epoch, Total_epoch):
for iter in range(epoch_step):
#----------------------#
# 根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
#----------------------#
ema.update(model)
num_update += 1
dict_update_num.append(num_update)
#----------------------#
# 验证
#----------------------#
#----------------------#
# 获取EMA eval mode,去并行化
#----------------------#
model = ema.ema
for epoch in range(Init_epoch, Total_epoch):
for iter in range(epoch_step):
pass
#----------------------#
# 保存权重
#----------------------#
#----------------------#
# 获取EMA模型的权重
#----------------------#
save_state_dict = ema.ema.state_dict()
path = "yourpath"
#torch.save(save_state_dict,path)
print('dene')
# -----------------------------------------------#
# save EMA decay figure
# -----------------------------------------------#
plt.figure()
plt.title('EMA decay during training')
plt.plot(dict_update_num, dict_decay, label="EMA decay")
plt.legend()
plt.grid()
plt.draw()
plt.savefig('EMA decay')
plt.show()
EMA decay 曲线变化图