Skip to content

Commit 1c616c5

Browse files
awthompfacebook-github-bot
authored andcommittedAug 14, 2020
Add complex tensor dtypes for the __cuda_array_interface__ spec (pytorch#42918)
Summary: Fixes pytorch#42860 The `__cuda_array_interface__` tensor specification is missing the appropriate datatypes for the newly merged complex64 and complex128 tensors. This PR addresses this issue by casting: * `torch.complex64` to 'c8' * `torch.complex128` to 'c16' Pull Request resolved: pytorch#42918 Reviewed By: izdeby Differential Revision: D23130219 Pulled By: anjali411 fbshipit-source-id: 5f8ee8446a71cad2f28811afdeae3a263a31ad11
1 parent c3fb152 commit 1c616c5

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed
 

‎test/test_numba_integration.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def test_array_adaptor(self):
105105
"""Torch __cuda_array_adaptor__ exposes tensor data to numba.cuda."""
106106

107107
torch_dtypes = [
108+
torch.complex64,
109+
torch.complex128,
108110
torch.float16,
109111
torch.float32,
110112
torch.float64,
@@ -244,6 +246,8 @@ def test_from_cuda_array_interface(self):
244246
"""
245247

246248
dtypes = [
249+
numpy.complex64,
250+
numpy.complex128,
247251
numpy.float64,
248252
numpy.float32,
249253
numpy.int64,
@@ -263,31 +267,31 @@ def test_from_cuda_array_interface(self):
263267
numba_ary = numba.cuda.to_device(numpy_ary)
264268
torch_ary = torch.as_tensor(numba_ary, device="cuda")
265269
self.assertEqual(numba_ary.__cuda_array_interface__, torch_ary.__cuda_array_interface__)
266-
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
270+
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
267271

268272
# Check that `torch_ary` and `numba_ary` points to the same device memory
269273
torch_ary += 42
270-
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
274+
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
271275

272276
# Implicit-copy because `torch_ary` is a CPU array
273277
for numpy_ary in numpy_arys:
274278
numba_ary = numba.cuda.to_device(numpy_ary)
275279
torch_ary = torch.as_tensor(numba_ary, device="cpu")
276-
self.assertEqual(torch_ary.data.numpy(), numpy.asarray(numba_ary))
280+
self.assertEqual(torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
277281

278282
# Check that `torch_ary` and `numba_ary` points to different memory
279283
torch_ary += 42
280-
self.assertEqual(torch_ary.data.numpy(), numpy.asarray(numba_ary) + 42)
284+
self.assertEqual(torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype) + 42)
281285

282286
# Explicit-copy when using `torch.tensor()`
283287
for numpy_ary in numpy_arys:
284288
numba_ary = numba.cuda.to_device(numpy_ary)
285289
torch_ary = torch.tensor(numba_ary, device="cuda")
286-
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
290+
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
287291

288292
# Check that `torch_ary` and `numba_ary` points to different memory
289293
torch_ary += 42
290-
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary) + 42)
294+
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype) + 42)
291295

292296
@unittest.skipIf(not TEST_NUMPY, "No numpy")
293297
@unittest.skipIf(not TEST_CUDA, "No cuda")

‎torch/tensor.py

+2
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,8 @@ def __cuda_array_interface__(self):
690690
# CUDA devices are little-endian and tensors are stored in native byte
691691
# order. 1-byte entries are endian-agnostic.
692692
typestr = {
693+
torch.complex64: "<c8",
694+
torch.complex128: "<c16",
693695
torch.float16: "<f2",
694696
torch.float32: "<f4",
695697
torch.float64: "<f8",

0 commit comments

Comments
 (0)
Please sign in to comment.