10 March 2023
Paper: YOLOv4: Optimal Speed and Accuracy of Object Detection
YOLOv4 的损失函数可以分为三部分,分别是 定位损失、分类损失和对象损失。它的总损失函数为这三个部分的加权和。具体地,可以表示为:
其中, 是边框回归损失, 是分类损失, 是物体置信度损失,、、 和 是调节不同损失权重的参数, 表示分割网格的数量, 是每个单元格预测的边界框数量,$\mathbb{1}_{i,j}^{obj}$ 表示当前第 $i$ 行 $j$ 列的网格($\mathbb{1}^{obj}$ 为有物体为1,没有则为0)
边框回归损失($\mathcal{L}_{\text{coord}}$):该部分损失主要用于调整预测边界框的位置和大小,通过计算预测边界框的中心点和宽高与真实边界框的中心点和宽高之间的差距来计算。
分类损失($\mathcal{L}_{\text{cls}}$):该部分损失主要用于对预测目标的类别进行分类。计算方式为对每个类别的预测概率与真实标签之间的平方差。
物体置信度损失($\mathcal{L}_{\text{obj}}$):该部分损失主要用于判断每个单元格是否包含目标。当单元格中存在目标时,该部分损失会考虑分类损失和定位损失;当单元格中不存在目标时,只考虑分类损失。
无目标损失($\lambda_{\text{noobj}} \sum_{i=0}^{S^2} \sum_{j=0}^{B-1} \mathbb{1}_{i,j}^{noobj} (C_i - \hat{C}_i)^2$):该部分损失用于惩罚预测框中心点落在不包含目标的单元格中的情况。
# 求出预测框左上角右下角
b1_xy = b1[..., :2]
b1_wh = b1[..., 2:4]
b1_wh_half = b1_wh / 2.
b1_mins = b1_xy - b1_wh_half
b1_maxes = b1_xy + b1_wh_half
# 求出真实框左上角右下角
b2_xy = b2[..., :2]
b2_wh = b2[..., 2:4]
b2_wh_half = b2_wh / 2.
b2_mins = b2_xy - b2_wh_half
b2_maxes = b2_xy + b2_wh_half
# 求真实框和预测框所有的iou
intersect_mins = torch.max(b1_mins, b2_mins) # 相交区域的左上角
intersect_maxes = torch.min(b1_maxes, b2_maxes) # 相交区域的右下角
intersect_wh = torch.max(intersect_maxes - intersect_mins, \
torch.zeros_like(intersect_maxes)) # 相交区域的长和宽
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] # 相交区域面积
b1_area = b1_wh[..., 0] * b1_wh[..., 1] # 预测框面积
b2_area = b2_wh[..., 0] * b2_wh[..., 1] # 真实(gt)框面积
union_area = b1_area + b2_area - intersect_area
iou = intersect_area / torch.clamp(union_area,min = 1e-6) # IOU 计算
v = (4 / (math.pi ** 2)) \
* torch.pow((torch.atan(b1_wh[..., 0] / torch.clamp(b1_wh[..., 1], min = 1e-6)) \
- torch.atan(b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min = 1e-6))), 2)
alpha = v / torch.clamp((1.0 - iou + v), min=1e-6)
def box_ciou(b1, b2):
"""
输入为:
----------
b1: tensor, shape=(batch, w, h, anchor_num, 4), xywh
b2: tensor, shape=(batch, w, h, anchor_num, 4), xywh
返回为:
-------
ciou: tensor, shape=(batch, w, h, anchor_num, 1)
"""
# 求出预测框左上角右下角
b1_xy = b1[..., :2]
b1_wh = b1[..., 2:4]
b1_wh_half = b1_wh / 2.
b1_mins = b1_xy - b1_wh_half
b1_maxes = b1_xy + b1_wh_half
# 求出真实框左上角右下角
b2_xy = b2[..., :2]
b2_wh = b2[..., 2:4]
b2_wh_half = b2_wh / 2.
b2_mins = b2_xy - b2_wh_half
b2_maxes = b2_xy + b2_wh_half
# 求真实框和预测框所有的iou
intersect_mins = torch.max(b1_mins, b2_mins) # 相交区域的左上角
intersect_maxes = torch.min(b1_maxes, b2_maxes) # 相交区域的右下角
intersect_wh = torch.max(intersect_maxes - intersect_mins, \
torch.zeros_like(intersect_maxes)) # 相交区域的长和宽
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] # 相交区域面积
b1_area = b1_wh[..., 0] * b1_wh[..., 1] # 预测框面积
b2_area = b2_wh[..., 0] * b2_wh[..., 1] # 真实(gt)框面积
union_area = b1_area + b2_area - intersect_area
iou = intersect_area / torch.clamp(union_area,min = 1e-6) # IOU 计算
# 计算中心的差距
center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
# 找到包裹两个框的最小框的左上角和右下角
enclose_mins = torch.min(b1_mins, b2_mins)
enclose_maxes = torch.max(b1_maxes, b2_maxes)
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
# 计算对角线距离
enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
v = (4 / (math.pi ** 2)) \
* torch.pow((torch.atan(b1_wh[..., 0] / torch.clamp(b1_wh[..., 1], min = 1e-6)) \
- torch.atan(b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min = 1e-6))), 2)
alpha = v / torch.clamp((1.0 - iou + v), min=1e-6)
ciou = ciou - alpha * v
return ciou