|
5 | 5 |
|
6 | 6 |
|
7 | 7 | class BroadcastMulBench(benchmark.Benchmark):
|
8 |
| - def __init__(self, mode, device, case, M, N, K): |
9 |
| - super().__init__(mode, device) |
| 8 | + def __init__(self, mode, device, dtype, case, M, N, K): |
| 9 | + super().__init__(mode, device, dtype) |
10 | 10 | self.case = case
|
11 | 11 | self.M = M
|
12 | 12 | self.N = N
|
13 | 13 | self.K = K
|
14 | 14 |
|
15 | 15 | if case == "row":
|
16 | 16 | self.d1 = self.rand(
|
17 |
| - [M, N, 1], device=device, requires_grad=self.requires_grad |
| 17 | + [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad |
18 | 18 | )
|
19 | 19 | self.d2 = self.rand(
|
20 |
| - [M, 1, K], device=device, requires_grad=self.requires_grad |
| 20 | + [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad |
21 | 21 | )
|
22 | 22 | elif case == "mid":
|
23 | 23 | self.d1 = self.rand(
|
24 |
| - [M, N, 1], device=device, requires_grad=self.requires_grad |
| 24 | + [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad |
25 | 25 | )
|
26 | 26 | self.d2 = self.rand(
|
27 |
| - [1, N, K], device=device, requires_grad=self.requires_grad |
| 27 | + [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad |
28 | 28 | )
|
29 | 29 | elif case == "col":
|
30 | 30 | self.d1 = self.rand(
|
31 |
| - [M, 1, K], device=device, requires_grad=self.requires_grad |
| 31 | + [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad |
32 | 32 | )
|
33 | 33 | self.d2 = self.rand(
|
34 |
| - [1, N, K], device=device, requires_grad=self.requires_grad |
| 34 | + [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad |
35 | 35 | )
|
36 | 36 | else:
|
37 | 37 | raise ValueError("invalid case: %s" % (case))
|
@@ -60,52 +60,52 @@ def memory_workload(self):
|
60 | 60 | sol_count = (1) + (1)
|
61 | 61 | algorithmic_count = 1 + (1 + 1)
|
62 | 62 |
|
63 |
| - buffer_size = self.M * self.N * self.K * 4 |
| 63 | + buffer_size = self.M * self.N * self.K |
64 | 64 | return {
|
65 | 65 | "sol": buffer_size * sol_count,
|
66 | 66 | "algorithmic": buffer_size * algorithmic_count,
|
67 | 67 | }
|
68 | 68 |
|
69 | 69 |
|
70 | 70 | class BroadcastRowBench(BroadcastMulBench):
|
71 |
| - def __init__(self, mode, device, M, N, K): |
72 |
| - super(BroadcastRowBench, self).__init__(mode, device, "row", M, N, K) |
| 71 | + def __init__(self, mode, device, dtype, M, N, K): |
| 72 | + super(BroadcastRowBench, self).__init__(mode, device, dtype, "row", M, N, K) |
73 | 73 |
|
74 | 74 | @staticmethod
|
75 | 75 | def module():
|
76 | 76 | return "broadcast_row"
|
77 | 77 |
|
78 | 78 |
|
79 | 79 | class BroadcastMidBench(BroadcastMulBench):
|
80 |
| - def __init__(self, mode, device, M, N, K): |
81 |
| - super(BroadcastMidBench, self).__init__(mode, device, "mid", M, N, K) |
| 80 | + def __init__(self, mode, device, dtype, M, N, K): |
| 81 | + super(BroadcastMidBench, self).__init__(mode, device, dtype, "mid", M, N, K) |
82 | 82 |
|
83 | 83 | @staticmethod
|
84 | 84 | def module():
|
85 | 85 | return "broadcast_mid"
|
86 | 86 |
|
87 | 87 |
|
88 | 88 | class BroadcastColBench(BroadcastMulBench):
|
89 |
| - def __init__(self, mode, device, M, N, K): |
90 |
| - super(BroadcastColBench, self).__init__(mode, device, "col", M, N, K) |
| 89 | + def __init__(self, mode, device, dtype, M, N, K): |
| 90 | + super(BroadcastColBench, self).__init__(mode, device, dtype, "col", M, N, K) |
91 | 91 |
|
92 | 92 | @staticmethod
|
93 | 93 | def module():
|
94 | 94 | return "broadcast_col"
|
95 | 95 |
|
96 | 96 |
|
97 | 97 | class BroadcastThreeArgs(benchmark.Benchmark):
|
98 |
| - def __init__(self, mode, device, M, N, K, L): |
99 |
| - super().__init__(mode, device) |
| 98 | + def __init__(self, mode, device, dtype, M, N, K, L): |
| 99 | + super().__init__(mode, device, dtype) |
100 | 100 | self.M = M
|
101 | 101 | self.N = N
|
102 | 102 | self.K = K
|
103 | 103 | self.L = L
|
104 | 104 |
|
105 |
| - self.d1 = self.rand([M, N], device=device, requires_grad=self.requires_grad) |
106 |
| - self.d2 = self.rand([K, M, 1], device=device, requires_grad=self.requires_grad) |
| 105 | + self.d1 = self.rand([M, N], device=device, dtype=dtype, requires_grad=self.requires_grad) |
| 106 | + self.d2 = self.rand([K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad) |
107 | 107 | self.d3 = self.rand(
|
108 |
| - [L, K, 1, 1], device=device, requires_grad=self.requires_grad |
| 108 | + [L, K, 1, 1], device=device, dtype=dtype, requires_grad=self.requires_grad |
109 | 109 | )
|
110 | 110 |
|
111 | 111 | self.inputs = [self.d1, self.d2, self.d3]
|
@@ -160,15 +160,15 @@ class BroadcastBench(benchmark.Benchmark):
|
160 | 160 | unary_op_np_func = None
|
161 | 161 | split_input = True
|
162 | 162 |
|
163 |
| - def __init__(self, mode, device, M, N, K): |
164 |
| - super().__init__(mode, device) |
| 163 | + def __init__(self, mode, device, dtype, M, N, K): |
| 164 | + super().__init__(mode, device, dtype) |
165 | 165 | self.M = M
|
166 | 166 | self.N = N
|
167 | 167 | self.K = K
|
168 |
| - self.d1 = self.rand([M, N], device=device, requires_grad=self.requires_grad) |
169 |
| - self.d2 = self.rand([K, 1, N], device=device, requires_grad=self.requires_grad) |
170 |
| - self.d3 = self.rand([M, N], device=device, requires_grad=self.requires_grad) |
171 |
| - self.d4 = self.rand([K, M, 1], device=device, requires_grad=self.requires_grad) |
| 168 | + self.d1 = self.rand([M, N], device=device, dtype=dtype, requires_grad=self.requires_grad) |
| 169 | + self.d2 = self.rand([K, 1, N], device=device, dtype=dtype, requires_grad=self.requires_grad) |
| 170 | + self.d3 = self.rand([M, N], device=device, dtype=dtype, requires_grad=self.requires_grad) |
| 171 | + self.d4 = self.rand([K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad) |
172 | 172 | self.inputs = [self.d1, self.d2, self.d3, self.d4]
|
173 | 173 |
|
174 | 174 | def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
|
|
0 commit comments