1. 简介

化学提及到蒸馏:加热液体汽化,再使蒸气液化,从而除去其中的杂质,获得所需要的产品

深度学习笔记(52) 知识蒸馏-LMLPHP
知识蒸馏也比较相似
利用一个大模型(教师模型)萃取知识,将其提取(迁移)到一个小模型(学生模型)上

深度学习笔记(52) 知识蒸馏-LMLPHP

通过上述的压缩已训练好的大模型方式,知识蒸馏就可以轻量化神经网络,得到小模型
然后就可以部署在边缘计算设备,实现算法应用落地


2. 知识的表示与迁移

深度学习笔记(52) 知识蒸馏-LMLPHP
在训练一个虎的识别时,通过hard targets的标签进行训练,之后将图片出入模型进行识别后得到一个soft targets
从soft targets中可以看出虎的概率是比较大的,识别为猫和车的概率都是比较小的
同样可以看出不同类别的相关性,如虎和猫存在一定相似性,而和车关联就比较少了
因此soft targets包含了更多的信息,如非正确类别概率的相对大小

那么可以用hard targets的标签训练教师模型输出soft targets,再将soft targets作为标签训练学生模型


3. 蒸馏温度T

如果对soft Target的输出信息还不满意,可以新增一个 蒸馏温度T
蒸馏温度T使用在softmax函数中,修正输出标签

s o f t m a x ( Z i ) = e Z i ∑ 1 C e Z c softmax(Z_{i}) = \frac{e^{Z_{i}}}{\sum_{1}^{C}e^{Z_{c}}} softmax(Zi)=1CeZceZi > > > q = e Z i / T ∑ 1 C e Z c / T q = \frac{e^{Z_{i}/T}}{\sum_{1}^{C}e^{Z_{c}/T}} q=1CeZc/TeZi/T

当T=1时,还是原始的softmax函数
当T=3时,可以看相关分类的相似度降低了,其他不相关分类的相似度有所增加
深度学习笔记(52) 知识蒸馏-LMLPHP
当T变大,每个分类所获得的相似度就越平均,越小会发现类别的相似度会很大


4. 知识蒸馏过程

深度学习笔记(52) 知识蒸馏-LMLPHP1. 选用一个已经训练完成的教师模型,然后输入训练集数据,进行数据推算且调整蒸馏温度T=t 的softmax,得到 soft labels
2. 再把训练集数据输入训练学生模型,进行数据推算,进行数据推算且调整蒸馏温度T=t 的softmax,得到 soft predictions,然后和教师模型的 soft labels 进行相似度比较求 蒸馏损失 distillation loss
3. 学生模型进行数据推算时还输出蒸馏温度T=1 的原softmax,得到 hard predictions,与训练集数据标签 hard labels 进行相似度比较求 学生损失 student loss
4. 按系数 α α α β β β 对 学生损失 student loss 和 蒸馏损失 distillation loss 进行求和得到 总损失 total loss

这样学生模型既考虑了标准标签,也考虑了教师模型的结果


4.1. student loss

学生损失 student loss 比较简单
上述提到,就是学生模型输出 hard predictions 和 数据标签 hard labels 进行使用 交叉熵 相似度损失
其他类别标签均为0,目标类别为1,则有 s t u d e n t   l o s s = − l o g ( x i ) = − l o g ( s o f t m a x ( Z i ) ) = − l o g ( e Z i ∑ 1 C e Z c ) student \ loss = -log(x_i)= -log(softmax(Z_{i})) = -log(\frac{e^{Z_{i}}}{\sum_{1}^{C}e^{Z_{c}}}) student loss=logxi=log(softmax(Zi))=log(1CeZceZi)


4.2. distillation loss

与学生损失 student loss 的区别就是其他类型的标签概率不再为0,且蒸馏温度T存在变化
需要每个类别一对一的求损失,再求和

