-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoof.py
107 lines (78 loc) · 3.05 KB
/
oof.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import time
import cv2
import matplotlib.pyplot as plt
import numpy as np
from omegaconf import OmegaConf
def _softmax(x: np.ndarray, temp: float = 1.0, axis: int = 0):
x = x / temp
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=axis, keepdims=True)
def plot_weights(weights):
n = weights.shape[0]
ncols = 4
nrows = (n // ncols) + (n % ncols > 0)
ir, ic = np.unravel_index(np.arange(n), (nrows, ncols))
W, H = plt.figaspect(nrows / ncols)
fig = plt.figure(figsize=(W * 2, H * 2), layout="constrained")
gs = fig.add_gridspec(nrows, ncols, wspace=0, hspace=0)
for idx, i, j in zip(range(n), ir, ic):
ax = fig.add_subplot(gs[i, j])
ax.imshow(weights[idx])
def weights_from_color(cfg: OmegaConf, imgs: np.ndarray, weights: np.ndarray):
imgs_lab = np.stack([cv2.cvtColor(i, cv2.COLOR_BGR2LAB) for i in imgs], 0)
tgt = np.array(cfg.color.target_lab)
d = np.linalg.norm(imgs_lab - tgt.reshape(1, 1, 1, 3), axis=-1, keepdims=True)
w = _softmax(-d, temp=cfg.color.T, axis=0)
return w
def weights_from_outlier(cfg: OmegaConf, imgs: np.ndarray, weights: np.ndarray):
imgs_gray = np.stack([cv2.cvtColor(i, cv2.COLOR_BGR2GRAY) for i in imgs], 0)
imgs_gray = np.expand_dims(imgs_gray, -1)
mean = imgs_gray.mean(0, keepdims=True)
var = (imgs_gray - mean) ** 2
w = _softmax(var, temp=cfg.outlier.T, axis=0)
return w
def weights_from_blend_masks(cfg: OmegaConf, imgs: np.ndarray, weights: np.ndarray):
w = weights / (weights.sum(0, keepdims=True) + 1e-12)
return w
def main():
cfg = OmegaConf.merge(OmegaConf.load("oof.yml"), OmegaConf.from_cli())
data = np.load(cfg.rawpath)
imgs = data["imgs"]
imgs = imgs.astype(np.float32) / 255
weights = data["weights"]
mode_to_fn = {
"color": weights_from_color,
"outlier": weights_from_outlier,
"baseline": weights_from_blend_masks,
}
w = mode_to_fn[cfg.weight_filter](cfg, imgs, weights)
w = np.where(w < cfg.integrate.min_weight, 0.0, w)
# plot_weights(w)
out = ((w * imgs).sum(0) * 255).astype(np.uint8)
fs = plt.figaspect(out.shape[0] / out.shape[1])
fig, ax = plt.subplots(figsize=(fs[0] * 2, fs[1] * 2))
ax.imshow(out[..., ::-1], origin="upper")
axins = ax.inset_axes(
[0.6, 0.1, 0.3, 0.3],
xlim=(700, 780),
ylim=(650, 730),
xticklabels=[],
yticklabels=[],
)
axins.imshow(out[..., ::-1], origin="upper")
ax.indicate_inset_zoom(axins, edgecolor="black")
ax.set_aspect("equal")
now = time.strftime("%Y%m%d-%H%M%S")
fig.tight_layout()
fig.savefig(f"tmp/oof-{cfg.weight_filter}-{now}.png", dpi=300)
fig, ax = plt.subplots(figsize=(fs[0] * 2, fs[1] * 2))
ax.imshow(w.max(0), origin="upper")
ax.set_aspect("equal")
now = time.strftime("%Y%m%d-%H%M%S")
fig.tight_layout()
fig.savefig(f"tmp/oof-weights-{cfg.weight_filter}-{now}.png", dpi=300)
if cfg.show:
plt.show()
if __name__ == "__main__":
# python oof.py rawpath=tmp\stitch-20241011-193007.npz
main()