在本文中,将介绍一些关于小样本学习的相关知识,以及介绍如何使用pytorch构建一个原型网络(Prototypical Networks),并应用于miniImageNet 数据集。
实验环境:
pytorch:1.11.0
代码地址:https://github.com/xiaohuiduan/deeplearning-study/tree/main/小样本学习
小样本学习引入
在这一节将简要的对小样本学习(FSL)相关的知识进行介绍。由于我并不是专门研究小样本的(我学习FSL也只是为了完成我的课程作业),因此,如果本文存在任何问题,欢迎进行批评指正😎。
首先将模型看成一个黑盒子,不去关注它的内部结构,而是关注其input和output。
在分类模型中,input是一张猫or狗的图片,output则为0/1(代表其为猫或者狗;实际上输出的是两者的预测概率)。
但是关于上面的模型,存在一个问题,训练这样的模型需要大量的数据,根据@petewarden的说法,训练一个分类图片的网络,每个类别需要大约1000张。但是,很多场景我们并没有足够的数据进行训练,也就是说数据集的样本比较少,这时候针对小样本数据集可以有两种处理方式:
数据增强:比如说对图像进行旋转,裁剪等等。
数据建模:如使用小样本学习的方法进行对数据进行建模。
小样本学习是什么
以VGG模型预测猫狗分类举例,模型对于一张新的图片的预测可以形象的解释为:
但是在小样本学习中,不是这样的,小样本学习对于一张新的图片的预测可以形象的解释为:
小样本的分类模型与传统的深度学习分类模型(如VGG)有着不同,这里引用论文的一句话:
也就是说,小样本分类并不是像VGG一样,test的数据是以前训练过的类别,对于小样本分类来说,进行test的数据是一些新的类,并且这些类别的样本很少,因此,没法对其进行re-training,否则会造成过拟合。
小样本学习方法
小样本学习可以认为是一个N-way K-shot的分类问题(不确定是不是所有的小样本分类任务都被认为是N-way K-shot分类问题)。
无论对于测试集还是训练集,都需要进行如下的划分,将数据集划分为两个部分:左边DataSet代表数据集,右边分别代表Support set,右边代表Query Set。在train或者test数据集中:
- 首先对于所有类别,随机选择其中N(图中N=3)个类别(图中,选择了类别2,类别3和类别5)。
- 在step 1中选择的类别样本中,随机选择K(图中K=3)个样本(绿色的部分),构成Support Set。也就是说,Support Set中拥有K*N个样本。
- 然后在所选择类别的剩余样本中,选择X(这里X=1)个样本(红色的部分),构成Query Set。也就是说,Query Set中拥有X*N个样本。
在训练集中,以上步骤构成的Support Set和Query Set会被input到Model中进行训练,称之为一个eposides(相当于mini-batch)。
对于模型来说,其目的则为判断Query Set中的样本与哪一个支持集最相似。
原型网络(Prototype Network)
原理简述
Prototype Network的原理很简单,可以简单的概括为:将support set中的图片\(data_1,data_2,\cdots,data_n\)映射到某一个向量空间\(c_1,c_2,\cdots,c_n\);对于Query set中的某一张图片\(query_i\)使用同一个映射函数,也映射到到向量空间\(x_i\),然后判断\(x_i\)与\(c_1,c_2,\cdots,c_n\)的距离(余弦距离or欧氏距离),选择距离最近向量所对应的类别作为\(query_i\)所属的类别。
示意图如下所示:
如果了解NLP中word2vec的话,会发现,其与word2vec的Embedding思想是很相似的。
算法流程
算法流程图如下所示,红色框和绿色框中的过程已经在前文进行介绍,这里主要是来介绍一下loss的计算方式。
实际上,loss的计算方式就是一个交叉熵损失函数,pytorch中CrossEntropyLoss的计算方法如下所示,class代表\(x\)实际所属类别\(x[j]\)代表模型对于\(x\)所属类别\(j\)的概率预测。
但是,在算法流程图中,大家会发现,其loss计算的正负号刚好与上面公式中的相反,解释如下:
以上,便是原型网络的算法流程。
算法实现
数据集处理
mini-Imagenet是一个专门用于训练小样本学习的训练集,数据集中一共有100个类别,每个类别600张图片,一共有60000张图片。数据集可以从mini-ImageNet | Kaggles上面下载。在下载文件中,一共有4个文件,一个是数据集图片的压缩包,另外3个csv文件分别代表了训练、验证和测试集相关的信息。其中训练集有64个类,验证集16个类,测试集20个类。
csv的部分数据如下所示,filename代表了图片的名字,label代表了图片对应的标签。
因此,可以构建一个label所对应filename的字典
def read_csv(csv_path):
dict = collections.defaultdict(list)
df = pd.read_csv(csv_path)
for index,row in df.iterrows():
dict[row["label"]].append(row["filename"])
return dict
train_dict = read_csv(train_csv_path)
val_dict = read_csv(val_csv_path)
test_dict = read_csv(test_csv_path)
同时,构建data与labels的对应关系:
from PIL import Image
import numpy as np
from torchvision import transforms
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
resize_transform = transforms.Resize(84) # 提前对图片进行缩放,以节省内存空间,将最短的边变成84
def build_data(data_dict):
datas = []
labels = []
label_index = 0
for label in data_dict.keys(): # 对图片的标签进行迭代
for path in data_dict[label]: # 对标签对应的文件名进行迭代
img_path = os.path.join(img_root_dir,path)
img = Image.open(img_path) # 读取文件
img = resize_transform(img) # 进行缩放
datas.append(img)
labels.append(label_index)
label_index += 1
return {"datas":datas,"labels":labels}
随机产生Support set 和 Query set
在下面代码中,CategoriesSampler
的作用是为了产生index,然后供给dataloader使用。
class CategoriesSampler():
"""
目的是为了随机产生K_way*(N_support+N_query)个图片对应的index
"""
def __init__(self, data, n_batch, K_way, N_per):
self.n_batch = n_batch
self.K_way = K_way
self.N_per = N_per
labels = np.array(data["labels"]) # [0,0,0,0,1,1,1,1,2,2,2,2……]
self.index = [] # 记录label对应的索引位置
for i in range(max(labels)+1):
ind = np.argwhere(labels == i).reshape(-1)
self.index.append(torch.from_numpy(ind))
def __len__(self):
return self.n_batch
def __iter__(self):
for i_batch in range(self.n_batch):
batch = []
classes = torch.randperm(len(self.index))[:self.K_way] # 随机选择K个类别构成support set和query set
for c in classes:
l = self.index[c] # 类别c对应的图片数组的索引,如 l = [5,6,7,8,9]
pos = torch.randperm(len(l))[:self.N_per] # 如 pos = [4,1,0]
batch.append(l[pos]) # 如 l[pos] = [9,6,5]
batch = torch.stack(batch).reshape(-1)
yield batch
同时,定义Dataset类,如下:
class MiniImageNet(Dataset):
def __init__(self, data):
self.datas = data["datas"]
self.labels = data["labels"]
self.transform = transforms.Compose([
transforms.CenterCrop(84),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.datas)
def __getitem__(self, i):
img, label = self.datas[i], self.labels[i]
return self.transform(img), label
使用示例如下,在MiniImageNet的__getitem__
函数中,其参数i
由CategoriesSampler
的__iter__
函数所产生。
batch_sampler = CategoriesSampler(datas,eval_step,K_way,N_shot+N_query)
data_loader = DataLoader(dataset=data_set, batch_sampler=batch_sampler,
num_workers=16, pin_memory=True)
实际上,上面的代码就是为了实现如下图所示的功能:
Embedding网络构建
前文说到,需要将图片进行向量化表示,在下面的代码中,便可以将一张图片(shape=\(3\times84\times84\))变成一个1600维的向量。(网络结构来自于论文)
class CNN_Net(nn.Module):
"""
用于特征提取
"""
def __init__(self, input_dim):
super(CNN_Net, self).__init__()
self.input_dim = input_dim
def conv_block(in_channel,out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3,padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.encoder = nn.Sequential(
conv_block(input_dim,64),
conv_block(64,64),
conv_block(64,64),
conv_block(64,64),
)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
return x
损失函数
关于计算损失的关键函数如下所示(参考了论文作者的源代码):
def cal_euc_distance(self, query_z, center,K_way, N_query):
"""
计算query_z与center的距离
query_z : (K_way*N_query,z_dim)
center : (K_way,z_dim)
"""
center = center.unsqueeze(0).expand(
K_way*N_query, K_way, self.z_dim) # (K_way*N_query,K_way,z_dim)
query_z = query_z.unsqueeze(1).expand(
K_way*N_query, K_way, self.z_dim) # (K_way*N_query,K_way,z_dim)
return torch.pow(query_z-center, 2).sum(2) # (K_way*N_query,K_way)
def loss_acc(self, query_z, center, K_way, N_query):
"""
计算loss和acc
query_z : (K_way*N_query,z_dim)
center : (K_way,z_dim)
"""
target_inds = torch.arange(0, K_way).view(K_way, 1).expand(
K_way, N_query).long().to(self.device) # shape=(K_way, N_query)
distance = self.cal_euc_distance(query_z, center,K_way, N_query) # (K_way*N_query,K_way)
predict_label = torch.argmin(distance, dim=1) # (K_way*N_query) 预测出来的label
acc = torch.eq(target_inds.contiguous().view(-1),
predict_label).float().mean() # 准确率
loss = F.log_softmax(-distance, dim=1).view(K_way,
N_query, K_way) # (K_way,N_query,K_way)
loss = - \
loss.gather(dim=2, index=target_inds.unsqueeze(2)).view(-1).mean()
return loss, acc
def set_forward_loss(self, K_way, N_shot, N_query,sample_datas):
"""
sample_datas: shape(K_way*(N_shot+N_query),3,84,84)
"""
z = self.cnn_net(sample_datas) # shape=(K_way*(N_shot+N_query),z_dim) ,将support set和query set都进行向量化表示
z = z.view(K_way,N_shot+N_query,-1) # shape = (K_way,N_shot+N_query,1600)
support_z = z[:,:N_shot] # support set的向量化表示 shape=(K_way,N_shot,1600)
query_z = z[:,N_shot:].contiguous().view(K_way*N_query,-1) # Query set的向量化表示 shape=(K_way*N_query,1600)
center = torch.mean(support_z, dim=1) # 计算support set的向量均值,shape=(K_way,1600)
return self.loss_acc(query_z, center,K_way,N_query)
关于实验中具体的参数设计,可以参考论文或者Github上面的源代码。在原论文中,对于实验的设计讲得非常清楚。
实验结果
下面的表格为测试集的acc(当验证集acc为最大值时测试集所对应的acc):
总结
总的来说,原型网络是一个容易理解的网络模型,思想简单,易于实现。