Skip to content

Commit

Permalink
[DNL] executorch export faster-rcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
yipjustin committed Dec 6, 2024
1 parent 6279faa commit 916ba03
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ def batched_nms(
_log_api_usage_once(batched_nms)
# Benchmarks that drove the following thresholds are at
# https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
else:
return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
#if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
# return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
#else:
# return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)


@torch.jit._script_if_tracing
Expand Down Expand Up @@ -104,7 +105,8 @@ def _batched_nms_vanilla(
) -> Tensor:
# Based on Detectron2 implementation, just manually call nms() on each class independently
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
for class_id in torch.unique(idxs):
#for class_id in torch.unique(idxs):
for class_id in idxs:
curr_indices = torch.where(idxs == class_id)[0]
curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
keep_mask[curr_indices[curr_keep_indices]] = True
Expand Down

0 comments on commit 916ba03

Please sign in to comment.