#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii, Inc. and its affiliates. import math from copy import deepcopy from functools import partial import torch import torch.nn as nn import torch.nn.functional as F class IOUloss(nn.Module): def __init__(self, reduction="none", loss_type="iou"): super(IOUloss, self).__init__() self.reduction = reduction self.loss_type = loss_type def forward(self, pred, target): assert pred.shape[0] == target.shape[0] pred = pred.view(-1, 4) target = target.view(-1, 4) tl = torch.max( (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) ) br = torch.min( (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) ) area_p = torch.prod(pred[:, 2:], 1) area_g = torch.prod(target[:, 2:], 1) en = (tl < br).type(tl.type()).prod(dim=1) area_i = torch.prod(br - tl, 1) * en area_u = area_p + area_g - area_i iou = (area_i) / (area_u + 1e-16) if self.loss_type == "iou": loss = 1 - iou ** 2 elif self.loss_type == "giou": c_tl = torch.min( (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) ) c_br = torch.max( (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) ) area_c = torch.prod(c_br - c_tl, 1) giou = iou - (area_c - area_u) / area_c.clamp(1e-16) loss = 1 - giou.clamp(min=-1.0, max=1.0) if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": loss = loss.sum() return loss class YOLOLoss(nn.Module): def __init__(self, num_classes, fp16, strides=[8, 16, 32]): super().__init__() self.num_classes = num_classes self.strides = strides self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") self.iou_loss = IOUloss(reduction="none") self.grids = [torch.zeros(1)] * len(strides) self.fp16 = fp16 def forward(self, inputs, labels=None): outputs = [] x_shifts = [] y_shifts = [] expanded_strides = [] #-----------------------------------------------# # inputs [[batch_size, num_classes + 5, 20, 20] # [batch_size, num_classes + 5, 40, 40] # [batch_size, num_classes + 5, 80, 80]] # outputs [[batch_size, 400, num_classes + 5] # [batch_size, 1600, num_classes + 5] # [batch_size, 6400, num_classes + 5]] # x_shifts [[batch_size, 400] # [batch_size, 1600] # [batch_size, 6400]] #-----------------------------------------------# for k, (stride, output) in enumerate(zip(self.strides, inputs)): output, grid = self.get_output_and_grid(output, k, stride) x_shifts.append(grid[:, :, 0]) y_shifts.append(grid[:, :, 1]) expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride) outputs.append(output) return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1)) def get_output_and_grid(self, output, k, stride): grid = self.grids[k] hsize, wsize = output.shape[-2:] if grid.shape[2:4] != output.shape[2:4]: yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type()) self.grids[k] = grid grid = grid.view(1, -1, 2) output = output.flatten(start_dim=2).permute(0, 2, 1) output[..., :2] = (output[..., :2] + grid.type_as(output)) * stride output[..., 2:4] = torch.exp(output[..., 2:4]) * stride return output, grid def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs): #-----------------------------------------------# # [batch, n_anchors_all, 4] #-----------------------------------------------# bbox_preds = outputs[:, :, :4] #-----------------------------------------------# # [batch, n_anchors_all, 1] #-----------------------------------------------# obj_preds = outputs[:, :, 4:5] #-----------------------------------------------# # [batch, n_anchors_all, n_cls] #-----------------------------------------------# cls_preds = outputs[:, :, 5:] total_num_anchors = outputs.shape[1] #-----------------------------------------------# # x_shifts [1, n_anchors_all] # y_shifts [1, n_anchors_all] # expanded_strides [1, n_anchors_all] #-----------------------------------------------# x_shifts = torch.cat(x_shifts, 1).type_as(outputs) y_shifts = torch.cat(y_shifts, 1).type_as(outputs) expanded_strides = torch.cat(expanded_strides, 1).type_as(outputs) cls_targets = [] reg_targets = [] obj_targets = [] fg_masks = [] num_fg = 0.0 for batch_idx in range(outputs.shape[0]): num_gt = len(labels[batch_idx]) if num_gt == 0: cls_target = outputs.new_zeros((0, self.num_classes)) reg_target = outputs.new_zeros((0, 4)) obj_target = outputs.new_zeros((total_num_anchors, 1)) fg_mask = outputs.new_zeros(total_num_anchors).bool() else: #-----------------------------------------------# # gt_bboxes_per_image [num_gt, num_classes] # gt_classes [num_gt] # bboxes_preds_per_image [n_anchors_all, 4] # cls_preds_per_image [n_anchors_all, num_classes] # obj_preds_per_image [n_anchors_all, 1] #-----------------------------------------------# gt_bboxes_per_image = labels[batch_idx][..., :4].type_as(outputs) gt_classes = labels[batch_idx][..., 4].type_as(outputs) bboxes_preds_per_image = bbox_preds[batch_idx] cls_preds_per_image = cls_preds[batch_idx] obj_preds_per_image = obj_preds[batch_idx] gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments( num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts, ) torch.cuda.empty_cache() num_fg += num_fg_img cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1) obj_target = fg_mask.unsqueeze(-1) reg_target = gt_bboxes_per_image[matched_gt_inds] cls_targets.append(cls_target) reg_targets.append(reg_target) obj_targets.append(obj_target.type(cls_target.type())) fg_masks.append(fg_mask) cls_targets = torch.cat(cls_targets, 0) reg_targets = torch.cat(reg_targets, 0) obj_targets = torch.cat(obj_targets, 0) fg_masks = torch.cat(fg_masks, 0) num_fg = max(num_fg, 1) loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum() loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum() loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum() reg_weight = 5.0 loss = reg_weight * loss_iou + loss_obj + loss_cls return loss / num_fg @torch.no_grad() def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts): #-------------------------------------------------------# # fg_mask [n_anchors_all] # is_in_boxes_and_center [num_gt, len(fg_mask)] #-------------------------------------------------------# fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt) #-------------------------------------------------------# # fg_mask [n_anchors_all] # bboxes_preds_per_image [fg_mask, 4] # cls_preds_ [fg_mask, num_classes] # obj_preds_ [fg_mask, 1] #-------------------------------------------------------# bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] cls_preds_ = cls_preds_per_image[fg_mask] obj_preds_ = obj_preds_per_image[fg_mask] num_in_boxes_anchor = bboxes_preds_per_image.shape[0] #-------------------------------------------------------# # pair_wise_ious [num_gt, fg_mask] #-------------------------------------------------------# pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) #-------------------------------------------------------# # cls_preds_ [num_gt, fg_mask, num_classes] # gt_cls_per_image [num_gt, fg_mask, num_classes] #-------------------------------------------------------# if self.fp16: with torch.cuda.amp.autocast(enabled=False): cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1) pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1) else: cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1) pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1) del cls_preds_ cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float() num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask) del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True): if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: raise IndexError if xyxy: tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) else: tl = torch.max( (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), ) br = torch.min( (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), ) area_a = torch.prod(bboxes_a[:, 2:], 1) area_b = torch.prod(bboxes_b[:, 2:], 1) en = (tl < br).type(tl.type()).prod(dim=2) area_i = torch.prod(br - tl, 2) * en return area_i / (area_a[:, None] + area_b - area_i) def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5): #-------------------------------------------------------# # expanded_strides_per_image [n_anchors_all] # x_centers_per_image [num_gt, n_anchors_all] # x_centers_per_image [num_gt, n_anchors_all] #-------------------------------------------------------# expanded_strides_per_image = expanded_strides[0] x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1) y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1) #-------------------------------------------------------# # gt_bboxes_per_image_x [num_gt, n_anchors_all] #-------------------------------------------------------# gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors) gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors) gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors) gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors) #-------------------------------------------------------# # bbox_deltas [num_gt, n_anchors_all, 4] #-------------------------------------------------------# b_l = x_centers_per_image - gt_bboxes_per_image_l b_r = gt_bboxes_per_image_r - x_centers_per_image b_t = y_centers_per_image - gt_bboxes_per_image_t b_b = gt_bboxes_per_image_b - y_centers_per_image bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2) #-------------------------------------------------------# # is_in_boxes [num_gt, n_anchors_all] # is_in_boxes_all [n_anchors_all] #-------------------------------------------------------# is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0 is_in_boxes_all = is_in_boxes.sum(dim=0) > 0 gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0) gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0) gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0) gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0) #-------------------------------------------------------# # center_deltas [num_gt, n_anchors_all, 4] #-------------------------------------------------------# c_l = x_centers_per_image - gt_bboxes_per_image_l c_r = gt_bboxes_per_image_r - x_centers_per_image c_t = y_centers_per_image - gt_bboxes_per_image_t c_b = gt_bboxes_per_image_b - y_centers_per_image center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) #-------------------------------------------------------# # is_in_centers [num_gt, n_anchors_all] # is_in_centers_all [n_anchors_all] #-------------------------------------------------------# is_in_centers = center_deltas.min(dim=-1).values > 0.0 is_in_centers_all = is_in_centers.sum(dim=0) > 0 #-------------------------------------------------------# # is_in_boxes_anchor [n_anchors_all] # is_in_boxes_and_center [num_gt, is_in_boxes_anchor] #-------------------------------------------------------# is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor] return is_in_boxes_anchor, is_in_boxes_and_center def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): #-------------------------------------------------------# # cost [num_gt, fg_mask] # pair_wise_ious [num_gt, fg_mask] # gt_classes [num_gt] # fg_mask [n_anchors_all] # matching_matrix [num_gt, fg_mask] #-------------------------------------------------------# matching_matrix = torch.zeros_like(cost) #------------------------------------------------------------# # 选取iou最大的n_candidate_k个点 # 然后求和,判断应该有多少点用于该框预测 # topk_ious [num_gt, n_candidate_k] # dynamic_ks [num_gt] # matching_matrix [num_gt, fg_mask] #------------------------------------------------------------# n_candidate_k = min(10, pair_wise_ious.size(1)) topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) for gt_idx in range(num_gt): #------------------------------------------------------------# # 给每个真实框选取最小的动态k个点 #------------------------------------------------------------# _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) matching_matrix[gt_idx][pos_idx] = 1.0 del topk_ious, dynamic_ks, pos_idx #------------------------------------------------------------# # anchor_matching_gt [fg_mask] #------------------------------------------------------------# anchor_matching_gt = matching_matrix.sum(0) if (anchor_matching_gt > 1).sum() > 0: #------------------------------------------------------------# # 当某一个特征点指向多个真实框的时候 # 选取cost最小的真实框。 #------------------------------------------------------------# _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) matching_matrix[:, anchor_matching_gt > 1] *= 0.0 matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 #------------------------------------------------------------# # fg_mask_inboxes [fg_mask] # num_fg为正样本的特征点个数 #------------------------------------------------------------# fg_mask_inboxes = matching_matrix.sum(0) > 0.0 num_fg = fg_mask_inboxes.sum().item() #------------------------------------------------------------# # 对fg_mask进行更新 #------------------------------------------------------------# fg_mask[fg_mask.clone()] = fg_mask_inboxes #------------------------------------------------------------# # 获得特征点对应的物品种类 #------------------------------------------------------------# matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) gt_matched_classes = gt_classes[matched_gt_inds] pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes] return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds def is_parallel(model): # Returns True if model is of type DP or DDP return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) def de_parallel(model): # De-parallelize a model: returns single-GPU model if model is of type DP or DDP return model.module if is_parallel(model) else model def copy_attr(a, b, include=(), exclude=()): # Copy attributes from b to a, options to only include [...] and to exclude [...] for k, v in b.__dict__.items(): if (len(include) and k not in include) or k.startswith('_') or k in exclude: continue else: setattr(a, k, v) class ModelEMA: """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models Keeps a moving average of everything in the model state_dict (parameters and buffers) For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage """ def __init__(self, model, decay=0.9999, tau=2000, updates=0): # Create EMA self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA # if next(model.parameters()).device.type != 'cpu': # self.ema.half() # FP16 EMA self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) for p in self.ema.parameters(): p.requires_grad_(False) def update(self, model): # Update EMA parameters with torch.no_grad(): self.updates += 1 d = self.decay(self.updates) msd = de_parallel(model).state_dict() # model state_dict for k, v in self.ema.state_dict().items(): if v.dtype.is_floating_point: v *= d v += (1 - d) * msd[k].detach() def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes copy_attr(self.ema, model, include, exclude) def weights_init(net, init_type='normal', init_gain = 0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and classname.find('Conv') != -1: if init_type == 'normal': torch.nn.init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) elif classname.find('BatchNorm2d') != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) print('initialize network with %s type' % init_type) net.apply(init_func) def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): if iters <= warmup_total_iters: # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start elif iters >= total_iters - no_aug_iter: lr = min_lr else: lr = min_lr + 0.5 * (lr - min_lr) * ( 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) ) return lr def step_lr(lr, decay_rate, step_size, iters): if step_size < 1: raise ValueError("step_size must above 1.") n = iters // step_size out_lr = lr * decay_rate ** n return out_lr if lr_decay_type == "cos": warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) else: decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) step_size = total_iters / step_num func = partial(step_lr, lr, decay_rate, step_size) return func def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): lr = lr_scheduler_func(epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr