Skip to content

Commit 6a2e9eb

Browse files
peterbell10facebook-github-bot
authored andcommittedSep 24, 2020
torch.fft: Multi-dimensional transforms (pytorch#44550)
Summary: Pull Request resolved: pytorch#44550 Part of the `torch.fft` work (pytorchgh-42175). This adds n-dimensional transforms: `fftn`, `ifftn`, `rfftn` and `irfftn`. This is aiming for correctness first, with the implementation on top of the existing `_fft_with_size` restrictions. I plan to follow up later with a more efficient rewrite that makes `_fft_with_size` work with arbitrary numbers of dimensions. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D23846032 Pulled By: mruberry fbshipit-source-id: e6950aa8be438ec5cb95fb10bd7b8bc9ffb7d824
1 parent 070fe15 commit 6a2e9eb

File tree

7 files changed

+662
-7
lines changed

7 files changed

+662
-7
lines changed
 

‎aten/src/ATen/WrapDimUtils.h

+11-3
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ static inline int64_t maybe_wrap_dim(int64_t dim, const std::vector<std::vector<
3030
return maybe_wrap_dim(dim, tensor_sizes[0].size());
3131
}
3232

33-
// wrap each of dims basing on dim_post_expr
34-
static inline void maybe_wrap_dims(std::vector<int64_t>& dims, int64_t dim_post_expr) {
33+
// wrap each dim in the dims array, taking dim_post_expr as the true number of dimensions
34+
static inline void maybe_wrap_dims_n(int64_t* dims, int64_t ndims, int64_t dim_post_expr) {
3535
if (dim_post_expr <= 0) {
3636
dim_post_expr = 1; // this will make range [-1, 0]
3737
}
3838
int64_t min = -dim_post_expr;
3939
int64_t max = dim_post_expr - 1;
40-
for (auto& dim : dims) {
40+
for (int64_t i = 0; i < ndims; ++i) {
41+
auto &dim = dims[i];
4142
if (dim < min || dim > max) {
4243
TORCH_CHECK_INDEX(false,
4344
"Dimension out of range (expected to be in range of [",
@@ -47,6 +48,13 @@ static inline void maybe_wrap_dims(std::vector<int64_t>& dims, int64_t dim_post_
4748
}
4849
}
4950

51+
// Wrap each dim in a contiguous container, taking dim_post_expr as the true number of dimensions
52+
// E.g. could also be std::array or c10::SmallVector
53+
template <typename Container>
54+
inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) {
55+
return maybe_wrap_dims_n(dims.data(), dims.size(), dim_post_expr);
56+
}
57+
5058
// previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
5159
// to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
5260
// to be "skipped" (both for wrap dimension behavior and dimension size checking).

‎aten/src/ATen/native/SpectralOps.cpp

+168
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,116 @@ Tensor fft_c2c(Tensor input, c10::optional<int64_t> n_opt,
203203
return out;
204204
}
205205

206+
// Dimensions to transform, and the signal shape in those dimensions
207+
struct ShapeAndDims {
208+
DimVector shape, dim;
209+
};
210+
211+
// Pre-process n-dimensional fft's `s` and `dim` arguments.
212+
// Wraps dimensions and applies defaulting behavior.
213+
// Also checks transform dims are unique and transform shape is non-empty.
214+
ShapeAndDims canonicalize_fft_shape_and_dim_args(
215+
Tensor input, c10::optional<IntArrayRef> shape, c10::optional<IntArrayRef> dim) {
216+
const int64_t input_dim = input.dim();
217+
const IntArrayRef input_sizes = input.sizes();
218+
ShapeAndDims ret;
219+
220+
if (dim) {
221+
ret.dim.resize(dim->size());
222+
std::copy(dim->begin(), dim->end(), ret.dim.begin());
223+
maybe_wrap_dims(ret.dim, input_dim);
224+
225+
// Check dims are unique
226+
DimVector copy = ret.dim;
227+
std::sort(copy.begin(), copy.end());
228+
auto duplicate = std::adjacent_find(copy.begin(), copy.end());
229+
TORCH_CHECK(duplicate == copy.end(), "FFT dims must be unique");
230+
}
231+
232+
if (shape) {
233+
// Has shape, may have dim
234+
TORCH_CHECK(!dim || dim->size() == shape->size(),
235+
"When given, dim and shape arguments must have the same length");
236+
TORCH_CHECK(shape->size() <= input_dim,
237+
"Got shape with ", shape->size(), " values but input tensor "
238+
"only has ", input_dim, " dimensions.");
239+
const int64_t transform_ndim = shape->size();
240+
// If shape is given, dims defaults to the last shape.size() dimensions
241+
if (!dim) {
242+
ret.dim.resize(transform_ndim);
243+
std::iota(ret.dim.begin(), ret.dim.end(), input_dim - transform_ndim);
244+
}
245+
246+
// Translate shape of -1 to the default length
247+
ret.shape.resize(transform_ndim);
248+
for (int64_t i = 0; i < transform_ndim; ++i) {
249+
const auto n = (*shape)[i];
250+
ret.shape[i] = n == -1 ? input_sizes[ret.dim[i]] : n;
251+
}
252+
} else if (!dim) {
253+
// No shape, no dim
254+
ret.dim.resize(input_dim);
255+
std::iota(ret.dim.begin(), ret.dim.end(), int64_t{0});
256+
ret.shape.resize(input_dim);
257+
std::copy(input_sizes.begin(), input_sizes.end(), ret.shape.begin());
258+
} else {
259+
// No shape, has dim
260+
ret.shape.resize(ret.dim.size());
261+
for (int64_t i = 0; i < ret.dim.size(); ++i) {
262+
ret.shape[i] = input_sizes[ret.dim[i]];
263+
}
264+
}
265+
266+
for (int64_t i = 0; i < ret.shape.size(); ++i) {
267+
TORCH_CHECK(ret.shape[i] > 0,
268+
"Invalid number of data points (", ret.shape[i], ") specified");
269+
}
270+
271+
return ret;
272+
}
273+
274+
// Complex to complex n-dimensional fft
275+
Tensor fftn_c2c(
276+
const Tensor& input, IntArrayRef shape, IntArrayRef dim,
277+
c10::optional<std::string> norm_str, bool forward) {
278+
TORCH_CHECK(input.is_complex(), "Expected a complex input tensor to FFT");
279+
const auto input_dim = input.dim();
280+
281+
Tensor x = resize_fft_input(input, dim, shape);
282+
x = at::view_as_real(x);
283+
284+
const int64_t transform_ndim = dim.size();
285+
const auto norm = norm_from_string(norm_str, forward);
286+
// _fft_with_size only supports 3 dimensions being transformed at a time.
287+
// This limit is inherited from cuFFT.
288+
constexpr int64_t max_signal_ndim = 3;
289+
290+
// Transform n dimensions, up to 3 at a time
291+
// TODO: rewrite _fft_with_size to transform more than 3 dimensions at once.
292+
for (int64_t i = 0; i < transform_ndim; i += max_signal_ndim) {
293+
const int64_t signal_ndim = std::min(transform_ndim - i, max_signal_ndim);
294+
DimVector source_dim(signal_ndim);
295+
DimVector dest_dim(signal_ndim);
296+
297+
for (int64_t j = 0; j < signal_ndim; ++j) {
298+
source_dim[j] = dim[i + j];
299+
dest_dim[j] = j + (input_dim - signal_ndim);
300+
}
301+
302+
// _fft operates on up-to the last 3 dims, so move selected dims to the end
303+
x = at::movedim(x, source_dim, dest_dim);
304+
305+
x = _fft(x, signal_ndim, /*complex_input=*/true, /*complex_output=*/true,
306+
/*inverse=*/!forward, /*signal_sizes=*/{}, /*normalization=*/norm,
307+
/*onesided=*/false);
308+
309+
// Move transform dims back to their original order
310+
x = at::movedim(x, dest_dim, source_dim);
311+
}
312+
313+
return at::view_as_complex(x);
314+
}
315+
206316
}
207317

208318
// torch.fft.fft, analogous to NumPy's numpy.fft.fft
@@ -240,6 +350,64 @@ Tensor fft_ihfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
240350
return fft_r2c(self, n, dim, norm, /*forward=*/false, /*onesided=*/true);
241351
}
242352

353+
Tensor fft_fftn(const Tensor& self, c10::optional<IntArrayRef> s,
354+
c10::optional<IntArrayRef> dim,
355+
c10::optional<std::string> norm) {
356+
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
357+
// TODO: For real input, perform rfftn then mirror with conjugate symmetry
358+
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
359+
return fftn_c2c(input, desc.shape, desc.dim, norm, /*forward=*/true);
360+
}
361+
362+
Tensor fft_ifftn(const Tensor& self, c10::optional<IntArrayRef> s,
363+
c10::optional<IntArrayRef> dim,
364+
c10::optional<std::string> norm) {
365+
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
366+
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
367+
return fftn_c2c(input, desc.shape, desc.dim, norm, /*forward=*/false);
368+
}
369+
370+
Tensor fft_rfftn(const Tensor& self, c10::optional<IntArrayRef> s,
371+
c10::optional<IntArrayRef> dim,
372+
c10::optional<std::string> norm) {
373+
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
374+
TORCH_CHECK(desc.shape.size() > 0, "rfftn must transform at least one axis");
375+
376+
const auto last_dim = desc.dim.back();
377+
const auto last_shape = desc.shape.back();
378+
desc.shape.pop_back();
379+
desc.dim.pop_back();
380+
381+
// rfft on last dim to get hermitian complex shape
382+
auto x = native::fft_rfft(self, last_shape, last_dim, norm);
383+
// Normal fft on remaining dims
384+
return fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/true);
385+
}
386+
387+
Tensor fft_irfftn(const Tensor& self, c10::optional<IntArrayRef> s,
388+
c10::optional<IntArrayRef> dim,
389+
c10::optional<std::string> norm) {
390+
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
391+
TORCH_CHECK(desc.shape.size() > 0, "irfftn must transform at least one axis");
392+
393+
const auto last_dim = desc.dim.back();
394+
const auto last_shape = [&]() -> c10::optional<int64_t> {
395+
// If shape is defaulted in the last dimension,
396+
// pass nullopt to irfft and let it calculate the default size
397+
if (!s.has_value() || (s->back() == -1)) {
398+
return c10::nullopt;
399+
}
400+
return desc.shape.back();
401+
}();
402+
desc.shape.pop_back();
403+
desc.dim.pop_back();
404+
405+
// Normal ifft for all but last dim
406+
Tensor x = promote_tensor_fft(self, /*require_complex=*/true);
407+
x = fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/false);
408+
// Then 1d irfft on last dim to get real output
409+
return native::fft_irfft(x, last_shape, last_dim, norm);
410+
}
243411

244412
// This is a pass-through wrapper function that does the size check and
245413
// inferences. The actual forward implementation function is called

‎aten/src/ATen/native/native_functions.yaml

+20
Original file line numberDiff line numberDiff line change
@@ -7943,6 +7943,26 @@
79437943
use_c10_dispatcher: full
79447944
variants: function
79457945

7946+
- func: fft_fftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
7947+
python_module: fft
7948+
use_c10_dispatcher: full
7949+
variants: function
7950+
7951+
- func: fft_ifftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
7952+
python_module: fft
7953+
use_c10_dispatcher: full
7954+
variants: function
7955+
7956+
- func: fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
7957+
python_module: fft
7958+
use_c10_dispatcher: full
7959+
variants: function
7960+
7961+
- func: fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
7962+
python_module: fft
7963+
use_c10_dispatcher: full
7964+
variants: function
7965+
79467966
- func: fft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor
79477967
use_c10_dispatcher: full
79487968
variants: function, method

‎docs/source/fft.rst

+4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ Functions
1919

2020
.. autofunction:: fft
2121
.. autofunction:: ifft
22+
.. autofunction:: fftn
23+
.. autofunction:: ifftn
2224
.. autofunction:: rfft
2325
.. autofunction:: irfft
26+
.. autofunction:: rfftn
27+
.. autofunction:: irfftn
2428
.. autofunction:: hfft
2529
.. autofunction:: ihfft

‎test/test_spectral_ops.py

+174-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
from contextlib import contextmanager
55
from itertools import product
6+
import itertools
67

78
from torch.testing._internal.common_utils import \
89
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA)
@@ -11,7 +12,7 @@
1112
skipCPUIfNoMkl, skipCUDAIfRocm, deviceCountAtLeast, onlyCUDA)
1213

