import numpy as np import torch from torchvision.ops import nms, boxes def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image): #-----------------------------------------------------------------# # 把y轴放前面是因为方便预测框和图像的宽高进行相乘 #-----------------------------------------------------------------# box_yx = box_xy[..., ::-1] box_hw = box_wh[..., ::-1] input_shape = np.array(input_shape) image_shape = np.array(image_shape) if letterbox_image: #-----------------------------------------------------------------# # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况 # new_shape指的是宽高缩放情况 #-----------------------------------------------------------------# new_shape = np.round(image_shape * np.min(input_shape/image_shape)) offset = (input_shape - new_shape)/2./input_shape scale = input_shape/new_shape box_yx = (box_yx - offset) * scale box_hw *= scale box_mins = box_yx - (box_hw / 2.) box_maxes = box_yx + (box_hw / 2.) boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1) boxes *= np.concatenate([image_shape, image_shape], axis=-1) return boxes def decode_outputs(outputs, input_shape): grids = [] strides = [] hw = [x.shape[-2:] for x in outputs] #---------------------------------------------------# # outputs输入前代表每个特征层的预测结果 # batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400 # batch_size, 5 + num_classes, 40, 40 # batch_size, 5 + num_classes, 20, 20 # batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400 # 堆叠后为batch_size, 8400, 5 + num_classes #---------------------------------------------------# outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1) #---------------------------------------------------# # 获得每一个特征点属于每一个种类的概率 #---------------------------------------------------# outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:]) for h, w in hw: #---------------------------# # 根据特征层的高宽生成网格点 #---------------------------# grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)]) #---------------------------# # 1, 6400, 2 # 1, 1600, 2 # 1, 400, 2 #---------------------------# grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2) shape = grid.shape[:2] grids.append(grid) strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h)) #---------------------------# # 将网格点堆叠到一起 # 1, 6400, 2 # 1, 1600, 2 # 1, 400, 2 # # 1, 8400, 2 #---------------------------# grids = torch.cat(grids, dim=1).type(outputs.type()) strides = torch.cat(strides, dim=1).type(outputs.type()) #------------------------# # 根据网格点进行解码 #------------------------# outputs[..., :2] = (outputs[..., :2] + grids) * strides outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides #-----------------# # 归一化 #-----------------# outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1] outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0] return outputs def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4): #----------------------------------------------------------# # 将预测结果的格式转换成左上角右下角的格式。 # prediction [batch_size, num_anchors, 85] #----------------------------------------------------------# box_corner = prediction.new(prediction.shape) box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 prediction[:, :, :4] = box_corner[:, :, :4] output = [None for _ in range(len(prediction))] #----------------------------------------------------------# # 对输入图片进行循环,一般只会进行一次 #----------------------------------------------------------# for i, image_pred in enumerate(prediction): #----------------------------------------------------------# # 对种类预测部分取max。 # class_conf [num_anchors, 1] 种类置信度 # class_pred [num_anchors, 1] 种类 #----------------------------------------------------------# class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) #----------------------------------------------------------# # 利用置信度进行第一轮筛选 #----------------------------------------------------------# conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze() if not image_pred.size(0): continue #-------------------------------------------------------------------------# # detections [num_anchors, 7] # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred #-------------------------------------------------------------------------# detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) detections = detections[conf_mask] nms_out_index = boxes.batched_nms( detections[:, :4], detections[:, 4] * detections[:, 5], detections[:, 6], nms_thres, ) output[i] = detections[nms_out_index] # #------------------------------------------# # # 获得预测结果中包含的所有种类 # #------------------------------------------# # unique_labels = detections[:, -1].cpu().unique() # if prediction.is_cuda: # unique_labels = unique_labels.cuda() # detections = detections.cuda() # for c in unique_labels: # #------------------------------------------# # # 获得某一类得分筛选后全部的预测结果 # #------------------------------------------# # detections_class = detections[detections[:, -1] == c] # #------------------------------------------# # # 使用官方自带的非极大抑制会速度更快一些! # #------------------------------------------# # keep = nms( # detections_class[:, :4], # detections_class[:, 4] * detections_class[:, 5], # nms_thres # ) # max_detections = detections_class[keep] # # # 按照存在物体的置信度排序 # # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True) # # detections_class = detections_class[conf_sort_index] # # # 进行非极大抑制 # # max_detections = [] # # while detections_class.size(0): # # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 # # max_detections.append(detections_class[0].unsqueeze(0)) # # if len(detections_class) == 1: # # break # # ious = bbox_iou(max_detections[-1], detections_class[1:]) # # detections_class = detections_class[1:][ious < nms_thres] # # # 堆叠 # # max_detections = torch.cat(max_detections).data # # Add max detections to outputs # output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections)) if output[i] is not None: output[i] = output[i].cpu().numpy() box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2] output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image) return output