Skip to content

Commit 1cd5ba4

Browse files
zou3519facebook-github-bot
authored andcommittedSep 16, 2020
Add batching rule for "is_complex", "conj" (pytorch#44649)
Summary: Pull Request resolved: pytorch#44649 To unblock pytorch#43208, which adds "is_complex" checks to backward formulas that are being tested for batched gradient support with vmap. Test Plan: - `pytest test/test_vmap.py -v` Reviewed By: anjali411 Differential Revision: D23685356 Pulled By: zou3519 fbshipit-source-id: 29e41a9296336f6d1008e3040cade4c643bf5ebf
1 parent cce7680 commit 1cd5ba4

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed
 

‎aten/src/ATen/BatchingRegistrations.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
522522
m.impl("_remove_batch_dim", native::_remove_batch_dim);
523523

524524
m.impl_UNBOXED("sum.dim_IntList", sum_batching_rule);
525+
m.impl("is_complex", native::is_complex);
526+
m.impl("conj", native::conj);
525527

526528
// view operations
527529
m.impl("chunk", chunk_batching_rule);
@@ -560,6 +562,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
560562
UNARY_POINTWISE(ceil);
561563
UNARY_POINTWISE(cos);
562564
UNARY_POINTWISE(cosh);
565+
UNARY_POINTWISE(_conj);
563566
UNARY_POINTWISE(digamma);
564567
UNARY_POINTWISE(exp);
565568
UNARY_POINTWISE(expm1);

‎test/test_vmap.py

+42
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,35 @@ def op(*tensors):
930930
test(vmap(get_op(0), in_dims=(0, 0)),
931931
(torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0))
932932

933+
def test_conj(self):
934+
op = torch.conj
935+
936+
def run_test(dtype):
937+
def get(shape):
938+
return torch.randn(shape, dtype=dtype)
939+
B0, B1 = 7, 11
940+
test = self._vmap_test
941+
942+
# Single vmap, various in_dims / out_dims
943+
test(op, [get([B0, 3])])
944+
test(op, [get([2, 5, B0, 3])], in_dims=2)
945+
test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
946+
947+
# Doubly nested vmap
948+
test(vmap(op), [get([B0, B1])])
949+
test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2)
950+
test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])],
951+
in_dims=2, out_dims=2)
952+
953+
# correctness tests
954+
run_test(torch.float)
955+
run_test(torch.cfloat)
956+
957+
# check that torch.conj on a non-complex tensor returns the same tensor
958+
real_tensor = torch.randn(3)
959+
result = vmap(op)(real_tensor)
960+
self.assertEqual(result.data_ptr(), real_tensor.data_ptr())
961+
933962
def test_chunk(self):
934963
test = self._vmap_view_test
935964
op = torch.chunk
@@ -997,6 +1026,19 @@ def test_expand_as(self):
9971026
test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
9981027
test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
9991028

1029+
def test_is_complex(self):
1030+
ctensor = torch.randn(3, dtype=torch.cfloat)
1031+
tensor = torch.randn(3)
1032+
1033+
def foo(x):
1034+
if x.is_complex():
1035+
return torch.tensor(1)
1036+
else:
1037+
return torch.tensor(0)
1038+
1039+
self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
1040+
self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
1041+
10001042
def test_movedim(self):
10011043
op = torch.movedim
10021044
test = self._vmap_view_test

0 commit comments

Comments
 (0)
Please sign in to comment.