1314
from distutils.version import LooseVersion
14-
from typing import Optional
15+
from typing import Optional, List
1516

1617

1718
if TEST_NUMPY:
@@ -115,6 +116,7 @@ def method_fn(t):
115116

116117
@skipCPUIfNoMkl
117118
@skipCUDAIfRocm
119+
@onlyOnCPUAndCUDA
118120
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
119121
@precisionOverride({torch.complex64: 1e-4, torch.float: 1e-4})
120122
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
@@ -226,11 +228,13 @@ def test_fft_round_trip(self, device, dtype):
226228
def test_empty_fft(self, device, dtype):
227229
t = torch.empty(0, device=device, dtype=dtype)
228230
match = r"Invalid number of data points \([-\d]*\) specified"
229-
fft_functions = [torch.fft.fft, torch.fft.ifft, torch.fft.hfft,
230-
torch.fft.irfft]
231+
fft_functions = [torch.fft.fft, torch.fft.fftn,
232+
torch.fft.ifft, torch.fft.ifftn,
233+
torch.fft.irfft, torch.fft.irfftn,
234+
torch.fft.hfft]
231235
# Real-only functions
232236
if not dtype.is_complex:
233-
fft_functions += [torch.fft.rfft, torch.fft.ihfft]
237+
fft_functions += [torch.fft.rfft, torch.fft.rfftn, torch.fft.ihfft]
234238