d i s t i l l a t i o n   l o s s = − 1 N ∑ j = 1 N ∑ i = 1 C y i j ∗ l o g ( x i j ) distillation \ loss = - \frac{1}{N}\sum_{j=1}^{N}\sum_{i=1}^{C}y_{ij}*log(x_{ij}) distillation loss=N1j=1Ni=1Cyijlog(xij)

以上面提及到的 虎/猫/车 分类为例,
假设 教师模型 蒸馏温度T=t 的softmax 结果为:0.86 / 0.12 / 0.02
假设 学生模型 蒸馏温度T=t 的softmax 结果为:0.66 / 0.22 / 0.12

那么 蒸馏损失 = − [ 0.86 ∗ l o g ( 0.66 ) + 0.12 ∗ l o g ( 0.22 ) + 0.02 ∗ l o g ( 0.12 ) ] = -[0.86*log(0.66)+0.12*log(0.22)+0.02*log(0.12)] =[0.86log(0.66)+0.12log(0.22)+0.02log(0.12)]


也可以参考下图, i i i 代表当前样本编号
深度学习笔记(52) 知识蒸馏-LMLPHP


5. 背后的机理

读万卷书不如行万里路,行万里路不如阅人无数,阅人无数不如 名师指路

深度学习笔记(52) 知识蒸馏-LMLPHP

绿色是教师模型求解空间(比较大),蓝色是学生模型求解空间(比较小)
红色为教师模型的答案空间,浅绿色为学生模型的答案空间
橙色是在知识蒸馏的情况下得到的答案空间也是最优解

如果不加引导学生模型会在自己的求解空间中试探着寻找,最后找到浅绿色的答案
在增加了教师模型之后,学生模型查找求解空间时,教师模型会给予指导
让学生模型得到的答案更准确,或者让其往教师模型的答案空间靠

所以知识蒸馏会得到更轻便且效果好的模型


6. 应用场景

  • 模型压缩
  • 优化训练、防止过拟合(潜在的正则化)
  • 无限大、无监督数据集的数据挖掘
  • 少样本、零样本学习

知识蒸馏可以看成是迁移学习的一个特例
二者的相同点都是想从大数据、大模型学习知识到目标数据上,以提高模型在目标数据上的表现

不同的是,迁移学习是一个宏大的概念
而知识蒸馏就单纯指的是通过最小化教师模型与学生模型的不同
以达到较小的学生模型可以模拟逼近教师模型的作用
因此,知识蒸馏是实现迁移学习的一种有效形式


7. 代码实现

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader


class StudentModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=7, stride=7)
        self.fc1 = nn.Linear(1 * 4 * 4, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x


class TeacherModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(TeacherModel, self).__init__()
        self.out_channel_layer1 = 64
        self.out_channel_layer2 = 128
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=self.out_channel_layer1, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=self.out_channel_layer1, out_channels=self.out_channel_layer2, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(self.out_channel_layer2 * 7 * 7, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = F.dropout(x, p=0.5)
        x = self.fc2(x)
        return x


def print_train(step_now, train_loader_len, epoch_now, epochs, lose_item):
    step_schedule_num = int(40 * step_now / train_loader_len)
    print("\r", end="")
    print("Train epoch: {}/{}\t step: {}/{} [{}{}] - loss: {:.5f}".format(epoch_now, epochs,
                                                                          step_now, train_loader_len,
                                                                          ">" * step_schedule_num,
                                                                          "-" * (40 - step_schedule_num),
                                                                          lose_item), end="")


def print_test(epoch_now, epochs, acc):
    print(("Test  epoch: {}/{}\t Accuracy:{:.4f}").format(epoch_now, epochs, acc))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1.载入训练集和测试集
train_dataset = torchvision.datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root="dataset/", train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

train_loader_len = len(train_loader)
train_loader_dataset_len = len(train_loader.dataset)

# 2.设置教师模型训练
print("Teacher model train.")
model = TeacherModel().to(device)
loss_function = nn.CrossEntropyLoss()
Learning_Rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=Learning_Rate)

epochs = 5  # 训练5轮
for epoch in range(epochs):
    model.train()
    step_now, losses = 0, []
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)
        # 优化器梯度初始化为零
        optimizer.zero_grad()
        # 前向预测
        preds = model(data)
        # 计算损失函数
        loss = loss_function(preds, targets)
        # 反向传播,优化权重
        loss.backward()
        # 结束一次前传+反传之后,更新优化器参数
        optimizer.step()
        # 显示进度
        step_now += 1
        losses.append(loss.item())
        print_train(step_now, train_loader_len, epoch+1, epochs, sum(losses)/len(losses))
    print()

    # 测试集上评估性能
    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct / num_samples).item()
    print_test(epoch+1, epochs, acc)

