Skip to content

Commit 26a91a9

Browse files
kevinstephanofacebook-github-bot
authored andcommittedSep 15, 2020
[WIP][JIT] Add benchmarking support of NV Fuser with FP16 dtype support (pytorch#44101)
Summary: Modified files in `benchmarks/tensorexpr` to add support for NVIDIA's Fuser for the jit compiler. This support has some modifications besides adding an option to support the NVIDIA fuser: * Adds FP16 Datatype support * Fixes SOL/Algo calculations to generally use the data type instead of being fixed to 4 bytes * Adds IR printing and kernel printing knobs * Adds a knob `input_iter` to create ranges of inputs currently only for reductions * Adds further reduction support for Inner and Outer dimension reductions that are compatible with the `input_iter` knob. * Added `simple_element`, `reduce2d_inner`, and `reduce2d_outer` to isolate performance on elementwise and reduction operations in the most minimal fashion. Pull Request resolved: pytorch#44101 Reviewed By: ngimel Differential Revision: D23713658 Pulled By: bertmaher fbshipit-source-id: d6b83cfab559aefe107c23b3c0f2df9923b3adc1
1 parent 2f4c31c commit 26a91a9

14 files changed

+345
-109
lines changed
 

‎benchmarks/tensorexpr/__main__.py

+87-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from . import matmul # noqa: F401
1212
# from . import normalization # noqa: F401
1313
# from . import pooling # noqa: F401
14-
# from . import reduction # noqa: F401
14+
from . import reduction # noqa: F401
1515
# from . import softmax # noqa: F401
1616
from . import rnn_eltwise # noqa: F401
1717
from . import swish # noqa: F401
@@ -45,6 +45,20 @@ def main():
4545
default="fwd,both",
4646
help="a comma separated list of running modes",
4747
)
48+
parser.add_argument(
49+
"--dtype",
50+
type=str,
51+
default="float32",
52+
help="a comma separated list of Data Types: {float32[default], float16}",
53+
)
54+
parser.add_argument(
55+
"--input-iter",
56+
type=str,
57+
default=None,
58+
help="a comma separated list of of Tensor dimensions that includes a start, \
59+
stop, and increment that can be constant or a power of 2 \
60+
{start:stop:inc,start:stop:pow2}",
61+
)
4862
parser.add_argument(
4963
"--engine",
5064
type=str,
@@ -79,14 +93,24 @@ def main():
7993
"--cuda_fuser",
8094
type=str,
8195
default="te",
82-
help="The Cuda fuser backend to use: one of {te, old, none}",
96+
help="The Cuda fuser backend to use: one of {te, nvf, old, none}",
8397
)
8498
parser.add_argument(
8599
"--output",
86100
type=str,
87101
default="stdout",
88102
help="The output format of the benchmark run {stdout[default], json}",
89103
)
104+
parser.add_argument(
105+
"--print-ir",
106+
action='store_true',
107+
help="Print the IR graph of the Fusion.",
108+
)
109+
parser.add_argument(
110+
"--print-kernel",
111+
action='store_true',
112+
help="Print generated kernel(s).",
113+
)
90114

91115
args = parser.parse_args()
92116

@@ -101,7 +125,13 @@ def main():
101125
torch._C._jit_set_profiling_executor(False)
102126
torch._C._jit_set_texpr_fuser_enabled(False)
103127
torch._C._jit_override_can_fuse_on_gpu(True)
104-
128+
elif args.cuda_fuser == "nvf":
129+
import torch
130+
torch._C._jit_set_profiling_executor(True)
131+
torch._C._jit_set_nvfuser_enabled(True)
132+
torch._C._jit_set_profiling_mode(True)
133+
else :
134+
raise ValueError("Undefined fuser: {}".format(args.cuda_fuser))
105135

106136
def set_global_threads(num_threads):
107137
os.environ["OMP_NUM_THREADS"] = str(num_threads)
@@ -133,13 +163,58 @@ def set_global_threads(num_threads):
133163

134164
modes = args.mode.split(",")
135165

166+
datatypes = args.dtype.split(",")
167+
for index, dtype in enumerate(datatypes):
168+
datatypes[index] = getattr(torch, dtype)
169+
if not datatypes[index] :
170+
raise AttributeError("DataType: {} is not valid!".format(dtype))
171+
136172
tensor_engine.set_engine_mode(args.engine)
137173