235239
for fn in fft_functions:
236240
with self.assertRaisesRegex(RuntimeError, match):
@@ -242,6 +246,9 @@ def test_fft_invalid_dtypes(self, device):
242246
with self.assertRaisesRegex(RuntimeError, "Expected a real input tensor"):
243247
torch.fft.rfft(t)
244248

249+
with self.assertRaisesRegex(RuntimeError, "Expected a real input tensor"):
250+
torch.fft.rfftn(t)
251+
245252
with self.assertRaisesRegex(RuntimeError, "Expected a real input tensor"):
246253
torch.fft.ihfft(t)
247254

@@ -292,14 +299,17 @@ def test_fft_half_errors(self, device, dtype):
292299
# TODO: Remove torch.half error when complex32 is fully implemented
293300
x = torch.randn(64, device=device).to(dtype)
294301
fft_functions = (torch.fft.fft, torch.fft.ifft,
302+
torch.fft.fftn, torch.fft.ifftn,
295303
torch.fft.rfft, torch.fft.irfft,
304+
torch.fft.rfftn, torch.fft.irfftn,
296305
torch.fft.hfft, torch.fft.ihfft)
297306
for fn in fft_functions:
298307
with self.assertRaisesRegex(RuntimeError, "Unsupported dtype "):
299308
fn(x)
300309

