一、导入模块
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn
from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
这段代码导入了一些PyTorch和SciPy中的模块和函数,以及自定义模块中的一些函数。
1. `import torch`: 导入PyTorch库,用于深度学习任务。
2. `from scipy.optimize import linear_sum_assignment`: 从SciPy库中导入`linear_sum_assignment`函数,它用于解决线性求和分配问题,通常用于匈牙利算法,用于在最优的方式下分配任务。
3. `from torch import nn`: 从PyTorch库中导入神经网络模块,`nn` 模块包含了构建神经网络层的类和函数。
4. `from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou`: 从自定义模块 `util.box_ops` 中导入 `box_cxcywh_to_xyxy` 和 `generalized_box_iou` 函数。这些函数可能是与处理边界框(bounding box)有关的工具函数,用于转换边界框坐标格式以及计算边界框之间的交并比(IoU)等操作。
二、 HungarianMatcher
模块
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
@torch.no_grad()
def forward(self, outputs, targets):
""" Performs the matching
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -out_prob[:, tgt_ids]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()
sizes = [len(v["boxes"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
这段代码定义了一个名为 HungarianMatcher
的PyTorch模块,该模块用于计算网络输出和目标之间的匹配(assignment)。
这个模块主要用于目标检测任务中,其中网络输出(predictions)和目标(targets)是需要匹配的。
1、__init__()函数
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
这段代码定义了一个名为 HungarianMatcher
的PyTorch模块,用于执行目标检测中的匹配操作。以下是代码的详细解释:
-
class HungarianMatcher(nn.Module):
:定义了一个继承自nn.Module
的Python类,表示匈牙利匹配器。 -
文档字符串(Docstring):这是类的注释,提供了对类的简要描述和用途。
-
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
:初始化方法,用于创建匈牙利匹配器的实例。它接受三个可选参数,分别是:cost_class
:分类错误在匹配成本中的相对权重,默认为1。cost_bbox
:边界框坐标错误在匹配成本中的相对权重,默认为1。cost_giou
:GIOU损失在匹配成本中的相对权重,默认为1。
-
super().__init__()
:调用父类的构造函数以正确初始化模块。 -
self.cost_class
,self.cost_bbox
,self.cost_giou
:将传入的三个参数值存储在模块的实例变量中,以便在后续的计算中使用。 -
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
:断言语句,用于确保三个成本权重中至少有一个不为零。如果三个成本都为零,将引发AssertionError异常,以防止不合理的输入。
这个类的主要目的是在目标检测任务中,根据网络的预测结果和目标(ground-truth)之间执行最优匹配,以便计算损失和优化目标检测模型。成本权重用于调整分类错误、边界框坐标错误和GIOU损失之间的相对重要性,以满足特定任务的需求。匈牙利匹配算法用于执行最优匹配,使得每个预测与一个目标(或未匹配的情况)关联,以便计算损失。
2、forward()函数
@torch.no_grad()
def forward(self, outputs, targets):
""" Performs the matching
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -out_prob[:, tgt_ids]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()
sizes = [len(v["boxes"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
这个 `forward` 方法是 `HungarianMatcher` 类的主要方法,用于执行匈牙利匹配操作,将预测与目标进行匹配。下面是代码中每行的详细解释:
1. `@torch.no_grad()`:这是一个装饰器,用于将下面的方法调用设置为无需梯度。这是因为匈牙利匹配操作不需要进行梯度计算。
2. `def forward(self, outputs, targets):`:前向传播方法,用于执行匹配操作。接受两个参数:
- `outputs`:一个字典,包含以下至少两个条目:
- "pred_logits":形状为 [batch_size, num_queries, num_classes] 的张量,包含分类的 logits。
- "pred_boxes":形状为 [batch_size, num_queries, 4] 的张量,包含预测的边界框坐标。
- `targets`:一个目标列表,每个目标都是一个字典,包含以下两个条目:
- "labels":形状为 [num_target_boxes] 的张量,包含目标类别标签。
- "boxes":形状为 [num_target_boxes, 4] 的张量,包含目标边界框坐标。
3. `bs, num_queries = outputs["pred_logits"].shape[:2]`:获取批量大小(batch size)和查询数量(num_queries)。
4. `out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)`:将分类 logits 平铺并进行 softmax 操作,以计算预测的类别概率。结果形状为 [batch_size * num_queries, num_classes]。
5. `out_bbox = outputs["pred_boxes"].flatten(0, 1)`:将预测的边界框坐标平铺,形状为 [batch_size * num_queries, 4]。
6. `tgt_ids = torch.cat([v["labels"] for v in targets])`:将目标中的类别标签连接成一个张量,形状为 [总目标边界框数]。
7. `tgt_bbox = torch.cat([v["boxes"] for v in targets])`:将目标中的边界框坐标连接成一个张量,形状为 [总目标边界框数, 4]。
8. `cost_class = -out_prob[:, tgt_ids]`:计算分类成本,即预测类别与目标类别之间的损失。这里使用了负对数似然的近似计算。
9. `cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)`:计算边界框坐标成本,即预测边界框坐标与目标边界框坐标之间的 L1 距离。
10. `cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))`:计算 GIOU(Generalized IoU)成本,即预测边界框与目标边界框之间的 GIOU 损失。
11. `C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou`:将分类、边界框坐标和GIOU成本加权组合,得到最终的匹配成本矩阵C。
12. `C = C.view(bs, num_queries, -1).cpu()`:将成本矩阵C重新形状为 [batch_size, num_queries, 总目标边界框数],并将其移到CPU上。
13. `sizes = [len(v["boxes"]) for v in targets]`:获取每个目标中的边界框数量,存储在列表中。
14. `indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]`:对每个批次中的成本矩阵执行线性求和分配,以找到最佳匹配。
15. `return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]`:返回一个列表,其中包含了每个批次中的匹配结果。每个匹配结果是一个元组,包含两个张量,分别表示选定的预测索引和相应的目标索引。匹配数量等于最小的查询数量和目标边界框数量。
三、build_matcher
()函数
def build_matcher(args):
return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)
build_matcher
函数用于构建一个 HungarianMatcher
类的实例,根据传入的参数配置匹配器的成本项。以下是这个函数的实现:
-
def build_matcher(args):
:定义了一个名为build_matcher
的函数,接受一个参数args
,用于配置匹配器的成本项。 -
return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)
:创建并返回一个HungarianMatcher
类的实例。在创建实例时,根据传入的args
参数来设置成本项,这些成本项包括:cost_class
:分类错误的成本(类别损失的权重)。cost_bbox
:边界框坐标错误的成本(边界框坐标损失的权重)。cost_giou
:GIOU 损失的成本(GIOU 损失的权重)。
这样,build_matcher
函数可以根据传入的参数创建并配置一个匹配器,并将其返回供后续使用。