我是Deepeeping和Pytorch的初学者。
我不明白在使用SWA时如何使用BatchNormalization。
pytorch.org在https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/中说:
请注意,权重的SWA平均值从未用于
训练期间的预测,因此批次归一化层可以
重置重置后没有计算激活统计信息
带有opt.swap_swa_sgd()
的模型的权重
这意味着它适合使用SWA后添加BatchNormalization层吗?
# it means, in my idea
#for example
opt = torchcontrib.optim.SWA(base_opt)
for i in range(100):
opt.zero_grad()
loss_fn(model(input), target).backward()
opt.step()
if i > 10 and i % 5 == 0:
opt.update_swa()
opt.swap_swa_sgd()
#save model once
torch.save(model,"swa_model.pt")
#model_load
saved_model=torch.load("swa_model.pt")
#it means adding BatchNormalization layer??
model2=saved_model
model2.add_module("Batch1",nn.BatchNorm1d(10))
# decay learning_rate more
learning_rate=0.005
optimizer = torch.optim.SGD(model2.parameters(), lr=learning_rate)
# train model again
for epoch in range(num_epochs):
loss = train(train_loader)
val_loss, val_acc = valid(test_loader)
感谢您的答复。
按照您的建议,
我尝试使示例模型添加optimizer.bn_update()
# add optimizer.bn_update() to model
criterion = nn.CrossEntropyLoss()
learning_rate=0.01
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
def train(train_loader):
#mode:train
model.train()
running_loss = 0
for batch_idx, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(images)
#loss
loss = criterion(outputs, labels)
running_loss += loss.item()
loss.backward()
optimizer.step()
optimizer.swap_swa_sgd()
train_loss = running_loss / len(train_loader)
return train_loss
def valid(test_loader):
model.eval()
running_loss = 0
correct = 0
total = 0
#torch.no_grad
with torch.no_grad():
for batch_idx, (images, labels) in enumerate(test_loader):
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
val_loss = running_loss / len(test_loader)
val_acc = float(correct) / total
return val_loss, val_acc
num_epochs=30
loss_list = []
val_loss_list = []
val_acc_list = []
for epoch in range(num_epochs):
loss = train(train_loader)
val_loss, val_acc = valid(test_loader)
optimizer.bn_update(train_loader, model)
print('epoch %d, loss: %.4f val_loss: %.4f val_acc: %.4f'
% (epoch, loss, val_loss, val_acc))
# logging
loss_list.append(loss)
val_loss_list.append(val_loss)
val_acc_list.append(val_acc)
# optimizer.bn_updata()
optimizer.bn_update(train_loader, model)
# go on evaluating model,,,
最佳答案
该文档告诉您的是,由于SWA计算权重的平均值,但是在训练期间这些权重并未用于预测,因此批归一化层将看不到那些权重。这意味着他们尚未为他们计算出各自的统计信息(因为他们从未能够),这很重要,因为权重是在实际预测期间(即不在训练期间)使用的。
这意味着,他们假设您的模型中有批量归一化层,并希望使用SWA对其进行训练。 由于上述原因,这(或多或少)不简单。
一种方法如下:
要计算激活统计信息,您只需在培训结束后使用SWA模型就可以对培训数据进行前向传递。
或者,您可以使用其帮助程序类:
在SWA
类中,我们提供了一个辅助函数opt.bn_update(train_loader, model)
。通过向前传递train_loader
数据加载器,它更新了模型中每个批次归一化层的激活统计信息。您只需在训练结束时调用一次此功能。
如果您使用的是Pytorch的 DataLoader
class,则只需将模型(训练后)和训练加载器提供给bn_update
函数,即可为您更新所有批次归一化统计信息。您只需在训练结束时调用一次此功能。
进行步骤:
opt.bn_update(train_loader, model)
并提供训练有素的模型