Skip to content

Commit 10f2875

Browse files
Ailing Zhangfacebook-github-bot
Ailing Zhang
authored andcommittedSep 22, 2020
Align casing in test_dispatch with dispatch keys. (pytorch#44933)
Summary: Pull Request resolved: pytorch#44933 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D23778247 Pulled By: ailzhang fbshipit-source-id: bc3725eae670b03543015afe763cb3bb16baf8f6
1 parent 1fd48a9 commit 10f2875

File tree

2 files changed

+68
-34
lines changed

2 files changed

+68
-34
lines changed
 

‎test/test_dispatch.py

+61-28
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,11 @@ def test_def(self):
229229
# m.impl("test_def", [](const Tensor& x) { return x })
230230
lambda m: m.impl_t_t("foo"),
231231
# m.impl("test_def", kCPU, [](const Tensor& x) { return x })
232-
lambda m: m.impl_t_t("foo", dispatch="cpu"),
232+
lambda m: m.impl_t_t("foo", dispatch="CPU"),
233233
# m.impl("test_def", kAutograd, [](const Tensor& x) { return x })
234-
lambda m: m.impl_t_t("foo", dispatch="autograd"),
234+
lambda m: m.impl_t_t("foo", dispatch="Autograd"),
235235
# m.impl("test_def", kAutogradCPU, [](const Tensor& x) { return x })
236-
lambda m: m.impl_t_t("foo", dispatch="autogradcpu")
236+
lambda m: m.impl_t_t("foo", dispatch="AutogradCPU")
237237
]).state
238238
self.assertExpectedInline(state, '''\
239239
name: test::foo
@@ -262,11 +262,11 @@ def test_def_with_inference(self):
262262
# m.def("foo", [](const Tensor & x) { return x })
263263
lambda m: m.def_name_t_t("foo"),
264264
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
265-
lambda m: m.impl_t_t("foo", "cpu"),
265+
lambda m: m.impl_t_t("foo", "CPU"),
266266
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
267-
lambda m: m.impl_t_t("foo", "autograd"),
267+
lambda m: m.impl_t_t("foo", "Autograd"),
268268
# m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x })
269-
lambda m: m.impl_t_t("foo", "autogradcpu")
269+
lambda m: m.impl_t_t("foo", "AutogradCPU")
270270
]).state
271271
self.assertExpectedInline(state, '''\
272272
name: test::foo
@@ -296,11 +296,11 @@ def test_impl_only(self):
296296
# m.impl("foo", [](const Tensor& x) { return x })
297297
lambda m: m.impl_t_t("foo"),
298298
# m.impl("foo", torch::kCPU, [](const Tensor& x) { return x })
299-
lambda m: m.impl_t_t("foo", "cpu"),
299+
lambda m: m.impl_t_t("foo", "CPU"),
300300
# m.impl("foo", torch::kAutograd, [](const Tensor& x) { return x })
301-
lambda m: m.impl_t_t("foo", "autograd"),
301+
lambda m: m.impl_t_t("foo", "Autograd"),
302302
# m.impl("foo", torch::kAutogradCPU, [](const Tensor& x) { return x })
303-
lambda m: m.impl_t_t("foo", "autogradcpu")
303+
lambda m: m.impl_t_t("foo", "AutogradCPU")
304304
]).state
305305
self.assertExpectedInline(state, '''\
306306
name: test::foo
@@ -316,13 +316,13 @@ def test_computed_table(self):
316316
# m.def("foo", [](const Tensor & x) { return x })
317317
lambda m: m.def_name_t_t("foo"),
318318
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
319-
lambda m: m.impl_t_t("foo", "cpu", debug="fn_cpu"),
319+
lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
320320
# m.impl("foo", torch::kCUDA, [](const Tensor & x) { return x })
321-
lambda m: m.impl_t_t("foo", "xla", debug="fn_xla"),
321+
lambda m: m.impl_t_t("foo", "XLA", debug="fn_xla"),
322322
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
323-
lambda m: m.impl_t_t("foo", "autograd", debug="fn_autograd"),
323+
lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
324324
# m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x })
325-
lambda m: m.impl_t_t("foo", "autogradcpu", debug="fn_autogradcpu")
325+
lambda m: m.impl_t_t("foo", "AutogradCPU", debug="fn_autogradcpu")
326326
])
327327
state, table = result.state, result.table
328328
self.assertExpectedInline(state, '''\
@@ -351,12 +351,12 @@ def test_computed_table(self):
351351
''')
352352

353353
def test_computed_table_with_cpu_catchall(self):
354-
global_m = C._dispatch_library("IMPL", "_", "autogradcpu")
354+
global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
355355
result = self.commute("foo", [
356356
# m.def("foo", [](const Tensor & x) { return x })
357357
lambda m: m.def_name_t_t("foo"),
358358
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
359-
lambda m: m.impl_t_t("foo", "cpu"),
359+
lambda m: m.impl_t_t("foo", "CPU"),
360360
])
361361
state, table = result.state, result.table
362362
self.assertExpectedInline(state, '''\
@@ -382,12 +382,12 @@ def test_computed_table_with_cpu_catchall(self):
382382
''')
383383

384384
def test_computed_table_with_math(self):
385-
global_m = C._dispatch_library("IMPL", "_", "autogradcpu")
385+
global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
386386
result = self.commute("foo", [
387387
# m.def("foo(Tensor x) -> Tensor")
388388
lambda m: m.def_("foo(Tensor x) -> Tensor"),
389389
# m.impl("foo", torch::kMath, [](const Tensor & x) { return x })
390-
lambda m: m.impl_t_t("foo", "math"),
390+
lambda m: m.impl_t_t("foo", "Math"),
391391
])
392392
state, table = result.state, result.table
393393
self.assertExpectedInline(state, '''\
@@ -412,14 +412,14 @@ def test_computed_table_with_math(self):
412412
''')
413413

414414
def test_computed_table_with_cpu_math(self):
415-
global_m = C._dispatch_library("IMPL", "_", "autogradcpu")
415+
global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
416416
result = self.commute("foo", [
417417
# m.def("foo(Tensor x) -> Tensor")
418418
lambda m: m.def_("foo(Tensor x) -> Tensor"),
419419
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
420-
lambda m: m.impl_t_t("foo", "cpu", debug="fn_cpu"),
420+
lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
421421
# m.impl("foo", torch::kMath, [](const Tensor & x) { return x })
422-
lambda m: m.impl_t_t("foo", "math", debug="fn_math"),
422+
lambda m: m.impl_t_t("foo", "Math", debug="fn_math"),
423423
])
424424
state, table = result.state, result.table
425425
self.assertExpectedInline(state, '''\
@@ -445,12 +445,12 @@ def test_computed_table_with_cpu_math(self):
445445
''')
446446

447447
def test_computed_table_with_autograd(self):
448-
global_m = C._dispatch_library("IMPL", "_", "autogradcpu")
448+
global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
449449
result = self.commute("foo", [
450450
# m.def("foo(Tensor x) -> Tensor")
451451
lambda m: m.def_("foo(Tensor x) -> Tensor"),
452452
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
453-
lambda m: m.impl_t_t("foo", "autograd"),
453+
lambda m: m.impl_t_t("foo", "Autograd"),
454454
])
455455
state, table = result.state, result.table
456456
self.assertExpectedInline(state, '''\
@@ -476,11 +476,11 @@ def test_computed_table_with_cpu_autograd_math_catchall(self):
476476
# m.def("foo", [](const Tensor & x) { return x })
477477
lambda m: m.def_name_t_t("foo"),
478478
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
479-
lambda m: m.impl_t_t("foo", "cpu", debug="fn_cpu"),
479+
lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
480480
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
481-
lambda m: m.impl_t_t("foo", "autograd", debug="fn_autograd"),
481+
lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
482482
# m.impl("foo", torch::kMath, [](const Tensor & x) { return x })
483-
lambda m: m.impl_t_t("foo", "math", debug="fn_math"),
483+
lambda m: m.impl_t_t("foo", "Math", debug="fn_math"),
484484
])
485485
state, table = result.state, result.table
486486
self.assertExpectedInline(state, '''\
@@ -512,9 +512,9 @@ def test_computed_table_with_cpu_autograd_catchall(self):
512512
# m.def("foo", [](const Tensor & x) { return x })
513513
lambda m: m.def_name_t_t("foo"),
514514
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
515-
lambda m: m.impl_t_t("foo", "cpu", debug="fn_cpu"),
515+
lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
516516
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
517-
lambda m: m.impl_t_t("foo", "autograd", debug="fn_autograd"),
517+
lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
518518
])
519519
state, table = result.state, result.table
520520
self.assertExpectedInline(state, '''\
@@ -538,6 +538,39 @@ def test_computed_table_with_cpu_autograd_catchall(self):
538538
AutogradCPU: fn_autograd [autograd kernel]
539539
AutogradCUDA: fn_autograd [autograd kernel]
540540
AutogradXLA: fn_autograd [autograd kernel]
541+
''')
542+
543+
def test_computed_table_with_ambiguous_autogradother(self):
544+
result = self.commute("foo", [
545+
# m.def("foo", [](const Tensor & x) { return x })
546+
lambda m: m.def_name_t_t("foo"),
547+
# m.impl("foo", torch::kMath, [](const Tensor & x) { return x })
548+
lambda m: m.impl_t_t("foo", "Math", debug="fn_math"),
549+
# m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x })
550+
lambda m: m.impl_t_t("foo", "QuantizedCPU", debug="fn_quantizedcpu"),
551+
])
552+
state, table = result.state, result.table
553+
self.assertExpectedInline(state, '''\
554+
name: test::foo
555+
schema: test::foo(Tensor _0) -> (Tensor _0)
556+
debug: registered at /dev/null:0
557+
alias analysis kind: CONSERVATIVE
558+
QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
559+
Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
560+
catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
561+
''')
562+
563+
# computed dispatch table is too big, so we only check on a few entries we're interested in.
564+
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
565+
566+
self.assertExpectedInline(extracted_table, '''\
567+
CPU: fn_math [math kernel]
568+
CUDA: fn_math [math kernel]
569+
XLA: fn_math [math kernel]
570+
AutogradOther: ambiguous_autogradother [ambiguous autogradother]
571+
AutogradCPU: fn_math [math kernel]
572+
AutogradCUDA: fn_math [math kernel]
573+
AutogradXLA: fn_math [math kernel]
541574
''')
542575

543576
# Can't do this yet for BC reasons
@@ -631,7 +664,7 @@ def test_multiple_def_alias_mismatch(self):
631664
)
632665

633666
def test_multiple_fallback(self):
634-
global_m = C._dispatch_library("IMPL", "_", "xla")
667+
global_m = C._dispatch_library("IMPL", "_", "XLA")
635668
global_m.fallback_fallthrough(),
636669
try:
637670
global_m.fallback_fallthrough(),

‎torch/csrc/utils/python_dispatch.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ torch::Library::Kind parseKind(const std::string& k) {
2727

2828
c10::optional<c10::DispatchKey> parseDispatchKey(const std::string& k) {
2929
static std::unordered_map<std::string, c10::DispatchKey> key_map = {
30-
{"cpu", c10::DispatchKey::CPU},
31-
{"cuda", c10::DispatchKey::CUDA},
32-
{"xla", c10::DispatchKey::XLA},
33-
{"math", c10::DispatchKey::Math},
34-
{"autograd", c10::DispatchKey::Autograd},
35-
{"autogradcpu", c10::DispatchKey::AutogradCPU},
30+
{"CPU", c10::DispatchKey::CPU},
31+
{"CUDA", c10::DispatchKey::CUDA},
32+
{"XLA", c10::DispatchKey::XLA},
33+
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
34+
{"Math", c10::DispatchKey::Math},
35+
{"Autograd", c10::DispatchKey::Autograd},
36+
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
3637
{"", c10::DispatchKey::Undefined},
3738
};
3839
auto it = key_map.find(k);

0 commit comments

Comments
 (0)
Please sign in to comment.