2024-11-27 23:22:08 +08:00

181 lines
8.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch
from torchvision.ops import nms, boxes
def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image):
#-----------------------------------------------------------------#
# 把y轴放前面是因为方便预测框和图像的宽高进行相乘
#-----------------------------------------------------------------#
box_yx = box_xy[..., ::-1]
box_hw = box_wh[..., ::-1]
input_shape = np.array(input_shape)
image_shape = np.array(image_shape)
if letterbox_image:
#-----------------------------------------------------------------#
# 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
# new_shape指的是宽高缩放情况
#-----------------------------------------------------------------#
new_shape = np.round(image_shape * np.min(input_shape/image_shape))
offset = (input_shape - new_shape)/2./input_shape
scale = input_shape/new_shape
box_yx = (box_yx - offset) * scale
box_hw *= scale
box_mins = box_yx - (box_hw / 2.)
box_maxes = box_yx + (box_hw / 2.)
boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
boxes *= np.concatenate([image_shape, image_shape], axis=-1)
return boxes
def decode_outputs(outputs, input_shape):
grids = []
strides = []
hw = [x.shape[-2:] for x in outputs]
#---------------------------------------------------#
# outputs输入前代表每个特征层的预测结果
# batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400
# batch_size, 5 + num_classes, 40, 40
# batch_size, 5 + num_classes, 20, 20
# batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400
# 堆叠后为batch_size, 8400, 5 + num_classes
#---------------------------------------------------#
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
#---------------------------------------------------#
# 获得每一个特征点属于每一个种类的概率
#---------------------------------------------------#
outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:])
for h, w in hw:
#---------------------------#
# 根据特征层的高宽生成网格点
#---------------------------#
grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)])
#---------------------------#
# 1, 6400, 2
# 1, 1600, 2
# 1, 400, 2
#---------------------------#
grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2)
shape = grid.shape[:2]
grids.append(grid)
strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h))
#---------------------------#
# 将网格点堆叠到一起
# 1, 6400, 2
# 1, 1600, 2
# 1, 400, 2
#
# 1, 8400, 2
#---------------------------#
grids = torch.cat(grids, dim=1).type(outputs.type())
strides = torch.cat(strides, dim=1).type(outputs.type())
#------------------------#
# 根据网格点进行解码
#------------------------#
outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
#-----------------#
# 归一化
#-----------------#
outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1]
outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0]
return outputs
def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
#----------------------------------------------------------#
# 将预测结果的格式转换成左上角右下角的格式。
# prediction [batch_size, num_anchors, 85]
#----------------------------------------------------------#
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
#----------------------------------------------------------#
# 对输入图片进行循环,一般只会进行一次
#----------------------------------------------------------#
for i, image_pred in enumerate(prediction):
#----------------------------------------------------------#
# 对种类预测部分取max。
# class_conf [num_anchors, 1] 种类置信度
# class_pred [num_anchors, 1] 种类
#----------------------------------------------------------#
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
#----------------------------------------------------------#
# 利用置信度进行第一轮筛选
#----------------------------------------------------------#
conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
if not image_pred.size(0):
continue
#-------------------------------------------------------------------------#
# detections [num_anchors, 7]
# 7的内容为x1, y1, x2, y2, obj_conf, class_conf, class_pred
#-------------------------------------------------------------------------#
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
detections = detections[conf_mask]
nms_out_index = boxes.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thres,
)
output[i] = detections[nms_out_index]
# #------------------------------------------#
# # 获得预测结果中包含的所有种类
# #------------------------------------------#
# unique_labels = detections[:, -1].cpu().unique()
# if prediction.is_cuda:
# unique_labels = unique_labels.cuda()
# detections = detections.cuda()
# for c in unique_labels:
# #------------------------------------------#
# # 获得某一类得分筛选后全部的预测结果
# #------------------------------------------#
# detections_class = detections[detections[:, -1] == c]
# #------------------------------------------#
# # 使用官方自带的非极大抑制会速度更快一些!
# #------------------------------------------#
# keep = nms(
# detections_class[:, :4],
# detections_class[:, 4] * detections_class[:, 5],
# nms_thres
# )
# max_detections = detections_class[keep]
# # # 按照存在物体的置信度排序
# # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
# # detections_class = detections_class[conf_sort_index]
# # # 进行非极大抑制
# # max_detections = []
# # while detections_class.size(0):
# # # 取出这一类置信度最高的一步一步往下判断判断重合程度是否大于nms_thres如果是则去除掉
# # max_detections.append(detections_class[0].unsqueeze(0))
# # if len(detections_class) == 1:
# # break
# # ious = bbox_iou(max_detections[-1], detections_class[1:])
# # detections_class = detections_class[1:][ious < nms_thres]
# # # 堆叠
# # max_detections = torch.cat(max_detections).data
# # Add max detections to outputs
# output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
if output[i] is not None:
output[i] = output[i].cpu().numpy()
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
return output