-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrpn_train_ain1.py
2535 lines (2090 loc) · 81.1 KB
/
rpn_train_ain1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""all-in-one script for training the RPN model of Mask R-CNN."""
import os
import random
from typing import Union
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
# training
EPS = 1e-5
# model
STRIDE = 32
SIZE_RESIZE = 600 # image size for resizing
SIZE_IMG = 512 # image size
W = SIZE_IMG # original image width
H = SIZE_IMG # original image height
SIZE_FM = SIZE_IMG // STRIDE # feature map size
C = 3 # number of channels
W_FM = W // STRIDE # feature map width
H_FM = H // STRIDE # feature map height
N_ANCHOR = 9 # number of anchors per grid cell
N_OBJ = 20
N_SUPP_SCORE = 800 # number of boxes to keep after score suppression
N_SUPP_NMS = 20 # number of boxes to keep after nms
N_ALIGN_GRID = 2 # number of grid cells to align (tuned)
C_FM = 2048 # number of channels in the feature map
NMS_TH = 0.7 # nms threshold
# dataset
BUFFER_SIZE = 100
BATCH_SIZE_TR = 8
BATCH_SIZE_TE = 8
N_CLASS = 20 # number of classes
# RPN
R_DROP = 0.2 # dropout rate
IOU_TH = 0.5 # IoU threshold for calculating mean Average Precision (mAP)
IOU_SCALE = 10000 # scale for IoU for converting to int
NEG_TH_ACGT = int(IOU_SCALE * 0.30)
POS_TH_ACGT = int(IOU_SCALE * 0.50) # tuned for VOC 2007
POS_TH_GTAC = int(IOU_SCALE * 0.01)
NUM_POS_RPN = 128 # number of positive anchors
NUM_NEG_RPN = 128 # number of negative anchors
# Visualize
COLOR_BOX = (0, 255, 0) # Green color for bounding box
COLOR_TXT = (0, 0, 255) # Red color for class tag
THICKNESS_BOX = 2 # Line thickness
THICKNESS_TXT = 1 # Text thickness
SIZE_FONT = 0.5 # Font size
# image normalization
IMGNET_STD = np.array([58.393, 57.12, 57.375], dtype=np.float32)
IMGNET_MEAN = np.array([123.68, 116.78, 103.94], dtype=np.float32)
TensorT = Union[tf.Tensor, np.ndarray] # noqa: UP007
# TODO: use operation-level seed
def set_global_determinism(seed: int) -> None:
"""Set global determinism.
Args:
seed (int): random seed
"""
# ---------- set the global random seed ----------
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
# ---------- set the global graph-level seed ----------
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.threading.set_intra_op_parallelism_threads(1)
# Call the above function with seed value
set_global_determinism(seed=100)
# =============================================================================
# SECTION: dataset
# =============================================================================
def resize(img: tf.Tensor, size: int) -> tf.Tensor:
"""Resize an image and bounding boxes while preserving its aspect ratio.
Args:
img (tf.Tensor): image to resize
bbx (tf.Tensor): bounding boxes to resize
size (int): desired size of the shorter side of the image
keep_ratio (bool, optional): whether to keep the aspect ratio.
Returns:
tuple[tf.Tensor, tf.Tensor]: resized image and bounding boxes
"""
h, w, _ = tf.unstack(tf.shape(img))
scale = tf.maximum(
tf.cast(size / h, tf.float32),
tf.cast(size / w, tf.float32),
) # scale to the shorter side
new_h = tf.cast(tf.round(tf.cast(h, tf.float32) * scale), tf.int32)
new_w = tf.cast(tf.round(tf.cast(w, tf.float32) * scale), tf.int32)
img = tf.image.resize(img, [new_h, new_w])
return img
def rand_crop(
img: tf.Tensor,
bbx: tf.Tensor,
h_target: int,
w_target: int,
) -> tuple[tf.Tensor, tf.Tensor]:
"""Randomly crops an image and adjusts the bounding boxes.
Args:
img (tf.Tensor): Tensor representing the image.
bbx (tf.Tensor): Tensor of bounding boxes in relative coordinates,
shape [num_boxes, 4].
h_target (int): Target height for the cropped image.
w_target (int): Target width for the cropped image.
Returns:
tuple[tf.Tensor, tf.Tensor]: Cropped image and adjusted bounding boxes.
"""
h_img, w_img, _ = tf.unstack(tf.shape(img))
max_h_offset = h_img - h_target - 1
max_w_offset = w_img - w_target - 1
# Randomly choosing the offset for cropping
h_offset = tf.random.uniform(shape=[],
minval=0,
maxval=max_h_offset,
dtype=tf.int32)
w_offset = tf.random.uniform(shape=[],
minval=0,
maxval=max_w_offset,
dtype=tf.int32)
# Cropping the image
img_crop = tf.image.crop_to_bounding_box(img, h_offset, w_offset, h_target,
w_target)
# Adjust bounding boxes
# Scale bounding box coordinates to the cropped image size
ymin, xmin, ymax, xmax = bbx[..., 0], bbx[..., 1], bbx[..., 2], bbx[..., 3]
# Clip bounding boxes to be within the cropped image
ymin = (ymin * tf.cast(h_img, tf.float32) -
tf.cast(h_offset, tf.float32)) / tf.cast(h_target, tf.float32)
ymax = (ymax * tf.cast(h_img, tf.float32) -
tf.cast(h_offset, tf.float32)) / tf.cast(h_target, tf.float32)
xmin = (xmin * tf.cast(w_img, tf.float32) -
tf.cast(w_offset, tf.float32)) / tf.cast(w_target, tf.float32)
xmax = (xmax * tf.cast(w_img, tf.float32) -
tf.cast(w_offset, tf.float32)) / tf.cast(w_target, tf.float32)
# Clip bounding boxes to be within the cropped image
bbx_clip = tf.stack(
[
tf.clip_by_value(ymin, 0, 1),
tf.clip_by_value(xmin, 0, 1),
tf.clip_by_value(ymax, 0, 1),
tf.clip_by_value(xmax, 0, 1),
],
axis=1,
)
return img_crop, bbx_clip
def batch_pad(
tensor: tf.Tensor,
max_box: int,
value: int,
) -> tf.Tensor:
"""Pad a tensor (either labels or bounding boxes) as a batch of fixed size.
Args:
tensor (tf.Tensor): tensor to pad, must be at least 2nd order, e.g.
`len(tensor.shape) >= 2`
max_box (int): maximum number of boxes
value (int): value to pad with
Returns:
tf.Tensor: padded tensor
"""
padding_size = max_box - tf.shape(tensor)[0]
if padding_size > 0:
tensor = tf.pad(tensor,
paddings=[[0, padding_size], [0, 0]],
constant_values=value)
else:
tensor = tensor[:max_box]
return tensor
def data_augment(
img: tf.Tensor,
bbx: tf.Tensor,
) -> tuple[tf.Tensor, tf.Tensor]:
"""Apply at most one data augmentation to the image and bounding boxes.
Args:
img (tf.Tensor): Input image
bbx (tf.Tensor): Bounding boxes associated with the image
Returns:
tuple[tf.Tensor, tf.Tensor]: Augmented image and adjusted bounding
boxes.
"""
# Proportion for each operation and idle
proportion = 0.25
# Random number for choosing the augmentation or idle
rand_aug = tf.random.uniform((), minval=0, maxval=1, dtype=tf.float32)
# Horizontal flip
if 0 <= rand_aug < 1 * proportion:
img = tf.image.flip_left_right(img)
ymin, xmin, ymax, xmax = tf.unstack(bbx, axis=1)
flipped_bbx = tf.stack([ymin, 1.0 - xmax, ymax, 1.0 - xmin], axis=1)
bbx = flipped_bbx
# Gaussian noise
elif 1 * proportion <= rand_aug < 2 * proportion:
noise = tf.random.normal(
shape=tf.shape(img),
mean=0.0,
stddev=0.5,
dtype=tf.float32,
)
img = img + noise
# Random brightness
elif 2 * proportion <= rand_aug < 3 * proportion:
img = tf.image.random_brightness(img, max_delta=2.0)
# No operation is done in the range [3 * proportion, 1]
return img, bbx
def preprcs_tr(sample: dict) -> tuple[tf.Tensor, tuple[tf.Tensor]]:
"""Preprocesses a sample from the dataset.
Args:
sample (dict): sample from the dataset
e.g. {
"image": tf.Tensor,
"objects": {
"bbox": tf.Tensor,
"label": tf.Tensor,
},
}
Returns:
tuple[tf.Tensor, tuple[tf.Tensor]]: preprocessed image and targets
- images: [batch_size, H, W, 3]
- bounding boxes: [batch_size, max_box, 4] in relative coordinates
- labels: [batch_size, max_box, 1]
"""
img = sample["image"]
# Normalize the image
img = (tf.cast(img, dtype=tf.float32) - IMGNET_MEAN) / IMGNET_STD
# Get the bounding boxes and labels
bbx = sample["objects"]["bbox"]
lab = sample["objects"]["label"]
# resize the image and bounding boxes while preserving the aspect ratio
img = resize(img, SIZE_RESIZE)
# randomly crop the image and bounding boxes
img, bbx = rand_crop(img, bbx, SIZE_IMG, SIZE_IMG)
# randomly augment the image and bounding boxes
img, bbx = data_augment(img, bbx)
# pad the labels and bounding boxes to a fixed size
bbx = batch_pad(bbx, max_box=N_OBJ, value=0)
lab = batch_pad(lab[:, tf.newaxis], max_box=N_OBJ, value=-1)
return img, (bbx, lab)
def preprcs_te(sample: dict) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Preprocess test dataset samples.
Args:
sample (dict): sample from the dataset
e.g. {
"image": tf.Tensor,
"objects": {
"bbox": tf.Tensor,
"label": tf.Tensor,
},
}
Returns:
tuple[tf.Tensor, tf.Tensor, tf.Tensor]: images, bounding boxes, and
labels
- images: [batch_size, H, W, 3]
- bounding boxes: [batch_size, max_box, 4] in relative coordinates
- labels: [batch_size, max_box, 1]
TODO: predict with full-sized images
"""
img = sample["image"]
# Normalize the image with ImageNet mean and std
img = (tf.cast(img, dtype=tf.float32) - IMGNET_MEAN) / IMGNET_STD
# Get the bounding boxes and labels
bbx = sample["objects"]["bbox"]
lbl = sample["objects"]["label"]
# resize the image and bounding boxes while preserving the aspect ratio
img = resize(img, SIZE_RESIZE)
# randomly crop the image and bounding boxes
img, bbx = rand_crop(img, bbx, SIZE_IMG, SIZE_IMG)
# pad the labels and bounding boxes to a fixed size
bbx = batch_pad(bbx, max_box=N_OBJ, value=0)
lbl = batch_pad(lbl[:, tf.newaxis], max_box=N_OBJ, value=-1)
return img, (bbx, lbl)
def load_train_voc2007(
n: int) -> tuple[tf.data.Dataset, tfds.core.DatasetInfo]:
"""Loads the training dataset of Pascal VOC 2007.
Args:
name (str): name of the dataset
n (int): number of training samples per batch
Returns:
tuple[tf.data.Dataset, tfds.core.DatasetInfo]: training datasets and
dataset info.
The training and validation datasets are shuffled and preprocessed with
`ds_handler`:
- images: [batch_size, H, W, 3]
- bounding boxes: [batch_size, max_box, 4] in relative
- labels: [batch_size, max_box, 1]
"""
ds, ds_info = tfds.load(
"voc/2007",
split="train",
shuffle_files=True,
with_info=True,
)
ds = ds.map(
preprcs_tr,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
).shuffle(BUFFER_SIZE).batch(n).prefetch(tf.data.experimental.AUTOTUNE)
return ds, ds_info
def load_test_voc2007(n: int) -> tuple[tf.data.Dataset, tfds.core.DatasetInfo]:
"""Loads the testing dataset of Pascal VOC 2007.
Args:
n (int): number of testing samples per batch
Returns:
tuple[tf.data.Dataset, tfds.core.DatasetInfo]: testing dataset and
dataset info.
The testing dataset is preprocessed with `ds_handler`:
- images: [batch_size, H, W, 3]
- bounding boxes: [batch_size, max_box, 4] in relative
- labels: [batch_size, max_box, 1]
"""
ds_te, ds_info = tfds.load(
"voc/2007",
split="validation",
shuffle_files=False,
with_info=True,
)
ds_te = ds_te.map(
preprcs_te,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
).batch(n).prefetch(tf.data.experimental.AUTOTUNE)
return ds_te, ds_info
# =============================================================================
# SECTION: anchor
# =============================================================================
def _scales(x: int) -> tuple[int, int, int]:
"""Get the scale sequence dynamically with closest (floor) power of 2.
Example:
>>> _scales(32)
(8, 16, 32)
>>> _scales(63)
(8, 16, 32)
>>> _scales(64)
(16, 32, 64)
Args:
x (int): minimum shape value (width or height) of the input image
Returns:
tuple[int, int, int]: three scales from small to large
"""
# closest (ceiling) power of 2
scale_max = 2**(x).bit_length()
return scale_max >> 3, scale_max >> 2, scale_max >> 1
def _scale_mat(x: int) -> np.ndarray:
"""Get the scale matrix for production.
Args:
x (int): minimum shape value (width or height) of the input image
Returns:
np.ndarray: scale matrix of shape (9, 3)
"""
scale_min, scale_med, scale_max = _scales(x)
return np.array(
[
[scale_min, 0, 0],
[scale_med, 0, 0],
[scale_max, 0, 0],
[0, scale_min, 0],
[0, scale_med, 0],
[0, scale_max, 0],
[0, 0, scale_min],
[0, 0, scale_med],
[0, 0, scale_max],
],
dtype=np.float32,
)
def _one_hw(x: int) -> np.ndarray:
"""Generate a single group (9) of anchors.
Args:
x (int): minimum shape value (width or height) of the input image
Returns:
np.ndarray: tensor (9, 2) of format (height, width)
"""
sqrt2 = 1.4142135624
ratio_hw = (
(sqrt2, sqrt2 / 2),
(1, 1),
(sqrt2 / 2, sqrt2),
)
return np.matmul(_scale_mat(x), np.array(ratio_hw, dtype=np.float32))
def _hw(h: int, w: int, stride: int) -> np.ndarray:
"""Get (height, width) pair of the feature map.
Args:
h (int): feature map height
w (int): feature map width
stride (int): stride of the backbone e.g. 32
Returns:
np.ndarray: tensor (H, W, 9, 2) of format (height, width)
"""
size_min = min(w * stride, h * stride)
raw_anchors_ = _one_hw(size_min)
return np.tile(raw_anchors_[np.newaxis, np.newaxis, :], (h, w, 1, 1))
def _center_coord(h: int, w: int, stride: int) -> np.ndarray:
"""Center coordinates of each grid cell.
Args:
h (int): feature map height
w (int): feature map width
stride (int): stride of the backbone e.g. 32
Returns:
np.ndarray: tensor (H, W, 9, 2) of format (y, x)
"""
vx, vy = (
np.arange(0, w, dtype=np.float32),
np.arange(0, h, dtype=np.float32),
)
xs, ys = (
vx[np.newaxis, :, np.newaxis],
vy[:, np.newaxis, np.newaxis],
)
xss, yss = (
np.tile(xs, (h, 1, N_ANCHOR)),
np.tile(ys, (1, w, N_ANCHOR)),
)
# (H, W), NOT the other way around
return np.stack([yss, xss], axis=-1) * stride + stride // 2
def get_abs_anchor(
h: int,
w: int,
stride: int,
*,
flat: bool = True,
) -> np.ndarray:
"""Get anchors for ALL grid cells in **ABSOLUTE** coordinates.
Args:
h (int): feature map height
w (int): feature map width
stride (int): stride of the backbone e.g. 32
flat (bool, optional): flatten the output tensor. Defaults to True.
Returns:
np.ndarray: anchor in absolute coordinates (y_min, x_min, y_max, x_max)
- if flat: (N_ALL_AC, 4)
- else: (H_FM, W_FM, 9, 4)
"""
hw_half = 0.5 * _hw(h, w, stride)
coords = _center_coord(h, w, stride)
ac = np.concatenate([coords - hw_half, coords + hw_half], axis=-1)
if flat:
return np.reshape(ac, (-1, 4))
return ac
def get_rel_anchor(
h: int,
w: int,
stride: int,
*,
flat: bool = True,
) -> np.ndarray:
"""Get anchors for ALL grid cells in **RELATIVE** coordinates.
Args:
h (int): image height in pixels
w (int): image width in pixels
stride (int): stride of the backbone e.g. 32
flat (bool, optional): flatten the output tensor. Defaults to True.
Returns:
np.ndarray: anchor of format (y_min, x_min, y_max, x_max)
- if flat: (N_ALL_AC, 4)
- else: (H_FM, W_FM, 9, 4)
"""
_mat_trans = np.array([
[1. / h, 0., 0., 0.],
[0., 1. / w, 0., 0.],
[0., 0., 1. / h, 0.],
[0., 0., 0., 1. / w],
]) # matrix to convert absolute coordinates to relative ones
ac_abs = get_abs_anchor(h // stride, w // stride, stride, flat=flat)
return np.matmul(ac_abs, _mat_trans)
# flattened relative anchors (N_ALL_AC, 4) including invalid ones
anchors_raw = get_rel_anchor(H, W, STRIDE, flat=True)
# valid anchors mask based on image size of type np.float32 (N_ALL_AC,)
BOUND_AC_LO = -0.2
BOUND_AC_HI = 1.2
MASK_RPNAC = np.where(
(anchors_raw[..., 0] >= BOUND_AC_LO) & # y_min >= lo
(anchors_raw[..., 1] >= BOUND_AC_LO) & # x_min >= lo
(anchors_raw[..., 2] <= BOUND_AC_HI) & # y_max <= hi
(anchors_raw[..., 3] <= BOUND_AC_HI) & # x_max <= hi
(anchors_raw[..., 2] > anchors_raw[..., 0]) & # y_max > y_min
(anchors_raw[..., 3] > anchors_raw[..., 1]), # x_max > x_min
1.0,
0.0,
)
# number of valid anchors (1384)
N_VAL_AC = int(MASK_RPNAC.sum())
RPNAC = anchors_raw[MASK_RPNAC == 1] # valid anchors (N_VAL_AC, 4)
# set as read-only
RPNAC.flags.writeable = False
MASK_RPNAC.flags.writeable = False
# valid anchors
AC_VAL = tf.constant(RPNAC, dtype=tf.float32) # (N_VAL_AC, 4)
# =============================================================================
# SECTION: YXYX bounding boxes.
#
# shape: (N1, N2, ..., Nk, C) where C >= 4
#
# format: (y_min, x_min, y_max, x_max, objectness score, ...)
# =============================================================================
def xmin(bbox: TensorT) -> TensorT:
"""Get top-left x coordinate of each anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: X-min tensor of shape (N1, N2, ..., Nk)
"""
return bbox[..., 1]
def ymin(bbox: TensorT) -> TensorT:
"""Get top-left y coordinate of each anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: Y-min tensor of shape (N1, N2, ..., Nk)
"""
return bbox[..., 0]
def xmax(bbox: TensorT) -> TensorT:
"""Get bottom-right x coordinate of each anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: X-max tensor of shape (N1, N2, ..., Nk)
"""
return bbox[..., 3]
def ymax(bbox: TensorT) -> TensorT:
"""Get bottom-right y coordinate of each anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: Y-max tensor of shape (N1, N2, ..., Nk)
"""
return bbox[..., 2]
def rem(bbox: TensorT) -> TensorT:
"""Get remainders (excluding YXYX) of each anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: remainder tensor of shape (N1, N2, ..., Nk, C - 4)
"""
return bbox[..., 4:]
def w(bbox: TensorT) -> TensorT:
"""Get width of each anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: width tensor of shape (N1, N2, ..., Nk)
"""
return xmax(bbox) - xmin(bbox)
def h(bbox: TensorT) -> TensorT:
"""Get height of each anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: height tensor of shape (N1, N2, ..., Nk)
"""
return ymax(bbox) - ymin(bbox)
def xctr(bbox: TensorT) -> TensorT:
"""Get center x coordinate of each anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: X-center tensor of shape (N1, N2, ..., Nk)
"""
return xmin(bbox) + 0.5 * w(bbox)
def yctr(bbox: TensorT) -> TensorT:
"""Get center y coordinate of each anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: Y-center tensor of shape (N1, N2, ..., Nk)
"""
return ymin(bbox) + 0.5 * h(bbox)
def area(bbox: TensorT) -> TensorT:
"""Get area of the anchor box.
Args:
bbox (TensorT): Bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: area tensor of shape (N1, N2, ..., Nk)
"""
return h(bbox) * w(bbox)
def pmax(bbox: TensorT) -> TensorT:
"""Get bottom-right point from each anchor box.
Args:
bbox (TensorT): bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: YX-max tensor of shape (N1, N2, ..., Nk, 2)
"""
return bbox[..., 2:4]
def pmin(bbox: TensorT) -> TensorT:
"""Get top-left point from each anchor box.
Args:
bbox (TensorT): bounding box tensor of shape (N1, N2, ..., Nk, C)
Returns:
TensorT: YX-min tensor of shape (N1, N2, ..., Nk, 2)
"""
return bbox[..., 0:2]
def interarea(bbox_prd: tf.Tensor, bbox_lbl: tf.Tensor) -> tf.Tensor:
"""Get intersection area of two sets of anchor boxes.
Args:
bbox_prd (tf.Tensor): predicted bounding box tensor of shape
(N1, N2, ..., Nk, C)
bbox_lbl (tf.Tensor): label bounding box tensor of shape
(N1, N2, ..., Nk, C)
Returns:
tf.Tensor: intersection area tensor of shape (N1, N2, ..., Nk)
"""
left_ups = tf.maximum(pmin(bbox_prd), pmin(bbox_lbl))
right_downs = tf.minimum(pmax(bbox_prd), pmax(bbox_lbl))
inter = tf.maximum(right_downs - left_ups, 0.0)
return tf.multiply(inter[..., 0], inter[..., 1])
def iou(bbox_prd: tf.Tensor, bbox_lbl: tf.Tensor) -> tf.Tensor:
"""Calculate IoU of two bounding boxes.
Args:
bbox_prd (tf.Tensor): predicted bounding box tensor of shape
(N1, N2, ..., Nk, C)
bbox_lbl (tf.Tensor): label bounding box tensor of shape
(N1, N2, ..., Nk, C)
Returns:
tf.Tensor: IoU tensor of shape (N1, N2, ..., Nk)
"""
area_inter = interarea(bbox_prd, bbox_lbl)
area_inter = tf.maximum(area_inter, 0.0)
area_pred = area(bbox_prd)
area_label = area(bbox_lbl)
area_union = area_pred + area_label - area_inter
return area_inter / (area_union + EPS)
def iou_mat(bbox_prd: tf.Tensor, bbox_lbl: tf.Tensor) -> tf.Tensor:
"""Calculate IoU matrix of two sets of bounding boxes.
Args:
bbox_prd (tf.Tensor): predicted bounding boxes of shape (N1, C)
bbox_lbl (tf.Tensor): ground truth bounding boxes of shape (N2, C)
Returns:
tf.Tensor: IoU tensor of shape (N1, N2)
"""
n1, n2 = tf.shape(bbox_prd)[0], tf.shape(bbox_lbl)[0]
# convert to shape (N1, N2, C)
bbox_prd_ = tf.tile(tf.expand_dims(bbox_prd, axis=1), [1, n2, 1])
bbox_lbl_ = tf.tile(tf.expand_dims(bbox_lbl, axis=0), [n1, 1, 1])
return iou(bbox_prd_, bbox_lbl_)
# TODO: GIoU
def iou_batch(bbox_prd: tf.Tensor, bbox_lbl: tf.Tensor) -> tf.Tensor:
"""Calculate IoU matrix for each batch of two sets of bounding boxes.
Args:
bbox_prd (tf.Tensor): predicted bounding boxes of shape (B, N1, C)
bbox_lbl (tf.Tensor): ground truth bounding boxes of shape (B, N2, C)
Returns:
tf.Tensor: IoU tensor of shape (B, N1, N2)
"""
n1, n2 = tf.shape(bbox_prd)[1], tf.shape(bbox_lbl)[1]
# convert to shape (B, N1, N2, C)
bbox_prd_ = tf.tile(tf.expand_dims(bbox_prd, axis=2), [1, 1, n2, 1])
bbox_lbl_ = tf.tile(tf.expand_dims(bbox_lbl, axis=1), [1, n1, 1, 1])
return iou(bbox_prd_, bbox_lbl_)
def from_xywh(xywh: tf.Tensor) -> tf.Tensor:
"""Convert bounding box from (x, y, w, h) to (ymin, xmin, ymax, xmax).
Args:
xywh (tf.Tensor): bounding box tensor (XYWH) of shape
(N1, N2, ..., Nk, 4)
Returns:
tf.Tensor: bounding box tensor (YXYX) of shape (N1, N2, ..., Nk, 4)
"""
x_, y_, w_, h_ = (xywh[..., 0], xywh[..., 1], xywh[..., 2], xywh[..., 3])
xmin_ = x_ - 0.5 * w_
ymin_ = y_ - 0.5 * h_
xmax_ = x_ + 0.5 * w_
ymax_ = y_ + 0.5 * h_
return tf.stack([ymin_, xmin_, ymax_, xmax_], axis=-1)
def to_xywh(bbox: tf.Tensor) -> tf.Tensor:
"""Convert bounding box from (ymin, xmin, ymax, xmax) to (x, y, w, h).
Args:
bbox (tf.Tensor): bounding box tensor (YXYX) of shape
(N1, N2, ..., Nk, C) where C >= 4
Returns:
tf.Tensor: bounding box tensor (XYWH) of shape (N1, N2, ..., Nk, C)
"""
xywh = tf.stack([xctr(bbox), yctr(bbox), w(bbox), h(bbox)], axis=-1)
return tf.concat([xywh, rem(bbox)], axis=-1)
def clip(bbox: tf.Tensor, h: float, w: float) -> tf.Tensor:
"""Clip bounding box to a given image shape.
Args:
bbox (tf.Tensor): bounding box tensor (YXYX) of shape
(N1, N2, ..., Nk, C) where C >= 4
h (float): image height
w (float): image width
Returns:
tf.Tensor: clipped bounding box tensor (YXYX) of shape
(N1, N2, ..., Nk, C)
"""
ymin_ = tf.clip_by_value(ymin(bbox), 0.0, h)
xmin_ = tf.clip_by_value(xmin(bbox), 0.0, w)
ymax_ = tf.clip_by_value(ymax(bbox), 0.0, h)
xmax_ = tf.clip_by_value(xmax(bbox), 0.0, w)
yxyx = tf.stack([ymin_, xmin_, ymax_, xmax_], axis=-1)
return tf.concat([yxyx, rem(bbox)], axis=-1)
# =============================================================================
# SECTION: Delta
# =============================================================================
def dx(delta: tf.Tensor) -> tf.Tensor:
"""Get delta x coordinate of each delta box.
Args:
delta (tf.Tensor): delta tensor of shape (H, W, 9, 4)
Returns:
tf.Tensor: delta tensor of shape (H, W, 9)
"""
return delta[..., 0]
def dy(delta: tf.Tensor) -> tf.Tensor:
"""Get delta y coordinate of each delta box.
Args:
delta (tf.Tensor): delta tensor of shape (H, W, 9, 4)
Returns:
tf.Tensor: delta tensor of shape (H, W, 9)
"""
return delta[..., 1]
def dw(delta: tf.Tensor) -> tf.Tensor:
"""Get delta width of each delta box.
Args:
delta (tf.Tensor): delta tensor of shape (H, W, 9, 4)
Returns:
tf.Tensor: delta tensor of shape (H, W, 9)
"""
return delta[..., 2]
def dh(delta: tf.Tensor) -> tf.Tensor:
"""Get delta height of each delta box.
Args:
delta (tf.Tensor): delta tensor of shape (H, W, 9, 4)
Returns:
tf.Tensor: delta tensor of shape (H, W, 9)
"""
return delta[..., 3]
# =============================================================================
# SECTION: utils
# =============================================================================
def delta2bbox(base: tf.Tensor, diff: tf.Tensor) -> tf.Tensor:
"""Apply delta to anchors to get bbox.
e.g.: anchor (base) + delta (diff) = predicted (bbox)
Args:
base (tf.Tensor): base bbox tensor of shape (N1, N2, ..., Nk, 4)
diff (tf.Tensor): delta tensor of shape (N1, N2, ..., Nk, 4)
Returns:
tf.Tensor: bbox tensor of shape (N1, N2, ..., Nk, 4)
"""
xctr_ = xctr(base) + w(base) * dx(diff)
yctr_ = yctr(base) + h(base) * dy(diff)
w_ = w(base) * tf.exp(dw(diff))
h_ = h(base) * tf.exp(dh(diff))
xywh_ = tf.stack([xctr_, yctr_, w_, h_], axis=-1)
return from_xywh(xywh_)
def bbox2delta(bbox_l: tf.Tensor, bbox_r: tf.Tensor) -> tf.Tensor:
"""Compute delta between two bounding boxes.
e.g.:
- GT (bbox_l) - anchor (bbox_r) = RPN target (delta)
- GT (bbox_l) - ROI (bbox_r) = RCNN target (delta)
Args:
bbox_l (tf.Tensor): minuend bbox tensor (left operand) of shape
(N1, N2, ..., Nk, C), where C >= 4
bbox_r (tf.Tensor): subtrahend bbox tensor (right operand) of shape
(N1, N2, ..., Nk, C), where C >= 4
Returns:
tf.Tensor: delta tensor of shape (N1, N2, ..., Nk, 4)
"""
xctr_r = xctr(bbox_r)
yctr_r = yctr(bbox_r)
w_r = tf.math.maximum(w(bbox_r), EPS)
h_r = tf.math.maximum(h(bbox_r), EPS)
xctr_l = xctr(bbox_l)
yctr_l = yctr(bbox_l)
w_l = w(bbox_l)
h_l = h(bbox_l)
bx_del = tf.stack(
[
(xctr_l - xctr_r) / w_r,
(yctr_l - yctr_r) / h_r,
tf.math.log(w_l / w_r),
tf.math.log(h_l / h_r),
],
axis=-1,
)
return tf.clip_by_value(bx_del, -10.0, 10.0)
# =============================================================================
# SECTION: RPN training utils
# =============================================================================
def sample_mask(mask: tf.Tensor, num: int) -> tf.Tensor:
"""Sample `num` anchors from `mask` for a batch of images.
TODO: more precise sampling
Args:
mask (tf.Tensor): 0/1 mask of anchors (B, N_ac)
num (int): number of positive anchors to sample
Returns:
tf.Tensor: 0/1 mask of anchors (B, N_ac)