# 训练完成保存教师模型
teacher_model = model

# 3.设置普通小模型训练
print("Mini model train.")
model = StudentModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=Learning_Rate)

epochs = 5
for epoch in range(epochs):
    model.train()
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        preds = model(data)
        loss = criterion(preds, targets)
        loss.backward()
        optimizer.step()

    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
            acc = (num_correct / num_samples).item()
    print_test(epoch+1, epochs, acc)


# 4.设置学生模型训练
print("Student model train.")
model = StudentModel().to(device)
temp = 5  # 蒸馏温度
hard_loss_alpha = 0.3  # hard_loss权重
hard_loss = nn.CrossEntropyLoss()
soft_loss = nn.KLDivLoss(reduction="batchmean")
optimizer = torch.optim.Adam(model.parameters(), lr=Learning_Rate)

epochs = 5
for epoch in range(epochs):
    model.train()
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()

        # 学生模型预测
        student_preds = model(data)
        student_loss = hard_loss(student_preds, targets)

        # 教师模型预测
        teacher_model.eval()
        with torch.no_grad():
            teacher_preds = teacher_model(data)

        # 计算蒸馏后的预测结果
        distillation_loss = soft_loss(
            F.softmax(student_preds/temp, dim=1),
            F.softmax(teacher_preds/temp, dim=1)
        )

        # 将 hard_loss 和 soft_loss 加权求和
        loss = hard_loss_alpha * student_loss + (1-hard_loss_alpha) * distillation_loss

        loss.backward()
        optimizer.step()

    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
    print_test(epoch+1, epochs, acc)


# Teacher model train.
# Train epoch: 1/1         step: 1875/1875 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] - loss: 0.21462
# Test  epoch: 1/1         Accuracy:0.9788
# Train epoch: 2/5         step: 1875/1875 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] - loss: 0.06607
# Test  epoch: 2/5         Accuracy:0.9860
# Train epoch: 3/5         step: 1875/1875 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] - loss: 0.04896
# Test  epoch: 3/5         Accuracy:0.9863
# Train epoch: 4/5         step: 1875/1875 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] - loss: 0.04104
# Test  epoch: 4/5         Accuracy:0.9883
# Train epoch: 5/5         step: 1875/1875 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] - loss: 0.03268
# Test  epoch: 5/5         Accuracy:0.9886
# Mini model train.
# Test  epoch: 1/5         Accuracy:0.3423
# Test  epoch: 2/5         Accuracy:0.5190
# Test  epoch: 3/5         Accuracy:0.6088
# Test  epoch: 4/5         Accuracy:0.6365
# Test  epoch: 5/5         Accuracy:0.6584
# Student model train.
# Test  epoch: 1/5         Accuracy:0.3597
# Test  epoch: 2/5         Accuracy:0.5896
# Test  epoch: 4/5         Accuracy:0.6690
# Test  epoch: 4/5         Accuracy:0.7096
# Test  epoch: 5/5         Accuracy:0.7286

数字分类是比较简单的分类,学生模型需要比较弱或差时对比才明显


谢谢

11-07 06:40