138174
def run_default_configs(bench_cls, allow_skip=True):
139-
for mode, device, config in itertools.product(
140-
modes, devices, bench_cls.default_configs()
175+
for mode, device, dtype, config in itertools.product(
176+
modes, devices, datatypes, bench_cls.default_configs()
177+
):
178+
bench = bench_cls(mode, device, dtype, *config)
179+
bench.output_type = args.output
180+
bench.jit_mode = args.jit_mode
181+
if not bench.is_supported():
182+
if allow_skip:
183+
continue
184+
else:
185+
raise ValueError(
186+
"attempted to run an unsupported benchmark: %s" % (bench.desc())
187+
)
188+
bench.run(args)
189+
190+
def run_with_input_iter(bench_cls, input_iter, allow_skip=True):
191+
tensor_dim_specs = input_iter.split(',')
192+
tensor_dim_specs = [dim.split(':') for dim in tensor_dim_specs]
193+
194+
configs = []
195+
for start, stop, inc in tensor_dim_specs:
196+
dim_list = []
197+
if inc == 'pow2' :
198+
curr = int(start)
199+
while curr <= int(stop) :
200+
dim_list.append(curr)
201+
curr <<= 1
202+
elif inc == 'pow2+1' :
203+
curr = int(start)
204+
while curr <= int(stop) :
205+
dim_list.append(curr)
206+
curr -= 1
207+
curr <<= 1
208+
curr += 1
209+
else :
210+
dim_list = list(range(int(start), int(stop) + int(inc), int(inc)))
211+
configs.append(dim_list)
212+
configs = itertools.product(*configs)
213+
214+
for mode, device, dtype, config in itertools.product(
215+
modes, devices, datatypes, list(configs)
141216
):
142-
bench = bench_cls(mode, device, *config)
217+
bench = bench_cls(mode, device, dtype, *config)
143218
bench.output_type = args.output
144219
bench.jit_mode = args.jit_mode
145220
if not bench.is_supported():
@@ -163,7 +238,12 @@ def run_default_configs(bench_cls, allow_skip=True):
163238
for bench_cls in benchmark_classes:
164239
if name in bench_cls.module():
165240
match_class_name = True
166-
run_default_configs(bench_cls, allow_skip=True)
241+
if (args.input_iter is not None) and bench_cls.input_iterable() :
242+
run_with_input_iter(bench_cls, args.input_iter, allow_skip=True)
243+
else :
244+
if args.input_iter is not None :
245+
print("WARNING: Incompatible benchmark class called with input_iter arg: {}".format(name))
246+
run_default_configs(bench_cls, allow_skip=True)
167247

168248
if match_class_name:
169249
continue

‎benchmarks/tensorexpr/attention.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,23 @@
77

88

99
class BahdanauAttention(benchmark.Benchmark):
10-
def __init__(self, mode, device, b, t_q, t_k, n):
11-
super().__init__(mode, device)
10+
def __init__(self, mode, device, dtype, b, t_q, t_k, n):
11+
super().__init__(mode, device, dtype)
1212
self.b = b
1313
self.t_q = t_q
1414
self.t_k = t_k
1515
self.n = n
1616
self.att_query = self.rand(
17-
[b, t_q, n], device=device, requires_grad=self.requires_grad
17+
[b, t_q, n], device=device, dtype=dtype, requires_grad=self.requires_grad
1818
)
1919
self.att_keys = self.rand(
20-
[b, t_k, n], device=device, requires_grad=self.requires_grad
20+
[b, t_k, n], device=device, dtype=dtype, requires_grad=self.requires_grad
2121
)
2222
self.normalize_bias = self.rand(
23-
[n], device=device, requires_grad=self.requires_grad
23+
[n], device=device, dtype=dtype, requires_grad=self.requires_grad
2424
)
2525
self.linear_att = self.rand(
26-
[n], device=device, requires_grad=self.requires_grad
26+
[n], device=device, dtype=dtype, requires_grad=self.requires_grad
2727
)
2828
self.inputs = [
2929
self.att_query,

‎benchmarks/tensorexpr/benchmark.py

+49-16
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88

99

1010
class Benchmark(object):
11-
def __init__(self, mode, device):
11+
def __init__(self, mode, device, dtype):
1212
self.mode = mode
1313
self.deterministic = False
1414
self.device = device
15+
self.dtype = dtype
1516
self.output_type = "stdout"
17+
self.print_ir = False
18+
self.print_kernel = False
1619
if mode == "both":
1720
self.requires_grad = True
1821
elif mode == "fwd":
@@ -82,6 +85,14 @@ def compute_workload(self):
8285
"""return the number of scalar operations it takes to finish the tensor op"""
8386
return None
8487

88+
@staticmethod
89+
def input_iterable():
90+
"""A benchmark child class should return true if it utilizes the input iter arg"""
91+
return False
92+
93+
def dtype_to_bytes(self) :
94+
return torch.tensor(0, dtype=self.dtype).element_size()
95+
8596
@staticmethod
8697
def default_configs():
8798
"""return a list of defualt configs for this benchmark"""
@@ -90,8 +101,8 @@ def default_configs():
90101
def is_supported(self):
91102
return True
92103

93-
def rand(self, shape, device=None, requires_grad=False):
94-
v = self.engine.rand(shape, device=device, requires_grad=requires_grad)
104+
def rand(self, shape, device=None, dtype=None, requires_grad=False):
105+
v = self.engine.rand(shape, device=device, dtype=dtype, requires_grad=requires_grad)
95106
if requires_grad:
96107
self.grad_variables.append(v)
97108
return v
@@ -109,16 +120,34 @@ def compute(self):
109120
return self.forward(*self.inputs)
110121

111122
def run(self, args):
112-
torch._C._jit_override_can_fuse_on_gpu(True)
113-
torch._C._jit_set_texpr_fuser_enabled(args.cuda_fuser == "te")
114-
with cuda_pointwise_context(
115-
args.cuda_pointwise_loop_levels,
116-
args.cuda_pointwise_block_count,
117-
args.cuda_pointwise_block_size,
118-
):
119-
return self.run_impl()
120-
121-
def run_impl(self):
123+
self.print_ir = args.print_ir
124+
if args.cuda_fuser == "old" :
125+
torch._C._jit_override_can_fuse_on_gpu(True)
126+
if args.print_kernel :
127+
os.environ['PYTORCH_FUSION_DEBUG'] = '1'
128+
return self.run_impl(True)
129+
elif args.cuda_fuser == "te" :
130+
torch._C._jit_set_texpr_fuser_enabled(True)
131+
with cuda_pointwise_context(
132+
args.cuda_pointwise_loop_levels,
133+
args.cuda_pointwise_block_count,
134+
args.cuda_pointwise_block_size,
135+
):
136+
return self.run_impl(True)
137+
elif args.cuda_fuser == "nvf" :
138+
torch._C._jit_set_nvfuser_enabled(True)
139+
torch._C._jit_set_profiling_executor(True)
140+
torch._C._jit_set_profiling_mode(True)
141+
torch._C._jit_override_can_fuse_on_cpu(False)
142+
torch._C._jit_override_can_fuse_on_gpu(False)
143+
torch._C._jit_set_bailout_depth(20)
144+
if args.print_kernel :
145+
os.environ['PYTORCH_CUDA_FUSER_DEBUG'] = '1'
146+
return self.run_impl(True)
147+
else :
148+
return self.run_impl(False)
149+
150+
def run_impl(self, use_fuser):
122151
warmups = 10
123152
if self.device == "cuda":
124153
iters = 1000
@@ -134,14 +163,18 @@ def run_impl(self):
134163
time_start = time.time()
135164

136165
if i == 0:
137-
if self.jit_mode == "trace":
166+
if self.jit_mode == "trace" and use_fuser :
138167
self.bm_jit = torch.jit.trace(
139168
self.forward, example_inputs=self.inputs, check_trace=False
140169
)
141170
if callable(getattr(self, "reference", None)):
142171
self.check()
143172
else:
144173
print("Warning: no reference result for ", self.module())
174+
elif i == 1:
175+
# The fusion graph is visible after the first iter is executed
176+
if self.jit_mode == "trace" and use_fuser and self.print_ir :
177+
print(self.bm_jit.graph_for(*self.inputs))
145178
z = self.compute()
146179
if self.mode == "both":
147180
if self.result_grad is None:
@@ -159,8 +192,8 @@ def run_impl(self):
159192
result_dict = {
160193
"desc": self.desc(),
161194
"us": iter_time * 1e6,
162-
"sol": memory_workload["sol"] / iter_time / 1e9,
163-
"algorithmic": memory_workload["algorithmic"] / iter_time / 1e9,
195+
"sol": memory_workload["sol"] * self.dtype_to_bytes() / iter_time / 1e9,
196+
"algorithmic": memory_workload["algorithmic"] * self.dtype_to_bytes() / iter_time / 1e9,
164197
}
165198
if compute_workload:
166199
result_dict["compute_workload"] = compute_workload / iter_time / 1e9

‎benchmarks/tensorexpr/broadcast.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,33 @@
55

66

77
class BroadcastMulBench(benchmark.Benchmark):
8-
def __init__(self, mode, device, case, M, N, K):
9-
super().__init__(mode, device)
8+
def __init__(self, mode, device, dtype, case, M, N, K):
9+
super().__init__(mode, device, dtype)
1010
self.case = case
1111
self.M = M
1212
self.N = N
1313
self.K = K
1414

1515
if case == "row":
1616
self.d1 = self.rand(
17-
[M, N, 1], device=device, requires_grad=self.requires_grad
17+
[M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
1818
)
1919
self.d2 = self.rand(
20-
[M, 1, K], device=device, requires_grad=self.requires_grad
20+
[M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
2121
)
2222
elif case == "mid":
2323
self.d1 = self.rand(
24-
[M, N, 1], device=device, requires_grad=self.requires_grad
24+
[M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
2525
)
2626
self.d2 = self.rand(
27-
[1, N, K], device=device, requires_grad=self.requires_grad
27+
[1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
2828
)
2929
elif case == "col":
3030
self.d1 = self.rand(
31-
[M, 1, K], device=device, requires_grad=self.requires_grad
31+
[M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
3232
)
3333
self.d2 = self.rand(
34-
[1, N, K], device=device, requires_grad=self.requires_grad
34+
[1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
3535
)
3636
else:
3737
raise ValueError("invalid case: %s" % (case))
@@ -60,52 +60,52 @@ def memory_workload(self):
6060
sol_count = (1) + (1)
6161
algorithmic_count = 1 + (1 + 1)
6262

63-
buffer_size = self.M * self.N * self.K * 4
63+
buffer_size = self.M * self.N * self.K
6464
return {
6565
"sol": buffer_size * sol_count,
6666
"algorithmic": buffer_size * algorithmic_count,
6767
}
6868

6969

7070
class BroadcastRowBench(BroadcastMulBench):
71-
def __init__(self, mode, device, M, N, K):
72-
super(BroadcastRowBench, self).__init__(mode, device, "row", M, N, K)
71+
def __init__(self, mode, device, dtype, M, N, K):
72+
super(BroadcastRowBench, self).__init__(mode, device, dtype, "row", M, N, K)
7373

7474
@staticmethod
7575
def module():
7676
return "broadcast_row"
7777

7878

7979
class BroadcastMidBench(BroadcastMulBench):
80-
def __init__(self, mode, device, M, N, K):
81-
super(BroadcastMidBench, self).__init__(mode, device, "mid", M, N, K)
80+
def __init__(self, mode, device, dtype, M, N, K):
81+
super(BroadcastMidBench, self).__init__(mode, device, dtype, "mid", M, N, K)
8282

8383
@staticmethod
8484
def module():
8585
return "broadcast_mid"
8686

8787

8888
class BroadcastColBench(BroadcastMulBench):
89-
def __init__(self, mode, device, M, N, K):
90-
super(BroadcastColBench, self).__init__(mode, device, "col", M, N, K)
89+
def __init__(self, mode, device, dtype, M, N, K):
90+
super(BroadcastColBench, self).__init__(mode, device, dtype, "col", M, N, K)
9191

9292
@staticmethod
9393
def module():
9494
return "broadcast_col"
9595

9696

9797
class BroadcastThreeArgs(benchmark.Benchmark):
98-
def __init__(self, mode, device, M, N, K, L):
99-
super().__init__(mode, device)
98+
def __init__(self, mode, device, dtype, M, N, K, L):
99+
super().__init__(mode, device, dtype)
100100
self.M = M
101101
self.N = N
102102
self.K = K
103103
self.L = L
104104

105-
self.d1 = self.rand([M, N], device=device, requires_grad=self.requires_grad)
106-
self.d2 = self.rand([K, M, 1], device=device, requires_grad=self.requires_grad)
105+
self.d1 = self.rand([M, N], device=device, dtype=dtype, requires_grad=self.requires_grad)
106+
self.d2 = self.rand([K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad)
107107
self.d3 = self.rand(
108-
[L, K, 1, 1], device=device, requires_grad=self.requires_grad
108+
[L, K, 1, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
109109
)
110110

111111
self.inputs = [self.d1, self.d2, self.d3]
@@ -160,15 +160,15 @@ class BroadcastBench(benchmark.Benchmark):
160160
unary_op_np_func = None
161161
split_input = True
162162

163-
def __init__(self, mode, device, M, N, K):
164-
super().__init__(mode, device)
163+
def __init__(self, mode, device, dtype, M, N, K):
164+
super().__init__(mode, device, dtype)
165165
self.M = M
166166
self.N = N
167167
self.K = K
168-
self.d1 = self.rand([M, N], device=device, requires_grad=self.requires_grad)
169-
self.d2 = self.rand([K, 1, N], device=device, requires_grad=self.requires_grad)
170-
self.d3 = self.rand([M, N], device=device, requires_grad=self.requires_grad)
171-
self.d4 = self.rand([K, M, 1], device=device, requires_grad=self.requires_grad)
168+
self.d1 = self.rand([M, N], device=device, dtype=dtype, requires_grad=self.requires_grad)
169+
self.d2 = self.rand([K, 1, N], device=device, dtype=dtype, requires_grad=self.requires_grad)
170+
self.d3 = self.rand([M, N], device=device, dtype=dtype, requires_grad=self.requires_grad)
171+
self.d4 = self.rand([K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad)
172172
self.inputs = [self.d1, self.d2, self.d3, self.d4]
173173

174174
def _eval(self, d1, d2, d3, d4, binary_op, unary_op):

‎benchmarks/tensorexpr/conv.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33

44
class ConvImplBench(benchmark.Benchmark):
5-
def __init__(self, case, mode, device, kernel_size, N, iC, H, W, oC):
6-
super().__init__(mode, device)
5+
def __init__(self, case, mode, device, dtype, kernel_size, N, iC, H, W, oC):
6+
super().__init__(mode, device, dtype)
77
self.case = case
88
self.kernel_size = kernel_size
99
self.N = N
@@ -41,13 +41,12 @@ def memory_workload(self):
4141
algorithmic_count = {"i": 1 + (1 + 1), "o": 1 + (1 + 1), "k": 1 + (1 + 1)}
4242

4343
buffer_size = {
44-
"i": self.N * self.iC * self.H * self.W * 4,
45-
"o": self.N * self.oC * self.H * self.W * 4,
44+
"i": self.N * self.iC * self.H * self.W,
45+
"o": self.N * self.oC * self.H * self.W,
4646
"k": self.oC
4747
* (self.iC / self.groups)
4848
* self.kernel_size
49-
* self.kernel_size
50-
* 4,
49+
* self.kernel_size,
5150
}
5251
sol_size = 0
5352
algorithmic_size = 0

‎benchmarks/tensorexpr/elementwise.py

+58-7
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ class ElementBench(benchmark.Benchmark):
1515
unary_op_np_func = None
1616
split_input = True
1717

18-
def __init__(self, mode, device, N):
19-
super().__init__(mode, device)
18+
def __init__(self, mode, device, dtype, N):
19+
super().__init__(mode, device, dtype)
2020
self.N = N
21-
self.d1 = self.rand([N], device=device, requires_grad=self.requires_grad)
22-
self.d2 = self.rand([N], device=device, requires_grad=self.requires_grad)
23-
self.d3 = self.rand([N], device=device, requires_grad=self.requires_grad)
24-
self.d4 = self.rand([N], device=device, requires_grad=self.requires_grad)
21+
self.d1 = self.rand([N], device=device, dtype=dtype, requires_grad=self.requires_grad)
22+
self.d2 = self.rand([N], device=device, dtype=dtype, requires_grad=self.requires_grad)
23+
self.d3 = self.rand([N], device=device, dtype=dtype, requires_grad=self.requires_grad)
24+
self.d4 = self.rand([N], device=device, dtype=dtype, requires_grad=self.requires_grad)
2525
self.inputs = [self.d1, self.d2, self.d3, self.d4]
2626
self.deterministic = "rand" not in self.op_str
2727

@@ -32,6 +32,7 @@ def binary_op(x, y):
3232
if not unary_op:
3333
def unary_op(x):
3434
return x
35+
3536
if self.split_input:
3637
d1 = unary_op(d1)
3738
d2 = unary_op(d2)
@@ -88,7 +89,7 @@ def memory_workload(self):
8889
sol_count = 1
8990
algorithmic_count = 1
9091

91-
buffer_size = self.N * 4
92+
buffer_size = self.N
9293
return {
9394
"sol": buffer_size * sol_count,
9495
"algorithmic": buffer_size * algorithmic_count,
@@ -157,3 +158,53 @@ def register_element_ops():
157158

158159
# benchmark.register_benchmark_class(ElementMulBench)
159160
register_element_ops()
161+
162+
class SimpleElementBench(benchmark.Benchmark):
163+
def __init__(self, mode, device, dtype, N):
164+
super().__init__(mode, device, dtype)
165+
self.N = N
166+
self.data = self.rand([N], device=device, dtype=dtype, requires_grad=self.requires_grad)
167+
self.inputs = [self.data]
168+
169+
def forward(self, data):
170+
a = data + 0.001
171+
b = a + 0.002
172+
return b
173+
174+
def reference(self):
175+
binary_op = self.__class__.binary_op_np_func
176+
unary_op = self.__class__.unary_op_np_func
177+
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
178+
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
179+
180+
def config(self):
181+
return [self.N]
182+
183+
@staticmethod
184+
def input_iterable():
185+
return True
186+
187+
@classmethod
188+
def module(cls):
189+
return "simple_element"
190+
191+
def memory_workload(self):
192+
input_count = len(self.inputs)
193+
if self.mode == "fwd":
194+
sol_count = 2
195+
algorithmic_count = 2
196+
else:
197+
sol_count = 2
198+
algorithmic_count = 2
199+
200+
buffer_size = self.N
201+
return {
202+
"sol": buffer_size * sol_count,
203+
"algorithmic": buffer_size * algorithmic_count,
204+
}
205+
206+
@staticmethod
207+
def default_configs():
208+
return [[1 << 25]]
209+
210+
benchmark.register_benchmark_class(SimpleElementBench)

‎benchmarks/tensorexpr/matmul.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44

55
class MatMulBench(benchmark.Benchmark):
6-
def __init__(self, mode, device, B, M, N, K):
7-
super().__init__(mode, device)
6+
def __init__(self, mode, device, dtype, B, M, N, K):
7+
super().__init__(mode, device, dtype)
88
self.B = B
99
self.M = M
1010
self.N = N
1111
self.K = K
12-
self.d1 = self.rand([B, M, N], device=device, requires_grad=self.requires_grad)
13-
self.d2 = self.rand([B, N, K], device=device, requires_grad=self.requires_grad)
12+
self.d1 = self.rand([B, M, N], device=device, dtype=dtype, requires_grad=self.requires_grad)
13+
self.d2 = self.rand([B, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad)
1414
self.inputs = [self.d1, self.d2]
1515

1616
def forward(self, d1, d2):
@@ -40,7 +40,6 @@ def memory_workload(self):
4040
+ self.B * self.M * self.N
4141
+ self.B * self.N * self.K
4242
)
43-
buffer_size *= 4
4443
return {
4544
"sol": buffer_size * sol_count,
4645
"algorithmic": buffer_size * algorithmic_count,

‎benchmarks/tensorexpr/normalization.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33

44

55
class NormalizationBench(benchmark.Benchmark):
6-
def __init__(self, mode, device, N, C, H, W):
7-
super().__init__(mode, device)
6+
def __init__(self, mode, device, dtype, N, C, H, W):
7+
super().__init__(mode, device, dtype)
88
self.N = N
99
self.C = C
1010
self.H = H
1111
self.W = W
1212

1313
self.data = self.nchw_rand(
1414
[self.N, self.C, self.H, self.W],
15-
device=device,
15+
device=device, dtype=dtype,
1616
requires_grad=self.requires_grad,
1717
)
18-
self.running_mean = self.rand([self.C], device=device)
19-
self.running_var = self.rand([self.C], device=device)
18+
self.running_mean = self.rand([self.C], device=device, dtype=dtype)
19+
self.running_var = self.rand([self.C], device=device, dtype=dtype)
2020
self.training = self.mode == "both"
2121

2222
def config(self):

‎benchmarks/tensorexpr/pooling.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class PoolingBench(benchmark.Benchmark):
5-
def __init__(self, case, mode, device, kernel_size, N, C, H, W):
5+
def __init__(self, case, mode, device, dtype, kernel_size, N, C, H, W):
66
super().__init__(mode, device)
77
self.case = case
88
self.kernel_size = kernel_size
@@ -11,7 +11,7 @@ def __init__(self, case, mode, device, kernel_size, N, C, H, W):
1111
self.H = H
1212
self.W = W
1313
self.data = self.rand(
14-
[N, C, H, W], device=device, requires_grad=self.requires_grad
14+
[N, C, H, W], device=device, dtype=dtype, requires_grad=self.requires_grad
1515
)
1616

1717
def forward(self):
@@ -32,7 +32,7 @@ def memory_workload(self):
3232
sol_count = (1 + 1) + (1 + 1)
3333
algorithmic_count = (1 + 1) + (2 + 1)
3434

35-
buffer_size = self.N * self.C * self.H * self.W * 4
35+
buffer_size = self.N * self.C * self.H * self.W
3636
return {
3737
"sol": buffer_size * sol_count,
3838
"algorithmic": buffer_size * algorithmic_count,

‎benchmarks/tensorexpr/pt_engine.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33

44
class TorchTensorEngine(object):
5-
def rand(self, shape, device=None, requires_grad=False):
6-
return torch.rand(shape, device=device, requires_grad=requires_grad)
5+
def rand(self, shape, device=None, dtype=None, requires_grad=False):
6+
return torch.rand(shape, device=device, dtype=dtype, requires_grad=requires_grad)
7+
8+
def randn(self, shape, device=None, dtype=None, requires_grad=False):
9+
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
710

811
def nchw_rand(self, shape, device=None, requires_grad=False):
912
return self.rand(shape, device=device, requires_grad=requires_grad)

‎benchmarks/tensorexpr/reduction.py

+85-14
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33

44
class ReduceBench(benchmark.Benchmark):
5-
def __init__(self, mode, device, case, M, N, K):
6-
super().__init__(mode, device)
5+
def __init__(self, mode, device, dtype, case, M, N, K):
6+
super().__init__(mode, device, dtype)
77
self.case = case
88
self.M = M
99
self.N = N
1010
self.K = K
1111

12-
self.data = self.rand(
13-
[M, N, K], device=device, requires_grad=self.requires_grad
14-
)
12+
self.inputs = [self.randn(
13+
[M, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
14+
)]
1515
if case == "row":
1616
self.dims = [1, 2]
1717
elif case == "mid":
@@ -21,8 +21,9 @@ def __init__(self, mode, device, case, M, N, K):
2121
else:
2222
raise ValueError("invalid case: %s" % case)
2323

24-
def forward(self):
25-
y = self.sum(self.data, self.dims)
24+
def forward(self, inputs):
25+
x = self.add(inputs, 0.001)
26+
y = self.sum(x, self.dims)
2627
return y
2728

2829
def config(self):
@@ -47,40 +48,110 @@ def memory_workload(self):
4748
sol_count = (1) + (1)
4849
algorithmic_count = 1 + 1
4950

50-
buffer_size = self.M * self.N * self.K * 4
51+
buffer_size = self.M * self.N * self.K
5152
return {
5253
"sol": buffer_size * sol_count,
5354
"algorithmic": buffer_size * algorithmic_count,
5455
}
5556

5657

5758
class ReduceRowBench(ReduceBench):
58-
def __init__(self, mode, device, M, N, K):
59-
super(ReduceRowBench, self).__init__(mode, device, "row", M, N, K)
59+
def __init__(self, mode, device, dtype, M, N, K):
60+
super(ReduceRowBench, self).__init__(mode, device, dtype, "row", M, N, K)
6061

6162
@staticmethod
6263
def module():
6364
return "reduce_row"
6465

6566

6667
class ReduceMidBench(ReduceBench):
67-
def __init__(self, mode, device, M, N, K):
68-
super(ReduceMidBench, self).__init__(mode, device, "mid", M, N, K)
68+
def __init__(self, mode, device, dtype, M, N, K):
69+
super(ReduceMidBench, self).__init__(mode, device, dtype, "mid", M, N, K)
6970

7071
@staticmethod
7172
def module():
7273
return "reduce_mid"
7374

7475

7576
class ReduceColBench(ReduceBench):
76-
def __init__(self, mode, device, M, N, K):
77-
super(ReduceColBench, self).__init__(mode, device, "col", M, N, K)
77+
def __init__(self, mode, device, dtype, M, N, K):
78+
super(ReduceColBench, self).__init__(mode, device, dtype, "col", M, N, K)
7879

7980
@staticmethod
8081
def module():
8182
return "reduce_col"
8283

84+
class Reduce2DBench(benchmark.Benchmark):
85+
'''
86+
A benchmark class to validate 2 dimensional reduction performance.
87+
Only a simple add is fused to induce the fuser and isolate reduction perf.
88+
'''
89+
def __init__(self, mode, device, dtype, red_dim, dim0, dim1):
90+
super().__init__(mode, device, dtype)
91+
self.red_dim = red_dim
92+
self.dim0 = dim0
93+
self.dim1 = dim1
94+
95+
self.inputs = [self.randn(
96+
[dim0, dim1], device=device, dtype=dtype, requires_grad=self.requires_grad
97+
)]
98+
99+
if red_dim != 0 and red_dim != 1 :
100+
raise ValueError("invalid reduction dimension: {}".format(red_dim))
101+
102+
def forward(self, inputs):
103+
x = self.add(inputs, 0.001)
104+
y = self.sum(x, [self.red_dim])
105+
return y
106+
107+
def config(self):
108+
return [self.red_dim, self.dim0, self.dim1]
109+
110+
@staticmethod
111+
def default_configs():
112+
return [
113+
[1, 640, 524288],
114+
]
115+
116+
@staticmethod
117+
def module():
118+
return "reduce2d"
119+
120+
@staticmethod
121+
def input_iterable() :
122+
return True
123+
124+
def memory_workload(self):
125+
assert self.mode == "fwd", "Only the forward operation is modeled!"
126+
127+
buffer_size = self.dim0 * self.dim1
128+
if self.red_dim == 0 :
129+
buffer_size += self.dim1
130+
else :
131+
buffer_size += self.dim0
132+
return {
133+
"sol": buffer_size,
134+
"algorithmic": buffer_size,
135+
}
136+
137+
class Reduce2DInnerBench(Reduce2DBench):
138+
def __init__(self, mode, device, dtype, dim0, dim1):
139+
super(Reduce2DInnerBench, self).__init__(mode, device, dtype, 1, dim0, dim1)
140+
141+
@staticmethod
142+
def module():
143+
return "reduce2d_inner"
144+
145+
class Reduce2DOuterBench(Reduce2DBench):
146+
def __init__(self, mode, device, dtype, dim0, dim1):
147+
super(Reduce2DOuterBench, self).__init__(mode, device, dtype, 0, dim0, dim1)
148+
149+
@staticmethod
150+
def module():
151+
return "reduce2d_outer"
83152

84153
benchmark.register_benchmark_class(ReduceRowBench)
85154
benchmark.register_benchmark_class(ReduceMidBench)
86155
benchmark.register_benchmark_class(ReduceColBench)
156+
benchmark.register_benchmark_class(Reduce2DInnerBench)
157+
benchmark.register_benchmark_class(Reduce2DOuterBench)

‎benchmarks/tensorexpr/rnn_eltwise.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,24 @@
22
import torch
33

44
class RNNEltwise(benchmark.Benchmark):
5-
def __init__(self, mode, device, b, hs):
6-
super().__init__(mode, device)
5+
def __init__(self, mode, device, dtype, b, hs):
6+
super().__init__(mode, device, dtype)
77
self.b = b
88
self.hs = hs
99
self.input = self.rand(
10-
[b, 4 * hs], device=device, requires_grad=self.requires_grad
10+
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
1111
)
1212
self.hx = self.rand(
13-
[b, 4 * hs], device=device, requires_grad=self.requires_grad
13+
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
1414
)
1515
self.cx = self.rand(
16-
[b, hs], device=device, requires_grad=self.requires_grad
16+
[b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad
1717
)
1818
self.b_ih = self.rand(
19-
[b, 4 * hs], device=device, requires_grad=self.requires_grad
19+
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
2020
)
2121
self.b_hh = self.rand(
22-
[b, 4 * hs], device=device, requires_grad=self.requires_grad
22+
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
2323
)
2424
self.inputs = [
2525
self.input,

‎benchmarks/tensorexpr/softmax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def memory_workload(self):
3131
sol_count = (1 + 1) + (1 + 1)
3232
algorithmic_count = (3 + 1) + (3 + 1)
3333

34-
buffer_size = self.M * self.N * 4
34+
buffer_size = self.M * self.N
3535
return {
3636
"sol": buffer_size * sol_count,
3737
"algorithmic": buffer_size * algorithmic_count,

‎benchmarks/tensorexpr/swish.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44

55
class SwishBench(benchmark.Benchmark):
6-
def __init__(self, mode, device, M, N):
7-
super().__init__(mode, device)
6+
def __init__(self, mode, device, dtype, M, N):
7+
super().__init__(mode, device, dtype)
88
self.M = M
99
self.N = N
10-
self.data = self.rand([M, N], device=device, requires_grad=self.requires_grad)
10+
self.data = self.rand([M, N], device=device, dtype=dtype, requires_grad=self.requires_grad)
1111
self.inputs = [self.data]
1212
self.zeros = torch.zeros(M, N, device=device)
1313
self.six = self.zeros + 6.0
@@ -36,7 +36,7 @@ def memory_workload(self):
3636
sol_count = (1 + 1) + (1 + 1)
3737
algorithmic_count = (3 + 1) + (3 + 1)
3838

39-
buffer_size = self.M * self.N * 4
39+
buffer_size = self.M * self.N
4040
return {
4141
"sol": buffer_size * sol_count,
4242
"algorithmic": buffer_size * algorithmic_count,

0 commit comments

Comments
 (0)
Please sign in to comment.