train_dada
首先初始化权重
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
他的训练数据是imagenet的rgb,然后利用Perlin 噪声来模拟深度图像
angles
生成随机角度,用于确定每个网格点的梯度向量。gradients
根据angles
生成的二维单位向量表示梯度。tt
是在整个噪声图像区域重复gradients
,以确保每个网格细胞内部有相同的梯度向量。
计算噪声值:
dot
函数计算网格点与梯度向量的点积。n00
,n10
,n01
,n11
计算四个角的点积值。t
使用预定义的渐变函数fade
对网格坐标进行调整,以实现平滑过渡。- 最终的噪声值通过插值函数
lerp_np
(一个预定义的线性函数)
和上述点积值结合,生成整个噪声图。
应用旋转变换:
perlin_noise = rot(image=perlin_noise)
对生成的 Perlin 噪声图应用旋转变换。
设置阈值并应用:
threshold = torch.rand(1).numpy()[0] * beta + beta
计算阈值,用于后续的二值化处理。perlin_thr = np.where(np.abs(perlin_noise) > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
根据阈值二值化处理 Perlin 噪声,生成阈值化后的噪声图perlin_thr
。
生成归一化 Perlin 噪声:
norm_perlin = np.where(np.abs(perlin_noise) > threshold, perlin_noise, np.zeros_like(perlin_noise))
生成归一化的 Perlin 噪声图norm_perlin
,在噪声值低于阈值的地方设为0。
函数最终返回 norm_perlin
(归一化 Perlin 噪声)->perlin_norm、perlin_thr
(阈值化 Perlin 噪声)、原始的 perlin_noise
和使用的阈值 threshold
->p_thr。
随机缩放噪声:
生成一个 [0, 1] 范围内的随机数 beta
image = beta * perlin_noise:使用这个随机数 beta 对 Perlin 噪声进行缩放,模拟不同深度的变化。
随机平移噪声:
生成另一个 [0, 1] 范围内的随机数 beta2。
image = image + (beta2 * (1 - beta)):将缩放后的噪声图进一步平移,以增加深度图的变化性。
裁剪和调整深度图:
image = np.clip(image, 0.0, 1.0) 确保深度值在 [0, 1] 范围内。
image = np.expand_dims(image, 2) 增加一个维度,使图像从二维变为三维。
image = np.transpose(image, (2, 0, 1)) 调整深度图的维度顺序,适配 PyTorch 的要求。
所以这个深度图是通过模拟而非直接从现实世界的深度传感器获取,和rgb也没有半点关系,完全随机出来的
训练过程中,VectorQuantizerEMA的两个实例里多了几步
首先补充说明一下几个类内初始化变量
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
这种初始化定义方式是为了准备在后续过程中使用指数移动平均(EMA)来更新聚类大小
register_buffer是nn.Module的一个方法
用于注册一个不需要梯度的缓冲区
这是因为_ema_cluster_size不是模型的参数(不需要学习)
但它是模型的一部分
并且在模型的训练过程中会更新
通过注册为缓冲区
确保在模型保存和加载时_ema_cluster_size也会被保存和加载。
和requires_grad=False的区别可能在于register_buffer的变量一定可以被加载或保存
self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
self._ema_w.data.normal_()
self._decay = decay
self._epsilon = epsilon
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
从获得quantized 1,48,48,256开始
# Use EMA to update the embedding vectors
if self.training:
self._ema_cluster_size = self._ema_cluster_size * self._decay + \
(1 - self._decay) * torch.sum(encodings, 0)
这步更新是反映每个嵌入向量当前被选中的频率
这是通过将当前的_ema_cluster_size乘以衰减因子_decay
然后加上新的观察值(由encodings的求和得到)来实现的。
# Laplace smoothing of the cluster size
n = torch.sum(self._ema_cluster_size.data)
self._ema_cluster_size = (
(self._ema_cluster_size + self._epsilon)
/ (n + self._num_embeddings * self._epsilon) * n)
对_ema_cluster_size应用拉普拉斯平滑(加一个小的常数_epsilon)
以避免任何聚类大小变为零
这个操作确保了即使某些聚类在当前批次中未被观察到(即其聚类大小为零)
它们的大小也会被设置为一个小的非零值
这有助于增加数值稳定性。
dw = torch.matmul(encodings.t(), flat_input)
self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
更新嵌入向量,确保了嵌入向量的更新考虑了不同嵌入被选择的频率
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
最后输出就要前仨
输出loss_b, loss_t, recon_out
loss_vq = loss_b + loss_t
recon_depth = recon_out[:,:1,:,:]
recon_rgb = recon_out[:,1:,:,:]
# Using L1 loss may work better and lead to improved reconstructions
#l2_recon_d_loss = torch.mean((depth_image - recon_depth)**2)
l2_recon_d_loss = torch.mean(torch.abs(depth_image - recon_depth))
#l2_recon_rgb_loss = torch.mean((rgb_image - recon_rgb)**2)
l2_recon_rgb_loss = torch.mean(torch.abs(rgb_image - recon_rgb))
#l2_recon_loss = torch.mean((model_in - recon_out)**2)
l2_recon_loss = torch.mean(torch.abs(model_in - recon_out))
recon_loss = l2_recon_loss + loss_vq
loss = recon_loss
相当于一个mse损失,一个l1范数损失
train_dsr
训练集相比测试集:
- 初始化时通过遍历整个训练集,直接定义全局的im_max和im_min,但是其实后面根本没有用上
(???这岂不是浪费启动时间)
- 没有读取gt
(毕竟还是无监督)
, - plane_mask和生成的perlin噪声相乘得到msk作为getitem的anomaly_mask。
随机取一个0~1之间的no_anomaly值,如果>0.5,那么anomaly_mask全为0 - 最后rgb和深度图分别随机旋转-15, 15之间的角度
训练过程:
- 从拼接深度和rgb得到in_image开始,到sub_res_model_hi之前为止,这段过程是不计算梯度的
(其实也就是DiscreteLatentModelGroups这部分,但是中间加了些步骤,所以通过model.xxx的形式实现)
- ①得到quantized_t之后先经过下面的generate_fake_anomalies_joined函数得到anomaly_embedding_lo,然后分别和quantized_t经过upsample_t得到up_quantized_t和up_quantized_t_real,二者都是16,256,96,96。
然后这二者再分别依次和enc_b拼接、经过_pre_vq_conv_bot、经过_vq_vae_bot,最终分别得到quantized_b和quantized_b_real,二者都是16,256,96,96。
调整随机嵌入矩阵形状: 将 random_embeddings
重塑并转置,使其形状与输入嵌入 embeddings
一致。16,48,48,256
可选的嵌入打乱: 以随机概率使用 shuffle_patches
函数打乱嵌入块,增加异常嵌入的多样性。
返回重新组合后的图像,它的形状与输入相同,但每个图像内部的块已经被随机打乱。
应用异常掩码:
函数返回 anomaly_embedding->anomaly_embedding_lo
,它在原始特征中引入了模拟的异常,可用于训练模型进行异常检测。16,256,48,48
第二次的输入是zb,quantized_b,embedder_hi._embedding.weight,anomaly_mask,anomaly_strength_hi= (torch.rand(in_image.shape[0]) * 0.90 + 0.10).cuda()
- ②然后再把quantized_b和quantized_b_real分别输入给generate_fake_anomalies_joined,分别输出anomaly_embedding和anomaly_embedding_hi_usebot
use_both = torch.randint(0, 2,(in_image.shape[0],1,1,1)).cuda().float()
use_lo = torch.randint(0, 2,(in_image.shape[0],1,1,1)).cuda().float()
use_hi = (1 - use_lo)
anomaly_embedding_hi_usebot = generate_fake_anomalies_joined(zb_real,
quantized_b_real,
embedder_hi._embedding.weight,
anomaly_mask, strength=anomaly_strength_hi)
anomaly_embedding_lo_usebot = quantized_t
anomaly_embedding_hi_usetop = quantized_b_real
anomaly_embedding_lo_usetop = anomaly_embedding_lo
anomaly_embedding_hi_not_both = use_hi * anomaly_embedding_hi_usebot + use_lo * anomaly_embedding_hi_usetop
anomaly_embedding_lo_not_both = use_hi * anomaly_embedding_lo_usebot + use_lo * anomaly_embedding_lo_usetop
anomaly_embedding_hi = (anomaly_embedding * use_both + anomaly_embedding_hi_not_both * (1.0 - use_both)).detach().clone()
anomaly_embedding_lo = (anomaly_embedding_lo * use_both + anomaly_embedding_lo_not_both * (1.0 - use_both)).detach().clone()
anomaly_embedding_hi_copy = anomaly_embedding_hi.clone()
anomaly_embedding_lo_copy = anomaly_embedding_lo.clone()
- 开始计算梯度后,下面的部分也有变动
recon_feat_hi, recon_embeddings_hi, _ = sub_res_model_hi(anomaly_embedding_hi_copy, embedder_hi)
recon_feat_lo, recon_embeddings_lo, _ = sub_res_model_lo(anomaly_embedding_lo_copy, embedder_lo)
这里之前分别输入的是embeddings_hi和embeddings_lo
recon_feat_xx也就是unet部分的输出output
# Reconstruct the image from the anomalous features with the general appearance decoder
up_quantized_anomaly_t = model.upsample_t(anomaly_embedding_lo)
quant_join_anomaly = torch.cat((up_quantized_anomaly_t, anomaly_embedding_hi), dim=1)
recon_image_general = model._decoder_b(quant_join_anomaly)
虽然model.upsample_t和model._decoder_b在torch.no_grad()之外,
但是model并没有计入优化器
所以就算计算了梯度也不会更新它的权重
# Reconstruct the image from the reconstructed features
# with the object-specific image reconstruction module
up_quantized_recon_t = model.upsample_t(recon_embeddings_lo)
quant_join = torch.cat((up_quantized_recon_t, recon_embeddings_hi), dim=1)
recon_image_recon = model_decode(quant_join)
out_mask = decoder_seg(recon_image_recon,recon_image_general)
out_mask_sm = torch.softmax(out_mask, dim=1)
# Calculate losses
loss_feat_hi = torch.nn.functional.mse_loss(recon_feat_hi, quantized_b_real.detach())
loss_feat_lo = torch.nn.functional.mse_loss(recon_feat_lo, quantized_t.detach())
loss_l2_recon_img = torch.nn.functional.mse_loss(in_image, recon_image_recon)
total_recon_loss = loss_feat_lo + loss_feat_hi + loss_l2_recon_img*10
# Resize the ground truth anomaly map to closely match the augmented features
down_ratio_x_hi = int(anomaly_mask.shape[3] / quantized_b.shape[3])
anomaly_mask_hi = torch.nn.functional.max_pool2d(anomaly_mask,
(down_ratio_x_hi, down_ratio_x_hi)).float()
anomaly_mask_hi = torch.nn.functional.interpolate(anomaly_mask_hi, scale_factor=down_ratio_x_hi)
down_ratio_x_lo = int(anomaly_mask.shape[3] / quantized_t.shape[3])
anomaly_mask_lo = torch.nn.functional.max_pool2d(anomaly_mask,
(down_ratio_x_lo, down_ratio_x_lo)).float()
anomaly_mask_lo = torch.nn.functional.interpolate(anomaly_mask_lo, scale_factor=down_ratio_x_lo)
anomaly_mask = anomaly_mask_lo * use_both + (
anomaly_mask_lo * use_lo + anomaly_mask_hi * use_hi) * (1.0 - use_both)
#anomaly_mask = anomaly_mask * anomaly_type_sum
# Calculate the segmentation loss
segment_loss = loss_focal(out_mask_sm, anomaly_mask)
Focal Loss主要用于解决类别不平衡的问题
在像素级的异常检测任务中,
"类别不平衡"问题通常是指异常区域像素与正常区域像素之间的不平衡。
它通过减少对易分类对象的关注(通过降低它们的损失贡献)
来提高模型对困难或少见类别的关注度。
Focal Loss 的公式是-1 * alpha * (1 - pt)^gamma * log(pt)
其中 pt 是模型对正确类别的预测概率
alpha 和 gamma 是调节损失贡献的超参数。
l1_mask_loss = torch.mean(torch.abs(out_mask_sm - torch.cat((1.0 - anomaly_mask, anomaly_mask), dim=1)))
如果模型的预测在某些像素点上极度不准确(即预测值与真实值之间的差异很大),
L1 损失不会像平方差损失(L2 损失)那样对这些错误赋予过高的权重,
从而避免让模型过度适应那些极端的误差。
segment_loss = segment_loss + l1_mask_loss
# L1 is different than in the paper but may improve results in some cases