3
3
import math
4
4
from contextlib import contextmanager
5
5
from itertools import product
6
+ import itertools
6
7
7
8
from torch .testing ._internal .common_utils import \
8
9
(TestCase , run_tests , TEST_NUMPY , TEST_LIBROSA )
11
12
skipCPUIfNoMkl , skipCUDAIfRocm , deviceCountAtLeast , onlyCUDA )
12
13
13
14
from distutils .version import LooseVersion
14
- from typing import Optional
15
+ from typing import Optional , List
15
16
16
17
17
18
if TEST_NUMPY :
@@ -115,6 +116,7 @@ def method_fn(t):
115
116
116
117
@skipCPUIfNoMkl
117
118
@skipCUDAIfRocm
119
+ @onlyOnCPUAndCUDA
118
120
@unittest .skipIf (not TEST_NUMPY , 'NumPy not found' )
119
121
@precisionOverride ({torch .complex64 : 1e-4 , torch .float : 1e-4 })
120
122
@dtypes (torch .float , torch .double , torch .complex64 , torch .complex128 )
@@ -226,11 +228,13 @@ def test_fft_round_trip(self, device, dtype):
226
228
def test_empty_fft (self , device , dtype ):
227
229
t = torch .empty (0 , device = device , dtype = dtype )
228
230
match = r"Invalid number of data points \([-\d]*\) specified"
229
- fft_functions = [torch .fft .fft , torch .fft .ifft , torch .fft .hfft ,
230
- torch .fft .irfft ]
231
+ fft_functions = [torch .fft .fft , torch .fft .fftn ,
232
+ torch .fft .ifft , torch .fft .ifftn ,
233
+ torch .fft .irfft , torch .fft .irfftn ,
234
+ torch .fft .hfft ]
231
235
# Real-only functions
232
236
if not dtype .is_complex :
233
- fft_functions += [torch .fft .rfft , torch .fft .ihfft ]
237
+ fft_functions += [torch .fft .rfft , torch .fft .rfftn , torch . fft . ihfft ]
234
238
235
239
for fn in fft_functions :
236
240
with self .assertRaisesRegex (RuntimeError , match ):
@@ -242,6 +246,9 @@ def test_fft_invalid_dtypes(self, device):
242
246
with self .assertRaisesRegex (RuntimeError , "Expected a real input tensor" ):
243
247
torch .fft .rfft (t )
244
248
249
+ with self .assertRaisesRegex (RuntimeError , "Expected a real input tensor" ):
250
+ torch .fft .rfftn (t )
251
+
245
252
with self .assertRaisesRegex (RuntimeError , "Expected a real input tensor" ):
246
253
torch .fft .ihfft (t )
247
254
@@ -292,14 +299,17 @@ def test_fft_half_errors(self, device, dtype):
292
299
# TODO: Remove torch.half error when complex32 is fully implemented
293
300
x = torch .randn (64 , device = device ).to (dtype )
294
301
fft_functions = (torch .fft .fft , torch .fft .ifft ,
302
+ torch .fft .fftn , torch .fft .ifftn ,
295
303
torch .fft .rfft , torch .fft .irfft ,
304
+ torch .fft .rfftn , torch .fft .irfftn ,
296
305
torch .fft .hfft , torch .fft .ihfft )
297
306
for fn in fft_functions :
298
307
with self .assertRaisesRegex (RuntimeError , "Unsupported dtype " ):
299
308
fn (x )
300
309
301
310
@skipCPUIfNoMkl
302
311
@skipCUDAIfRocm
312
+ @onlyOnCPUAndCUDA
303
313
@dtypes (torch .double , torch .complex128 ) # gradcheck requires double
304
314
def test_fft_backward (self , device , dtype ):
305
315
test_args = list (product (
@@ -340,6 +350,166 @@ def test_fn(x):
340
350
341
351
self .assertTrue (torch .autograd .gradcheck (test_fn , (input ,)))
342
352
353
+ # nd-fft tests
354
+
355
+ @skipCPUIfNoMkl
356
+ @skipCUDAIfRocm
357
+ @onlyOnCPUAndCUDA
358
+ @unittest .skipIf (not TEST_NUMPY , 'NumPy not found' )
359
+ @precisionOverride ({torch .complex64 : 1e-4 , torch .float : 1e-4 })
360
+ @dtypes (torch .float , torch .double , torch .complex64 , torch .complex128 )
361
+ def test_fftn_numpy (self , device , dtype ):
362
+ norm_modes = ((None , "forward" , "backward" , "ortho" )
363
+ if LooseVersion (np .__version__ ) >= '1.20.0'
364
+ else (None , "ortho" ))
365
+
366
+ # input_ndim, s, dim
367
+ transform_desc = [
368
+ * product (range (2 , 5 ), (None ,), (None , (0 ,), (0 , - 1 ))),
369
+ * product (range (2 , 5 ), (None , (4 , 10 )), (None ,)),
370
+ (6 , None , None ),
371
+ (5 , None , (1 , 3 , 4 )),
372
+ (3 , None , (0 , - 1 )),
373
+ (3 , None , (1 ,)),
374
+ (1 , None , (0 ,)),
375
+ (4 , (10 , 10 ), None ),
376
+ (4 , (10 , 10 ), (0 , 1 ))
377
+ ]
378
+
379
+ fft_functions = ['fftn' , 'ifftn' , 'irfftn' ]
380
+ # Real-only functions
381
+ if not dtype .is_complex :
382
+ fft_functions += ['rfftn' ]
383
+
384
+ for input_ndim , s , dim in transform_desc :
385
+ shape = itertools .islice (itertools .cycle (range (4 , 9 )), input_ndim )
386
+ input = torch .randn (* shape , device = device , dtype = dtype )
387
+ for fname , norm in product (fft_functions , norm_modes ):
388
+ torch_fn = getattr (torch .fft , fname )
389
+ numpy_fn = getattr (np .fft , fname )
390
+
391
+ def fn (t : torch .Tensor , s : Optional [List [int ]], dim : Optional [List [int ]], norm : Optional [str ]):
392
+ return torch_fn (t , s , dim , norm )
393
+
394
+ torch_fns = (torch_fn , torch .jit .script (fn ))
395
+
396
+ expected = numpy_fn (input .cpu ().numpy (), s , dim , norm )
397
+ exact_dtype = dtype in (torch .double , torch .complex128 )
398
+ for fn in torch_fns :
399
+ actual = fn (input , s , dim , norm )
400
+ self .assertEqual (actual , expected , exact_dtype = exact_dtype )
401
+
402
+ @skipCUDAIfRocm
403
+ @skipCPUIfNoMkl
404
+ @onlyOnCPUAndCUDA
405
+ @dtypes (torch .float , torch .double , torch .complex64 , torch .complex128 )
406
+ def test_fftn_round_trip (self , device , dtype ):
407
+ norm_modes = (None , "forward" , "backward" , "ortho" )
408
+
409
+ # input_ndim, dim
410
+ transform_desc = [
411
+ * product (range (2 , 5 ), (None , (0 ,), (0 , - 1 ))),
412
+ * product (range (2 , 5 ), (None ,)),
413
+ (7 , None ),
414
+ (5 , (1 , 3 , 4 )),
415
+ (3 , (0 , - 1 )),
416
+ (3 , (1 ,)),
417
+ (1 , 0 ),
418
+ ]
419
+
420
+ fft_functions = [(torch .fft .fftn , torch .fft .ifftn )]
421
+
422
+ # Real-only functions
423
+ if not dtype .is_complex :
424
+ fft_functions += [(torch .fft .rfftn , torch .fft .irfftn )]
425
+
426
+ for input_ndim , dim in transform_desc :
427
+ shape = itertools .islice (itertools .cycle (range (4 , 9 )), input_ndim )
428
+ x = torch .randn (* shape , device = device , dtype = dtype )
429
+
430
+ for (forward , backward ), norm in product (fft_functions , norm_modes ):
431
+ if isinstance (dim , tuple ):
432
+ s = [x .size (d ) for d in dim ]
433
+ else :
434
+ s = x .size () if dim is None else x .size (dim )
435
+
436
+ kwargs = {'s' : s , 'dim' : dim , 'norm' : norm }
437
+ y = backward (forward (x , ** kwargs ), ** kwargs )
438
+ # For real input, ifftn(fftn(x)) will convert to complex
439
+ self .assertEqual (x , y , exact_dtype = (
440
+ forward != torch .fft .fftn or x .is_complex ()))
441
+
442
+ @skipCPUIfNoMkl
443
+ @skipCUDAIfRocm
444
+ @onlyOnCPUAndCUDA
445
+ @dtypes (torch .double , torch .complex128 ) # gradcheck requires double
446
+ def test_fftn_backward (self , device , dtype ):
447
+ # input_ndim, s, dim
448
+ transform_desc = [
449
+ * product ((2 , 3 ), (None ,), (None , (0 ,), (0 , - 1 ))),
450
+ * product ((2 , 3 ), (None , (4 , 10 )), (None ,)),
451
+ (4 , None , None ),
452
+ (3 , (10 , 10 ), (0 , 1 )),
453
+ (2 , (1 , 1 ), (0 , 1 )),
454
+ (2 , None , (1 ,)),
455
+ (1 , None , (0 ,)),
456
+ (1 , (11 ,), (0 ,)),
457
+ ]
458
+ norm_modes = (None , "forward" , "backward" , "ortho" )
459
+
460
+ fft_functions = ['fftn' , 'ifftn' , 'irfftn' ]
461
+ # Real-only functions
462
+ if not dtype .is_complex :
463
+ fft_functions += ['rfftn' ]
464
+
465
+ for input_ndim , s , dim in transform_desc :
466
+ shape = itertools .islice (itertools .cycle (range (4 , 9 )), input_ndim )
467
+ input = torch .randn (* shape , device = device , dtype = dtype )
468
+
469
+ for fname , norm in product (fft_functions , norm_modes ):
470
+ torch_fn = getattr (torch .fft , fname )
471
+
472
+ # Workaround for gradcheck's poor support for complex input
473
+ # Use real input instead and put view_as_complex into the graph
474
+ if dtype .is_complex :
475
+ def test_fn (x ):
476
+ return torch_fn (torch .view_as_complex (x ), s , dim , norm )
477
+ inputs = (torch .view_as_real (input ).detach ().requires_grad_ (),)
478
+ else :
479
+ def test_fn (x ):
480
+ return torch_fn (x , s , dim , norm )
481
+ inputs = (input .detach ().requires_grad_ (),)
482
+
483
+ self .assertTrue (torch .autograd .gradcheck (test_fn , inputs ))
484
+
485
+ @skipCUDAIfRocm
486
+ @skipCPUIfNoMkl
487
+ @onlyOnCPUAndCUDA
488
+ def test_fftn_invalid (self , device ):
489
+ a = torch .rand (10 , 10 , 10 , device = device )
490
+ fft_funcs = (torch .fft .fftn , torch .fft .ifftn ,
491
+ torch .fft .rfftn , torch .fft .irfftn )
492
+
493
+ for func in fft_funcs :
494
+ with self .assertRaisesRegex (RuntimeError , "FFT dims must be unique" ):
495
+ func (a , dim = (0 , 1 , 0 ))
496
+
497
+ with self .assertRaisesRegex (RuntimeError , "FFT dims must be unique" ):
498
+ func (a , dim = (2 , - 1 ))
499
+
500
+ with self .assertRaisesRegex (RuntimeError , "dim and shape .* same length" ):
501
+ func (a , s = (1 ,), dim = (0 , 1 ))
502
+
503
+ with self .assertRaisesRegex (IndexError , "Dimension out of range" ):
504
+ func (a , dim = (3 ,))
505
+
506
+ with self .assertRaisesRegex (RuntimeError , "tensor only has 3 dimensions" ):
507
+ func (a , s = (10 , 10 , 10 , 10 ))
508
+
509
+ c = torch .complex (a , a )
510
+ with self .assertRaisesRegex (RuntimeError , "Expected a real input" ):
511
+ torch .fft .rfftn (c )
512
+
343
513
# Legacy fft tests
344
514
def _test_fft_ifft_rfft_irfft (self , device , dtype ):
345
515
def _test_complex (sizes , signal_ndim , prepro_fn = lambda x : x ):
0 commit comments