@@ -229,11 +229,11 @@ def test_def(self):
229
229
# m.impl("test_def", [](const Tensor& x) { return x })
230
230
lambda m : m .impl_t_t ("foo" ),
231
231
# m.impl("test_def", kCPU, [](const Tensor& x) { return x })
232
- lambda m : m .impl_t_t ("foo" , dispatch = "cpu " ),
232
+ lambda m : m .impl_t_t ("foo" , dispatch = "CPU " ),
233
233
# m.impl("test_def", kAutograd, [](const Tensor& x) { return x })
234
- lambda m : m .impl_t_t ("foo" , dispatch = "autograd " ),
234
+ lambda m : m .impl_t_t ("foo" , dispatch = "Autograd " ),
235
235
# m.impl("test_def", kAutogradCPU, [](const Tensor& x) { return x })
236
- lambda m : m .impl_t_t ("foo" , dispatch = "autogradcpu " )
236
+ lambda m : m .impl_t_t ("foo" , dispatch = "AutogradCPU " )
237
237
]).state
238
238
self .assertExpectedInline (state , '''\
239
239
name: test::foo
@@ -262,11 +262,11 @@ def test_def_with_inference(self):
262
262
# m.def("foo", [](const Tensor & x) { return x })
263
263
lambda m : m .def_name_t_t ("foo" ),
264
264
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
265
- lambda m : m .impl_t_t ("foo" , "cpu " ),
265
+ lambda m : m .impl_t_t ("foo" , "CPU " ),
266
266
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
267
- lambda m : m .impl_t_t ("foo" , "autograd " ),
267
+ lambda m : m .impl_t_t ("foo" , "Autograd " ),
268
268
# m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x })
269
- lambda m : m .impl_t_t ("foo" , "autogradcpu " )
269
+ lambda m : m .impl_t_t ("foo" , "AutogradCPU " )
270
270
]).state
271
271
self .assertExpectedInline (state , '''\
272
272
name: test::foo
@@ -296,11 +296,11 @@ def test_impl_only(self):
296
296
# m.impl("foo", [](const Tensor& x) { return x })
297
297
lambda m : m .impl_t_t ("foo" ),
298
298
# m.impl("foo", torch::kCPU, [](const Tensor& x) { return x })
299
- lambda m : m .impl_t_t ("foo" , "cpu " ),
299
+ lambda m : m .impl_t_t ("foo" , "CPU " ),
300
300
# m.impl("foo", torch::kAutograd, [](const Tensor& x) { return x })
301
- lambda m : m .impl_t_t ("foo" , "autograd " ),
301
+ lambda m : m .impl_t_t ("foo" , "Autograd " ),
302
302
# m.impl("foo", torch::kAutogradCPU, [](const Tensor& x) { return x })
303
- lambda m : m .impl_t_t ("foo" , "autogradcpu " )
303
+ lambda m : m .impl_t_t ("foo" , "AutogradCPU " )
304
304
]).state
305
305
self .assertExpectedInline (state , '''\
306
306
name: test::foo
@@ -316,13 +316,13 @@ def test_computed_table(self):
316
316
# m.def("foo", [](const Tensor & x) { return x })
317
317
lambda m : m .def_name_t_t ("foo" ),
318
318
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
319
- lambda m : m .impl_t_t ("foo" , "cpu " , debug = "fn_cpu" ),
319
+ lambda m : m .impl_t_t ("foo" , "CPU " , debug = "fn_cpu" ),
320
320
# m.impl("foo", torch::kCUDA, [](const Tensor & x) { return x })
321
- lambda m : m .impl_t_t ("foo" , "xla " , debug = "fn_xla" ),
321
+ lambda m : m .impl_t_t ("foo" , "XLA " , debug = "fn_xla" ),
322
322
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
323
- lambda m : m .impl_t_t ("foo" , "autograd " , debug = "fn_autograd" ),
323
+ lambda m : m .impl_t_t ("foo" , "Autograd " , debug = "fn_autograd" ),
324
324
# m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x })
325
- lambda m : m .impl_t_t ("foo" , "autogradcpu " , debug = "fn_autogradcpu" )
325
+ lambda m : m .impl_t_t ("foo" , "AutogradCPU " , debug = "fn_autogradcpu" )
326
326
])
327
327
state , table = result .state , result .table
328
328
self .assertExpectedInline (state , '''\
@@ -351,12 +351,12 @@ def test_computed_table(self):
351
351
''' )
352
352
353
353
def test_computed_table_with_cpu_catchall (self ):
354
- global_m = C ._dispatch_library ("IMPL" , "_" , "autogradcpu " )
354
+ global_m = C ._dispatch_library ("IMPL" , "_" , "AutogradCPU " )
355
355
result = self .commute ("foo" , [
356
356
# m.def("foo", [](const Tensor & x) { return x })
357
357
lambda m : m .def_name_t_t ("foo" ),
358
358
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
359
- lambda m : m .impl_t_t ("foo" , "cpu " ),
359
+ lambda m : m .impl_t_t ("foo" , "CPU " ),
360
360
])
361
361
state , table = result .state , result .table
362
362
self .assertExpectedInline (state , '''\
@@ -382,12 +382,12 @@ def test_computed_table_with_cpu_catchall(self):
382
382
''' )
383
383
384
384
def test_computed_table_with_math (self ):
385
- global_m = C ._dispatch_library ("IMPL" , "_" , "autogradcpu " )
385
+ global_m = C ._dispatch_library ("IMPL" , "_" , "AutogradCPU " )
386
386
result = self .commute ("foo" , [
387
387
# m.def("foo(Tensor x) -> Tensor")
388
388
lambda m : m .def_ ("foo(Tensor x) -> Tensor" ),
389
389
# m.impl("foo", torch::kMath, [](const Tensor & x) { return x })
390
- lambda m : m .impl_t_t ("foo" , "math " ),
390
+ lambda m : m .impl_t_t ("foo" , "Math " ),
391
391
])
392
392
state , table = result .state , result .table
393
393
self .assertExpectedInline (state , '''\
@@ -412,14 +412,14 @@ def test_computed_table_with_math(self):
412
412
''' )
413
413
414
414
def test_computed_table_with_cpu_math (self ):
415
- global_m = C ._dispatch_library ("IMPL" , "_" , "autogradcpu " )
415
+ global_m = C ._dispatch_library ("IMPL" , "_" , "AutogradCPU " )
416
416
result = self .commute ("foo" , [
417
417
# m.def("foo(Tensor x) -> Tensor")
418
418
lambda m : m .def_ ("foo(Tensor x) -> Tensor" ),
419
419
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
420
- lambda m : m .impl_t_t ("foo" , "cpu " , debug = "fn_cpu" ),
420
+ lambda m : m .impl_t_t ("foo" , "CPU " , debug = "fn_cpu" ),
421
421
# m.impl("foo", torch::kMath, [](const Tensor & x) { return x })
422
- lambda m : m .impl_t_t ("foo" , "math " , debug = "fn_math" ),
422
+ lambda m : m .impl_t_t ("foo" , "Math " , debug = "fn_math" ),
423
423
])
424
424
state , table = result .state , result .table
425
425
self .assertExpectedInline (state , '''\
@@ -445,12 +445,12 @@ def test_computed_table_with_cpu_math(self):
445
445
''' )
446
446
447
447
def test_computed_table_with_autograd (self ):
448
- global_m = C ._dispatch_library ("IMPL" , "_" , "autogradcpu " )
448
+ global_m = C ._dispatch_library ("IMPL" , "_" , "AutogradCPU " )
449
449
result = self .commute ("foo" , [
450
450
# m.def("foo(Tensor x) -> Tensor")
451
451
lambda m : m .def_ ("foo(Tensor x) -> Tensor" ),
452
452
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
453
- lambda m : m .impl_t_t ("foo" , "autograd " ),
453
+ lambda m : m .impl_t_t ("foo" , "Autograd " ),
454
454
])
455
455
state , table = result .state , result .table
456
456
self .assertExpectedInline (state , '''\
@@ -476,11 +476,11 @@ def test_computed_table_with_cpu_autograd_math_catchall(self):
476
476
# m.def("foo", [](const Tensor & x) { return x })
477
477
lambda m : m .def_name_t_t ("foo" ),
478
478
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
479
- lambda m : m .impl_t_t ("foo" , "cpu " , debug = "fn_cpu" ),
479
+ lambda m : m .impl_t_t ("foo" , "CPU " , debug = "fn_cpu" ),
480
480
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
481
- lambda m : m .impl_t_t ("foo" , "autograd " , debug = "fn_autograd" ),
481
+ lambda m : m .impl_t_t ("foo" , "Autograd " , debug = "fn_autograd" ),
482
482
# m.impl("foo", torch::kMath, [](const Tensor & x) { return x })
483
- lambda m : m .impl_t_t ("foo" , "math " , debug = "fn_math" ),
483
+ lambda m : m .impl_t_t ("foo" , "Math " , debug = "fn_math" ),
484
484
])
485
485
state , table = result .state , result .table
486
486
self .assertExpectedInline (state , '''\
@@ -512,9 +512,9 @@ def test_computed_table_with_cpu_autograd_catchall(self):
512
512
# m.def("foo", [](const Tensor & x) { return x })
513
513
lambda m : m .def_name_t_t ("foo" ),
514
514
# m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
515
- lambda m : m .impl_t_t ("foo" , "cpu " , debug = "fn_cpu" ),
515
+ lambda m : m .impl_t_t ("foo" , "CPU " , debug = "fn_cpu" ),
516
516
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
517
- lambda m : m .impl_t_t ("foo" , "autograd " , debug = "fn_autograd" ),
517
+ lambda m : m .impl_t_t ("foo" , "Autograd " , debug = "fn_autograd" ),
518
518
])
519
519
state , table = result .state , result .table
520
520
self .assertExpectedInline (state , '''\
@@ -538,6 +538,39 @@ def test_computed_table_with_cpu_autograd_catchall(self):
538
538
AutogradCPU: fn_autograd [autograd kernel]
539
539
AutogradCUDA: fn_autograd [autograd kernel]
540
540
AutogradXLA: fn_autograd [autograd kernel]
541
+ ''' )
542
+
543
+ def test_computed_table_with_ambiguous_autogradother (self ):
544
+ result = self .commute ("foo" , [
545
+ # m.def("foo", [](const Tensor & x) { return x })
546
+ lambda m : m .def_name_t_t ("foo" ),
547
+ # m.impl("foo", torch::kMath, [](const Tensor & x) { return x })
548
+ lambda m : m .impl_t_t ("foo" , "Math" , debug = "fn_math" ),
549
+ # m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x })
550
+ lambda m : m .impl_t_t ("foo" , "QuantizedCPU" , debug = "fn_quantizedcpu" ),
551
+ ])
552
+ state , table = result .state , result .table
553
+ self .assertExpectedInline (state , '''\
554
+ name: test::foo
555
+ schema: test::foo(Tensor _0) -> (Tensor _0)
556
+ debug: registered at /dev/null:0
557
+ alias analysis kind: CONSERVATIVE
558
+ QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
559
+ Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
560
+ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
561
+ ''' )
562
+
563
+ # computed dispatch table is too big, so we only check on a few entries we're interested in.
564
+ extracted_table = extract_dispatch_table_with_keys (table , dispatch_keys_to_check )
565
+
566
+ self .assertExpectedInline (extracted_table , '''\
567
+ CPU: fn_math [math kernel]
568
+ CUDA: fn_math [math kernel]
569
+ XLA: fn_math [math kernel]
570
+ AutogradOther: ambiguous_autogradother [ambiguous autogradother]
571
+ AutogradCPU: fn_math [math kernel]
572
+ AutogradCUDA: fn_math [math kernel]
573
+ AutogradXLA: fn_math [math kernel]
541
574
''' )
542
575
543
576
# Can't do this yet for BC reasons
@@ -631,7 +664,7 @@ def test_multiple_def_alias_mismatch(self):
631
664
)
632
665
633
666
def test_multiple_fallback (self ):
634
- global_m = C ._dispatch_library ("IMPL" , "_" , "xla " )
667
+ global_m = C ._dispatch_library ("IMPL" , "_" , "XLA " )
635
668
global_m .fallback_fallthrough (),
636
669
try :
637
670
global_m .fallback_fallthrough (),
0 commit comments