|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +from matplotlib.colors import LinearSegmentedColormap |
| 4 | + |
| 5 | + |
| 6 | +def dilated_conv_one_pixel(center: (int, int), |
| 7 | + feature_map: np.ndarray, |
| 8 | + k: int = 3, |
| 9 | + r: int = 1, |
| 10 | + v: int = 1): |
| 11 | + """ |
| 12 | + 膨胀卷积核中心在指定坐标center处时,统计哪些像素被利用到, |
| 13 | + 并在利用到的像素位置处加上增量v |
| 14 | + Args: |
| 15 | + center: 膨胀卷积核中心的坐标 |
| 16 | + feature_map: 记录每个像素使用次数的特征图 |
| 17 | + k: 膨胀卷积核的kernel大小 |
| 18 | + r: 膨胀卷积的dilation rate |
| 19 | + v: 使用次数增量 |
| 20 | + """ |
| 21 | + assert divmod(3, 2)[1] == 1 |
| 22 | + |
| 23 | + # left-top: (x, y) |
| 24 | + left_top = (center[0] - ((k - 1) // 2) * r, center[1] - ((k - 1) // 2) * r) |
| 25 | + for i in range(k): |
| 26 | + for j in range(k): |
| 27 | + feature_map[left_top[1] + i * r][left_top[0] + j * r] += v |
| 28 | + |
| 29 | + |
| 30 | +def dilated_conv_all_map(dilated_map: np.ndarray, |
| 31 | + k: int = 3, |
| 32 | + r: int = 1): |
| 33 | + """ |
| 34 | + 根据输出特征矩阵中哪些像素被使用以及使用次数, |
| 35 | + 配合膨胀卷积k和r计算输入特征矩阵哪些像素被使用以及使用次数 |
| 36 | + Args: |
| 37 | + dilated_map: 记录输出特征矩阵中每个像素被使用次数的特征图 |
| 38 | + k: 膨胀卷积核的kernel大小 |
| 39 | + r: 膨胀卷积的dilation rate |
| 40 | + """ |
| 41 | + new_map = np.zeros_like(dilated_map) |
| 42 | + for i in range(dilated_map.shape[0]): |
| 43 | + for j in range(dilated_map.shape[1]): |
| 44 | + if dilated_map[i][j] > 0: |
| 45 | + dilated_conv_one_pixel((j, i), new_map, k=k, r=r, v=dilated_map[i][j]) |
| 46 | + |
| 47 | + return new_map |
| 48 | + |
| 49 | + |
| 50 | +def plot_map(matrix: np.ndarray): |
| 51 | + plt.figure() |
| 52 | + |
| 53 | + c_list = ['white', 'blue', 'red'] |
| 54 | + new_cmp = LinearSegmentedColormap.from_list('chaos', c_list) |
| 55 | + plt.imshow(matrix, cmap=new_cmp) |
| 56 | + |
| 57 | + ax = plt.gca() |
| 58 | + ax.set_xticks(np.arange(-0.5, matrix.shape[1], 1), minor=True) |
| 59 | + ax.set_yticks(np.arange(-0.5, matrix.shape[0], 1), minor=True) |
| 60 | + |
| 61 | + # 显示color bar |
| 62 | + plt.colorbar() |
| 63 | + |
| 64 | + # 在图中标注数量 |
| 65 | + thresh = 5 |
| 66 | + for x in range(matrix.shape[1]): |
| 67 | + for y in range(matrix.shape[0]): |
| 68 | + # 注意这里的matrix[y, x]不是matrix[x, y] |
| 69 | + info = int(matrix[y, x]) |
| 70 | + ax.text(x, y, info, |
| 71 | + verticalalignment='center', |
| 72 | + horizontalalignment='center', |
| 73 | + color="white" if info > thresh else "black") |
| 74 | + ax.grid(which='minor', color='black', linestyle='-', linewidth=1.5) |
| 75 | + plt.show() |
| 76 | + plt.close() |
| 77 | + |
| 78 | + |
| 79 | +def main(): |
| 80 | + # bottom to top |
| 81 | + dilated_rates = [1, 2, 3] |
| 82 | + # init feature map |
| 83 | + size = 31 |
| 84 | + m = np.zeros(shape=(size, size), dtype=np.int32) |
| 85 | + center = size // 2 |
| 86 | + m[center][center] = 1 |
| 87 | + # print(m) |
| 88 | + # plot_map(m) |
| 89 | + |
| 90 | + for index, dilated_r in enumerate(dilated_rates[::-1]): |
| 91 | + new_map = dilated_conv_all_map(m, r=dilated_r) |
| 92 | + m = new_map |
| 93 | + print(m) |
| 94 | + plot_map(m) |
| 95 | + |
| 96 | + |
| 97 | +if __name__ == '__main__': |
| 98 | + main() |
0 commit comments