301310
@skipCPUIfNoMkl
302311
@skipCUDAIfRocm
312+
@onlyOnCPUAndCUDA
303313
@dtypes(torch.double, torch.complex128) # gradcheck requires double
304314
def test_fft_backward(self, device, dtype):
305315
test_args = list(product(
@@ -340,6 +350,166 @@ def test_fn(x):
340350

341351
self.assertTrue(torch.autograd.gradcheck(test_fn, (input,)))
342352

353+
# nd-fft tests
354+
355+
@skipCPUIfNoMkl
356+
@skipCUDAIfRocm
357+
@onlyOnCPUAndCUDA
358+
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
359+
@precisionOverride({torch.complex64: 1e-4, torch.float: 1e-4})
360+
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
361+
def test_fftn_numpy(self, device, dtype):
362+
norm_modes = ((None, "forward", "backward", "ortho")
363+
if LooseVersion(np.__version__) >= '1.20.0'
364+
else (None, "ortho"))
365+
366+
# input_ndim, s, dim
367+
transform_desc = [
368+
*product(range(2, 5), (None,), (None, (0,), (0, -1))),
369+
*product(range(2, 5), (None, (4, 10)), (None,)),
370+
(6, None, None),
371+
(5, None, (1, 3, 4)),
372+
(3, None, (0, -1)),
373+
(3, None, (1,)),
374+
(1, None, (0,)),
375+
(4, (10, 10), None),
376+
(4, (10, 10), (0, 1))
377+
]
378+
379+
fft_functions = ['fftn', 'ifftn', 'irfftn']
380+
# Real-only functions
381+
if not dtype.is_complex:
382+
fft_functions += ['rfftn']
383+
384+
for input_ndim, s, dim in transform_desc:
385+
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
386+
input = torch.randn(*shape, device=device, dtype=dtype)
387+
for fname, norm in product(fft_functions, norm_modes):
388+
torch_fn = getattr(torch.fft, fname)
389+
numpy_fn = getattr(np.fft, fname)
390+
391+
def fn(t: torch.Tensor, s: Optional[List[int]], dim: Optional[List[int]], norm: Optional[str]):
392+
return torch_fn(t, s, dim, norm)
393+
394+
torch_fns = (torch_fn, torch.jit.script(fn))
395+
396+
expected = numpy_fn(input.cpu().numpy(), s, dim, norm)
397+
exact_dtype = dtype in (torch.double, torch.complex128)
398+
for fn in torch_fns:
399+
actual = fn(input, s, dim, norm)
400+
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
401+
402+
@skipCUDAIfRocm
403+
@skipCPUIfNoMkl
404+
@onlyOnCPUAndCUDA
405+
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
406+
def test_fftn_round_trip(self, device, dtype):
407+
norm_modes = (None, "forward", "backward", "ortho")
408+
409+
# input_ndim, dim
410+
transform_desc = [
411+
*product(range(2, 5), (None, (0,), (0, -1))),
412+
*product(range(2, 5), (None,)),
413+
(7, None),
414+
(5, (1, 3, 4)),
415+
(3, (0, -1)),
416+
(3, (1,)),
417+
(1, 0),
418+
]
419+
420+
fft_functions = [(torch.fft.fftn, torch.fft.ifftn)]
421+
422+
# Real-only functions
423+
if not dtype.is_complex:
424+
fft_functions += [(torch.fft.rfftn, torch.fft.irfftn)]
425+
426+
for input_ndim, dim in transform_desc:
427+
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
428+
x = torch.randn(*shape, device=device, dtype=dtype)
429+
430+
for (forward, backward), norm in product(fft_functions, norm_modes):
431+
if isinstance(dim, tuple):
432+
s = [x.size(d) for d in dim]
433+
else:
434+
s = x.size() if dim is None else x.size(dim)
435+
436+
kwargs = {'s': s, 'dim': dim, 'norm': norm}
437+
y = backward(forward(x, **kwargs), **kwargs)
438+
# For real input, ifftn(fftn(x)) will convert to complex
439+
self.assertEqual(x, y, exact_dtype=(
440+
forward != torch.fft.fftn or x.is_complex()))
441+
442+
@skipCPUIfNoMkl
443+
@skipCUDAIfRocm
444+
@onlyOnCPUAndCUDA
445+
@dtypes(torch.double, torch.complex128) # gradcheck requires double
446+
def test_fftn_backward(self, device, dtype):
447+
# input_ndim, s, dim
448+
transform_desc = [
449+
*product((2, 3), (None,), (None, (0,), (0, -1))),
450+
*product((2, 3), (None, (4, 10)), (None,)),
451+
(4, None, None),
452+
(3, (10, 10), (0, 1)),
453+
(2, (1, 1), (0, 1)),
454+
(2, None, (1,)),
455+
(1, None, (0,)),
456+
(1, (11,), (0,)),
457+
]
458+
norm_modes = (None, "forward", "backward", "ortho")
459+
460+
fft_functions = ['fftn', 'ifftn', 'irfftn']
461+
# Real-only functions
462+
if not dtype.is_complex:
463+
fft_functions += ['rfftn']
464+
465+
for input_ndim, s, dim in transform_desc:
466+
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
467+
input = torch.randn(*shape, device=device, dtype=dtype)
468+
469+
for fname, norm in product(fft_functions, norm_modes):
470+
torch_fn = getattr(torch.fft, fname)
471+
472+
# Workaround for gradcheck's poor support for complex input
473+
# Use real input instead and put view_as_complex into the graph
474+
if dtype.is_complex:
475+
def test_fn(x):
476+
return torch_fn(torch.view_as_complex(x), s, dim, norm)
477+
inputs = (torch.view_as_real(input).detach().requires_grad_(),)
478+
else:
479+
def test_fn(x):
480+
return torch_fn(x, s, dim, norm)
481+
inputs = (input.detach().requires_grad_(),)
482+
483+
self.assertTrue(torch.autograd.gradcheck(test_fn, inputs))
484+
485+
@skipCUDAIfRocm
486+
@skipCPUIfNoMkl
487+
@onlyOnCPUAndCUDA
488+
def test_fftn_invalid(self, device):
489+
a = torch.rand(10, 10, 10, device=device)
490+
fft_funcs = (torch.fft.fftn, torch.fft.ifftn,
491+
torch.fft.rfftn, torch.fft.irfftn)
492+
493+
for func in fft_funcs:
494+
with self.assertRaisesRegex(RuntimeError, "FFT dims must be unique"):
495+
func(a, dim=(0, 1, 0))
496+
497+
with self.assertRaisesRegex(RuntimeError, "FFT dims must be unique"):
498+
func(a, dim=(2, -1))
499+
500+
with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
501+
func(a, s=(1,), dim=(0, 1))
502+
503+
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
504+
func(a, dim=(3,))
505+
506+
with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"):
507+
func(a, s=(10, 10, 10, 10))
508+
509+
c = torch.complex(a, a)
510+
with self.assertRaisesRegex(RuntimeError, "Expected a real input"):
511+
torch.fft.rfftn(c)
512+
343513
# Legacy fft tests
344514
def _test_fft_ifft_rfft_irfft(self, device, dtype):
345515
def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):

‎torch/csrc/api/include/torch/fft.h

+60
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,36 @@ inline Tensor ifft(const Tensor& self,
3535
return torch::fft_ifft(self, n, dim, norm);
3636
}
3737

38+
/// Computes the N dimensional fast Fourier transform over given dimensions.
39+
/// See https://pytorch.org/docs/master/fft.html#torch.fft.fftn.
40+
///
41+
/// Example:
42+
/// ```
43+
/// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
44+
/// torch::fft::fftn(t);
45+
/// ```
46+
inline Tensor fftn(const Tensor& self,
47+
c10::optional<IntArrayRef> s=c10::nullopt,
48+
c10::optional<IntArrayRef> dim=c10::nullopt,
49+
c10::optional<std::string> norm=c10::nullopt) {
50+
return torch::fft_fftn(self, s, dim, norm);
51+
}
52+
53+
/// Computes the N dimensional fast Fourier transform over given dimensions.
54+
/// See https://pytorch.org/docs/master/fft.html#torch.fft.ifftn.
55+
///
56+
/// Example:
57+
/// ```
58+
/// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
59+
/// torch::fft::ifftn(t);
60+
/// ```
61+
inline Tensor ifftn(const Tensor& self,
62+
c10::optional<IntArrayRef> s=c10::nullopt,
63+
c10::optional<IntArrayRef> dim=c10::nullopt,
64+
c10::optional<std::string> norm=c10::nullopt) {
65+
return torch::fft_ifftn(self, s, dim, norm);
66+
}
67+
3868
/// Computes the 1 dimensional FFT of real input with onesided Hermitian output.
3969
/// See https://pytorch.org/docs/master/fft.html#torch.fft.rfft.
4070
///
@@ -69,6 +99,36 @@ inline Tensor irfft(const Tensor& self,
6999
return torch::fft_irfft(self, n, dim, norm);
70100
}
71101

102+
/// Computes the N dimensional FFT of real input with onesided Hermitian output.
103+
/// See https://pytorch.org/docs/master/fft.html#torch.fft.rfftn
104+
///
105+
/// Example:
106+
/// ```
107+
/// auto t = torch::randn({128, 128}, dtype=kDouble);
108+
/// torch::fft::rfftn(t);
109+
/// ```
110+
inline Tensor rfftn(const Tensor& self,
111+
c10::optional<IntArrayRef> s=c10::nullopt,
112+
c10::optional<IntArrayRef> dim=c10::nullopt,
113+
c10::optional<std::string> norm=c10::nullopt) {
114+
return torch::fft_rfftn(self, s, dim, norm);
115+
}
116+
117+
/// Computes the inverse of torch.fft.rfftn.
118+
/// See https://pytorch.org/docs/master/fft.html#torch.fft.irfftn.
119+
///
120+
/// Example:
121+
/// ```
122+
/// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
123+
/// torch::fft::irfftn(t);
124+
/// ```
125+
inline Tensor irfftn(const Tensor& self,
126+
c10::optional<IntArrayRef> s=c10::nullopt,
127+
c10::optional<IntArrayRef> dim=c10::nullopt,
128+
c10::optional<std::string> norm=c10::nullopt) {
129+
return torch::fft_irfftn(self, s, dim, norm);
130+
}
131+
72132
/// Computes the 1 dimensional FFT of a onesided Hermitian signal
73133
///
74134
/// The input represents a Hermitian symmetric time domain signal. The returned

‎torch/fft/__init__.py

+225
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,101 @@
8787
tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j])
8888
""")
8989

90+
fftn = _add_docstr(_fft.fft_fftn, r"""
91+
fftn(input, s=None, dim=None, norm=None) -> Tensor
92+
93+
Computes the N dimensional discrete Fourier transform of :attr:`input`.
94+
95+
Note:
96+
97+
The Fourier domain representation of any real signal satisfies the
98+
Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This
99+
function always returns all positive and negative frequency terms even
100+
though, for real inputs, half of these values are redundant.
101+
:func:`~torch.fft.rfftn` returns the more compact one-sided representation
102+
where only the positive frequencies of the last dimension are returned.
103+
104+
Args:
105+
input (Tensor): the input tensor
106+
s (Tuple[int], optional): Signal size in the transformed dimensions.
107+
If given, each dimension ``dim[i]`` will either be zero-padded or
108+
trimmed to the length ``s[i]`` before computing the FFT.
109+
If a length ``-1`` is specified, no padding is done in that dimension.
110+
Default: ``s = [input.size(d) for d in dim]``
111+
dim (Tuple[int], optional): Dimensions to be transformed.
112+
Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
113+
norm (str, optional): Normalization mode. For the forward transform
114+
(:func:`~torch.fft.fftn`), these correspond to:
115+
116+
* ``"forward"`` - normalize by ``1/n``
117+
* ``"backward"`` - no normalization
118+
* ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
119+
120+
Where ``n = prod(s)`` is the logical FFT size.
121+
Calling the backward transform (:func:`~torch.fft.ifftn`) with the same
122+
normalization mode will apply an overall normalization of ``1/n``
123+
between the two transforms. This is required to make
124+
:func:`~torch.fft.ifftn` the exact inverse.
125+
126+
Default is ``"backward"`` (no normalization).
127+
128+
Example:
129+
130+
>>> import torch.fft
131+
>>> x = torch.rand(10, 10, dtype=torch.complex64)
132+
>>> fftn = torch.fft.fftn(t)
133+
134+
The discrete Fourier transform is separable, so :func:`~torch.fft.fftn`
135+
here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls:
136+
137+
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
138+
>>> torch.allclose(fftn, two_ffts)
139+
140+
""")
141+
142+
ifftn = _add_docstr(_fft.fft_ifftn, r"""
143+
ifftn(input, s=None, dim=None, norm=None) -> Tensor
144+
145+
Computes the N dimensional inverse discrete Fourier transform of :attr:`input`.
146+
147+
Args:
148+
input (Tensor): the input tensor
149+
s (Tuple[int], optional): Signal size in the transformed dimensions.
150+
If given, each dimension ``dim[i]`` will either be zero-padded or
151+
trimmed to the length ``s[i]`` before computing the IFFT.
152+
If a length ``-1`` is specified, no padding is done in that dimension.
153+
Default: ``s = [input.size(d) for d in dim]``
154+
dim (Tuple[int], optional): Dimensions to be transformed.
155+
Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
156+
norm (str, optional): Normalization mode. For the backward transform
157+
(:func:`~torch.fft.ifftn`), these correspond to:
158+
159+
* ``"forward"`` - no normalization
160+
* ``"backward"`` - normalize by ``1/n``
161+
* ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)
162+
163+
Where ``n = prod(s)`` is the logical IFFT size.
164+
Calling the forward transform (:func:`~torch.fft.fftn`) with the same
165+
normalization mode will apply an overall normalization of ``1/n`` between
166+
the two transforms. This is required to make :func:`~torch.fft.ifftn`
167+
the exact inverse.
168+
169+
Default is ``"backward"`` (normalize by ``1/n``).
170+
171+
Example:
172+
173+
>>> import torch.fft
174+
>>> x = torch.rand(10, 10, dtype=torch.complex64)
175+
>>> ifftn = torch.fft.ifftn(t)
176+
177+
The discrete Fourier transform is separable, so :func:`~torch.fft.ifftn`
178+
here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls:
179+
180+
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
181+
>>> torch.allclose(ifftn, two_iffts)
182+
183+
""")
184+
90185
rfft = _add_docstr(_fft.fft_rfft, r"""
91186
rfft(input, n=None, dim=-1, norm=None) -> Tensor
92187
@@ -199,6 +294,136 @@
199294
tensor([0.0000, 1.0000, 2.0000, 3.0000, 4.0000])
200295
""")
201296

297+
rfftn = _add_docstr(_fft.fft_rfftn, r"""
298+
rfftn(input, s=None, dim=None, norm=None) -> Tensor
299+
300+
Computes the N-dimensional discrete Fourier transform of real :attr:`input`.
301+
302+
The FFT of a real signal is Hermitian-symmetric,
303+
``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])`` so the full
304+
:func:`~torch.fft.fftn` output contains redundant information.
305+
:func:`~torch.fft.rfftn` instead omits the negative frequencies in the
306+
last dimension.
307+
308+
Args:
309+
input (Tensor): the input tensor
310+
s (Tuple[int], optional): Signal size in the transformed dimensions.
311+
If given, each dimension ``dim[i]`` will either be zero-padded or
312+
trimmed to the length ``s[i]`` before computing the real FFT.
313+
If a length ``-1`` is specified, no padding is done in that dimension.
314+
Default: ``s = [input.size(d) for d in dim]``
315+
dim (Tuple[int], optional): Dimensions to be transformed.
316+
Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
317+
norm (str, optional): Normalization mode. For the forward transform
318+
(:func:`~torch.fft.rfftn`), these correspond to:
319+
320+
* ``"forward"`` - normalize by ``1/n``
321+
* ``"backward"`` - no normalization
322+
* ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal)
323+
324+
Where ``n = prod(s)`` is the logical FFT size.
325+
Calling the backward transform (:func:`~torch.fft.irfftn`) with the same
326+
normalization mode will apply an overall normalization of ``1/n`` between
327+
the two transforms. This is required to make :func:`~torch.fft.irfftn`
328+
the exact inverse.
329+
330+
Default is ``"backward"`` (no normalization).
331+
332+
Example:
333+
334+
>>> import torch.fft
335+
>>> t = torch.rand(10, 10)
336+
>>> rfftn = torch.fft.rfftn(t)
337+
>>> rfftn.size()
338+
torch.Size([10, 6])
339+
340+
Compared against the full output from :func:`~torch.fft.fftn`, we have all
341+
elements up to the Nyquist frequency.
342+
343+
>>> fftn = torch.fft.fftn(t)
344+
>>> torch.allclose(fftn[..., :6], rfftn)
345+
True
346+
347+
The discrete Fourier transform is separable, so :func:`~torch.fft.rfftn`
348+
here is equivalent to a combination of :func:`~torch.fft.fft` and
349+
:func:`~torch.fft.rfft`:
350+
351+
>>> two_ffts = torch.fft.fft(torch.fft.rfft(x, dim=1), dim=0)
352+
>>> torch.allclose(rfftn, two_ffts)
353+
354+
""")
355+
356+
irfftn = _add_docstr(_fft.fft_irfftn, r"""
357+
irfftn(input, s=None, dim=None, norm=None) -> Tensor
358+
359+
Computes the inverse of :func:`~torch.fft.rfftn`.
360+
361+
:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier
362+
domain, as produced by :func:`~torch.fft.rfftn`. By the Hermitian property, the
363+
output will be real-valued.
364+
365+
Note:
366+
Some input frequencies must be real-valued to satisfy the Hermitian
367+
property. In these cases the imaginary component will be ignored.
368+
For example, any imaginary component in the zero-frequency term cannot
369+
be represented in a real output and so will always be ignored.
370+
371+
Note:
372+
The correct interpretation of the Hermitian input depends on the length of
373+
the original data, as given by :attr:`s`. This is because each input shape
374+
could correspond to either an odd or even length signal. By default, the
375+
signal is assumed to be even length and odd signals will not round-trip
376+
properly. So, it is recommended to always pass the signal shape :attr:`s`.
377+
378+
Args:
379+
input (Tensor): the input tensor
380+
s (Tuple[int], optional): Signal size in the transformed dimensions.
381+
If given, each dimension ``dim[i]`` will either be zero-padded or
382+
trimmed to the length ``s[i]`` before computing the real FFT.
383+
If a length ``-1`` is specified, no padding is done in that dimension.
384+
Defaults to even output in the last dimension:
385+
``s[-1] = 2*(input.size(dim[-1]) - 1)``.
386+
dim (Tuple[int], optional): Dimensions to be transformed.
387+
The last dimension must be the half-Hermitian compressed dimension.
388+
Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
389+
norm (str, optional): Normalization mode. For the backward transform
390+
(:func:`~torch.fft.irfftn`), these correspond to:
391+
392+
* ``"forward"`` - no normalization
393+
* ``"backward"`` - normalize by ``1/n``
394+
* ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)
395+
396+
Where ``n = prod(s)`` is the logical IFFT size.
397+
Calling the forward transform (:func:`~torch.fft.rfftn`) with the same
398+
normalization mode will apply an overall normalization of ``1/n`` between
399+
the two transforms. This is required to make :func:`~torch.fft.irfftn`
400+
the exact inverse.
401+
402+
Default is ``"backward"`` (normalize by ``1/n``).
403+
404+
Example:
405+
406+
>>> import torch.fft
407+
>>> t = torch.rand(10, 9)
408+
>>> T = torch.fft.rfftn(t)
409+
410+
Without specifying the output length to :func:`~torch.fft.irfft`, the output
411+
will not round-trip properly because the input is odd-length in the last
412+
dimension:
413+
414+
>>> torch.fft.irfftn(T).size()
415+
torch.Size([10, 10])
416+
417+
So, it is recommended to always pass the signal shape :attr:`s`.
418+
419+
>>> roundtrip = torch.fft.irfftn(T, t.size())
420+
>>> roundtrip.size()
421+
torch.Size([10, 9])
422+
>>> torch.allclose(roundtrip, t)
423+
True
424+
425+
""")
426+
202427
hfft = _add_docstr(_fft.fft_hfft, r"""
203428
hfft(input, n=None, dim=-1, norm=None) -> Tensor
204429

0 commit comments

Comments
 (0)
Please sign in to comment.