Skip to content

Commit 272d676

Browse files
author
wz
committed
add drawing dilated convolution pixels to be used
1 parent 1b2e686 commit 272d676

File tree

1 file changed

+98
-0
lines changed
  • others_project/draw_dilated_conv

1 file changed

+98
-0
lines changed
+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from matplotlib.colors import LinearSegmentedColormap
4+
5+
6+
def dilated_conv_one_pixel(center: (int, int),
7+
feature_map: np.ndarray,
8+
k: int = 3,
9+
r: int = 1,
10+
v: int = 1):
11+
"""
12+
膨胀卷积核中心在指定坐标center处时,统计哪些像素被利用到,
13+
并在利用到的像素位置处加上增量v
14+
Args:
15+
center: 膨胀卷积核中心的坐标
16+
feature_map: 记录每个像素使用次数的特征图
17+
k: 膨胀卷积核的kernel大小
18+
r: 膨胀卷积的dilation rate
19+
v: 使用次数增量
20+
"""
21+
assert divmod(3, 2)[1] == 1
22+
23+
# left-top: (x, y)
24+
left_top = (center[0] - ((k - 1) // 2) * r, center[1] - ((k - 1) // 2) * r)
25+
for i in range(k):
26+
for j in range(k):
27+
feature_map[left_top[1] + i * r][left_top[0] + j * r] += v
28+
29+
30+
def dilated_conv_all_map(dilated_map: np.ndarray,
31+
k: int = 3,
32+
r: int = 1):
33+
"""
34+
根据输出特征矩阵中哪些像素被使用以及使用次数,
35+
配合膨胀卷积k和r计算输入特征矩阵哪些像素被使用以及使用次数
36+
Args:
37+
dilated_map: 记录输出特征矩阵中每个像素被使用次数的特征图
38+
k: 膨胀卷积核的kernel大小
39+
r: 膨胀卷积的dilation rate
40+
"""
41+
new_map = np.zeros_like(dilated_map)
42+
for i in range(dilated_map.shape[0]):
43+
for j in range(dilated_map.shape[1]):
44+
if dilated_map[i][j] > 0:
45+
dilated_conv_one_pixel((j, i), new_map, k=k, r=r, v=dilated_map[i][j])
46+
47+
return new_map
48+
49+
50+
def plot_map(matrix: np.ndarray):
51+
plt.figure()
52+
53+
c_list = ['white', 'blue', 'red']
54+
new_cmp = LinearSegmentedColormap.from_list('chaos', c_list)
55+
plt.imshow(matrix, cmap=new_cmp)
56+
57+
ax = plt.gca()
58+
ax.set_xticks(np.arange(-0.5, matrix.shape[1], 1), minor=True)
59+
ax.set_yticks(np.arange(-0.5, matrix.shape[0], 1), minor=True)
60+
61+
# 显示color bar
62+
plt.colorbar()
63+
64+
# 在图中标注数量
65+
thresh = 5
66+
for x in range(matrix.shape[1]):
67+
for y in range(matrix.shape[0]):
68+
# 注意这里的matrix[y, x]不是matrix[x, y]
69+
info = int(matrix[y, x])
70+
ax.text(x, y, info,
71+
verticalalignment='center',
72+
horizontalalignment='center',
73+
color="white" if info > thresh else "black")
74+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1.5)
75+
plt.show()
76+
plt.close()
77+
78+
79+
def main():
80+
# bottom to top
81+
dilated_rates = [1, 2, 3]
82+
# init feature map
83+
size = 31
84+
m = np.zeros(shape=(size, size), dtype=np.int32)
85+
center = size // 2
86+
m[center][center] = 1
87+
# print(m)
88+
# plot_map(m)
89+
90+
for index, dilated_r in enumerate(dilated_rates[::-1]):
91+
new_map = dilated_conv_all_map(m, r=dilated_r)
92+
m = new_map
93+
print(m)
94+
plot_map(m)
95+
96+
97+
if __name__ == '__main__':
98+
main()

0 commit comments

Comments
 (0)