Skip to content

Commit da7863f

Browse files
peterbell10facebook-github-bot
authored andcommittedSep 20, 2020
Add one dimensional FFTs to torch.fft namespace (pytorch#43011)
Summary: Pull Request resolved: pytorch#43011 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D23751850 Pulled By: mruberry fbshipit-source-id: 8dc5fec75102d8809eeb85a3d347ba1b5de45b33
1 parent 49db7b5 commit da7863f

File tree

15 files changed

+1066
-110
lines changed

15 files changed

+1066
-110
lines changed
 

‎aten/src/ATen/native/SpectralOps.cpp

+251-12
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,236 @@
1818

1919
namespace at { namespace native {
2020

21+
// Common code for all FFT functions
22+
static inline Tensor _fft(
23+
const Tensor &self, int64_t signal_ndim, bool complex_input,
24+
const bool complex_output, bool inverse, IntArrayRef signal_sizes,
25+
fft_norm_mode normalization, bool onesided);
26+
27+
namespace {
28+
29+
// Promote inputs to FFT functions
30+
// * Integers are promoted to the default floating type
31+
// * If require_complex=True, all types are promoted to complex
32+
// * Raises an error for half-precision dtypes to allow future support
33+
ScalarType promote_type_fft(ScalarType type, bool require_complex) {
34+
if (at::isComplexType(type)) {
35+
return type;
36+
}
37+
// Promote integral to default float type
38+
if (!at::isFloatingType(type)) {
39+
type = c10::typeMetaToScalarType(c10::get_default_dtype());
40+
}
41+
42+
TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type);
43+
44+
if (!require_complex) {
45+
return type;
46+
}
47+
48+
// Promote to complex
49+
switch (type) {
50+
case kFloat: return kComplexFloat;
51+
case kDouble: return kComplexDouble;
52+
default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype");
53+
}
54+
}
55+
56+
// Promote a tensor's dtype according to promote_type_fft
57+
Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) {
58+
auto cur_type = t.scalar_type();
59+
auto new_type = promote_type_fft(cur_type, require_complex);
60+
return (cur_type == new_type) ? t : t.to(new_type);
61+
}
62+
63+
// Convert NumPy compatible normalization mode string to enum values
64+
// NOTE: NumPy's normalization modes have direction-specific meanings. For example,
65+
// "forward" translates to `by_n` for a forward transform and `none` for backward.
66+
fft_norm_mode norm_from_string(c10::optional<std::string> norm, bool forward) {
67+
if (!norm || *norm == "backward") {
68+
return forward ? fft_norm_mode::none : fft_norm_mode::by_n;
69+
}
70+
71+
if (*norm == "forward") {
72+
return forward ? fft_norm_mode::by_n : fft_norm_mode::none;
73+
}
74+
75+
if (*norm == "ortho") {
76+
return fft_norm_mode::by_root_n;
77+
}
78+
79+
TORCH_CHECK(false, "Invalid normalization mode: \"", *norm, "\"")
80+
}
81+
82+
// Fixes the shape of x such that x.size(dims[i]) == sizes[i],
83+
// either by zero-padding, or by slicing x starting from 0.
84+
Tensor resize_fft_input(Tensor x, IntArrayRef dims, IntArrayRef sizes) {
85+
TORCH_INTERNAL_ASSERT(dims.size() == sizes.size());
86+
bool must_copy = false;
87+
auto x_sizes = x.sizes();
88+
DimVector pad_amount(x_sizes.size() * 2);
89+
for (int64_t i = 0; i < dims.size(); ++i) {
90+
if (sizes[i] == -1) {
91+
continue;
92+
}
93+
94+
if (x_sizes[dims[i]] < sizes[i]) {
95+
must_copy = true;
96+
auto pad_idx = pad_amount.size() - 2 * dims[i] - 1;
97+
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]];
98+
}
99+
100+
if (x_sizes[dims[i]] > sizes[i]) {
101+
x = x.slice(dims[i], 0, sizes[i]);
102+
}
103+
}
104+
105+
// Only call pad if necessary since pad copies the entire tensor
106+
return must_copy ? at::constant_pad_nd(x, pad_amount) : x;
107+
}
108+
109+
// Complex to real FFT
110+
Tensor fft_c2r(Tensor input, c10::optional<int64_t> n_opt,
111+
int64_t unwrapped_dim, c10::optional<std::string> norm_str,
112+
bool forward) {
113+
input = promote_tensor_fft(input, /*require_complex=*/true);
114+
const auto input_dim = input.dim();
115+
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim);
116+
const auto n = n_opt.value_or(2*(input.sizes()[dim] - 1));
117+
TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified");
118+
if (n_opt) {
119+
input = resize_fft_input(input, dim, n/2 + 1);
120+
}
121+
// _fft only operates on the last dim, so transpose the selected dim to the end
122+
const bool must_transpose = (dim != input_dim - 1);
123+
if (must_transpose) {
124+
input = at::transpose(input, -1, dim);
125+
}
126+
const auto norm = norm_from_string(norm_str, forward);
127+
if (forward) {
128+
// FIXME: _fft does not support complex_output=false with inverse=false
129+
input = at::conj(input);
130+
}
131+
auto out = _fft(at::view_as_real(input),
132+
/*signal_ndim=*/1, /*complex_input=*/true,
133+
/*complex_output=*/false, /*inverse=*/true,
134+
/*signal_sizes=*/{n}, /*normalization=*/norm,
135+
/*onesided=*/true);
136+
if (must_transpose) {
137+
out = at::transpose(out, -1, dim);
138+
}
139+
return out;
140+
}
141+
142+
// Real to complex FFT
143+
Tensor fft_r2c(Tensor input, c10::optional<int64_t> n_opt,
144+
int64_t unwrapped_dim, c10::optional<std::string> norm_str,
145+
bool forward, bool onesided) {
146+
TORCH_CHECK(!input.is_complex(), "Expected a real input tensor to FFT");
147+
input = promote_tensor_fft(input);
148+
const auto input_dim = input.dim();
149+
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim);
150+
const auto n = n_opt.value_or(input.sizes()[dim]);
151+
TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified");
152+
if (n_opt) {
153+
input = resize_fft_input(input, dim, n);
154+
}
155+
// _fft only operates on the last dim, so transpose the selected dim to the end
156+
const bool must_transpose = (dim != input_dim - 1);
157+
if (must_transpose) {
158+
input = at::transpose(input, -1, dim);
159+
}
160+
const auto norm = norm_from_string(norm_str, forward);
161+
auto out = _fft(input, /*signal_ndim=*/1, /*complex_input=*/false,
162+
/*complex_output=*/true, /*inverse=*/false,
163+
/*signal_sizes=*/{n}, /*normalization=*/norm,
164+
/*onesided=*/onesided);
165+
out = at::view_as_complex(out);
166+
if (must_transpose) {
167+
out = at::transpose(out, -1, dim);
168+
}
169+
if (!forward) {
170+
// FIXME: _fft does not support complex_input=false with inverse=true
171+
out = at::conj(out);
172+
}
173+
return out;
174+
}
175+
176+
// Complex to complex FFT
177+
Tensor fft_c2c(Tensor input, c10::optional<int64_t> n_opt,
178+
int64_t unwrapped_dim, c10::optional<std::string> norm_str,
179+
bool forward) {
180+
TORCH_CHECK(input.is_complex(), "Expected a complex input tensor to FFT");
181+
const auto input_dim = input.dim();
182+
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim);
183+
const auto n = n_opt.value_or(input.sizes()[dim]);
184+
TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified");
185+
if (n_opt) {
186+
input = resize_fft_input(input, dim, n);
187+
}
188+
// _fft only operates on the last dim, so transpose the selected dim to the end
189+
const bool must_transpose = (dim != input_dim - 1);
190+
if (must_transpose) {
191+
input = at::transpose(input, -1, dim);
192+
}
193+
const auto norm = norm_from_string(norm_str, forward);
194+
auto out = _fft(at::view_as_real(input),
195+
/*signal_ndim=*/1, /*complex_input=*/true,
196+
/*complex_output=*/true, /*inverse=*/!forward,
197+
/*signal_sizes=*/{}, /*normalization=*/norm,
198+
/*onesided=*/false);
199+
out = at::view_as_complex(out);
200+
if (must_transpose) {
201+
out = at::transpose(out, -1, dim);
202+
}
203+
return out;
204+
}
205+
206+
}
207+
21208
// torch.fft.fft, analogous to NumPy's numpy.fft.fft
22-
Tensor fft_fft(const Tensor& self) {
23-
TORCH_CHECK(self.is_complex(), "Expected a complex tensor.");
24-
TORCH_CHECK(self.dim() == 1, "Expected a 1D tensor.");
209+
Tensor fft_fft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
210+
c10::optional<std::string> norm) {
211+
return self.is_complex() ?
212+
fft_c2c(self, n, dim, norm, /*forward=*/true) :
213+
fft_r2c(self, n, dim, norm, /*forward=*/true, /*onesided=*/false);
214+
}
25215

26-
auto result = at::fft(at::view_as_real(self), 1, false);
27-
return at::view_as_complex(result);
216+
Tensor fft_ifft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
217+
c10::optional<std::string> norm) {
218+
return self.is_complex() ?
219+
fft_c2c(self, n, dim, norm, /*forward=*/false) :
220+
fft_r2c(self, n, dim, norm, /*forward=*/false, /*onesided=*/false);
28221
}
29222

223+
Tensor fft_rfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
224+
c10::optional<std::string> norm) {
225+
return fft_r2c(self, n, dim, norm, /*forward=*/true, /*onesided=*/true);
226+
}
227+
228+
Tensor fft_irfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
229+
c10::optional<std::string> norm) {
230+
return fft_c2r(self, n, dim, norm, /*forward=*/false);
231+
}
232+
233+
Tensor fft_hfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
234+
c10::optional<std::string> norm) {
235+
return fft_c2r(self, n, dim, norm, /*forward=*/true);
236+
}
237+
238+
Tensor fft_ihfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
239+
c10::optional<std::string> norm) {
240+
return fft_r2c(self, n, dim, norm, /*forward=*/false, /*onesided=*/true);
241+
}
242+
243+
30244
// This is a pass-through wrapper function that does the size check and
31245
// inferences. The actual forward implementation function is called
32246
// at::_fft_with_size which dispatches to _fft_cufft (CUDA) or _fft_mkl (CPU).
33247
static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
34248
const bool complex_input, const bool complex_output,
35-
const bool inverse, IntArrayRef signal_sizes, const bool normalized,
36-
const bool onesided) {
249+
const bool inverse, IntArrayRef signal_sizes,
250+
const fft_norm_mode normalization, const bool onesided) {
37251

38252
TORCH_CHECK(signal_ndim >= 1 && signal_ndim <= 3,
39253
"Expected signal_ndim to be 1, 2, or 3, but got signal_ndim=",
@@ -122,7 +336,9 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
122336

123337
Tensor output = at::_fft_with_size(input, signal_ndim, complex_input,
124338
complex_output, inverse,
125-
checked_signal_sizes, normalized, onesided,
339+
checked_signal_sizes,
340+
static_cast<int64_t>(normalization),
341+
onesided,
126342
output_sizes);
127343

128344
// unflatten the batch dims
@@ -139,6 +355,25 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
139355
return output;
140356
}
141357

358+
// Wrapper to preserve the historic signature of _fft_with_size
359+
// NOTE: This is only used for torchscript backwards compatibility and the new
360+
// signature with normalization modes should be used in all other cases
361+
Tensor _fft_with_size(const Tensor& input, int64_t signal_ndim,
362+
bool complex_input, bool complex_output,
363+
bool inverse, IntArrayRef checked_signal_sizes,
364+
bool normalized, bool onesided,
365+
IntArrayRef output_sizes) {
366+
fft_norm_mode norm;
367+
if (normalized) {
368+
norm = fft_norm_mode::by_root_n;
369+
} else {
370+
norm = inverse ? fft_norm_mode::by_n : fft_norm_mode::none;
371+
}
372+
return at::_fft_with_size(
373+
input, signal_ndim, complex_input, complex_output, inverse,
374+
checked_signal_sizes, static_cast<int64_t>(norm), onesided, output_sizes);
375+
}
376+
142377
// We call the following methods via CUDA hooks because they are really only
143378
// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
144379
int64_t _cufft_get_plan_cache_max_size(int64_t device_index) {
@@ -159,28 +394,32 @@ void _cufft_clear_plan_cache(int64_t device_index) {
159394

160395
Tensor fft(const Tensor& self, const int64_t signal_ndim, const bool normalized) {
161396
return _fft(self, signal_ndim, /* complex_input */ true,
162-
/* complex_output */ true, /* inverse */ false, {}, normalized,
397+
/* complex_output */ true, /* inverse */ false, {},
398+
normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none,
163399
/* onesided */ false);
164400
}
165401

166402
Tensor ifft(const Tensor& self, const int64_t signal_ndim, const bool normalized) {
167403
return _fft(self, signal_ndim, /* complex_input */ true,
168-
/* complex_output */ true, /* inverse */ true, {}, normalized,
404+
/* complex_output */ true, /* inverse */ true, {},
405+
normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n,
169406
/* onesided */ false);
170407
}
171408

172409
Tensor rfft(const Tensor& self, const int64_t signal_ndim, const bool normalized,
173410
const bool onesided) {
174411
return _fft(self, signal_ndim, /* complex_input */ false,
175-
/* complex_output */ true, /* inverse */ false, {}, normalized,
412+
/* complex_output */ true, /* inverse */ false, {},
413+
normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none,
176414
onesided);
177415
}
178416

179417
Tensor irfft(const Tensor& self, const int64_t signal_ndim, const bool normalized,
180418
const bool onesided, IntArrayRef signal_sizes) {
181419
return _fft(self, signal_ndim, /* complex_input */ true,
182420
/* complex_output */ false, /* inverse */ true, signal_sizes,
183-
normalized, onesided);
421+
normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n,
422+
onesided);
184423
}
185424

186425
template <typename Stream, typename T>

‎aten/src/ATen/native/SpectralOpsUtils.h

+7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66

77
namespace at { namespace native {
88

9+
// Normalization types used in _fft_with_size
10+
enum class fft_norm_mode {
11+
none, // No normalization
12+
by_root_n, // Divide by sqrt(signal_size)
13+
by_n, // Divide by signal_size
14+
};
15+
916
// NOTE [ Fourier Transform Conjugate Symmetry ]
1017
//
1118
// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is,

‎aten/src/ATen/native/cuda/SpectralOps.cu

+9-7
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ static void _fft_fill_with_conjugate_symmetry_(Tensor& input,
175175
static inline Tensor _run_cufft(
176176
const CuFFTConfig &config, Tensor& input, int64_t signal_ndim,
177177
bool complex_input, bool complex_output, bool inverse,
178-
IntArrayRef checked_signal_sizes, bool normalized, bool onesided,
178+
IntArrayRef checked_signal_sizes, fft_norm_mode norm, bool onesided,
179179
IntArrayRef output_sizes, bool input_was_cloned
180180
) {
181181
if (config.should_clone_input() && !input_was_cloned) {
@@ -235,12 +235,12 @@ static inline Tensor _run_cufft(
235235
inverse ? CUFFT_INVERSE : CUFFT_FORWARD));
236236
#endif
237237

238-
// rescale if needed by normalized flag or inverse transform
238+
// rescale if requested
239239
auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1];
240-
if (normalized || inverse) {
240+
if (norm != fft_norm_mode::none) {
241241
auto signal_numel = at::prod_intlist(checked_signal_sizes);
242242
double scale_denom;
243-
if (normalized) {
243+
if (norm == fft_norm_mode::by_root_n) {
244244
scale_denom = std::sqrt(static_cast<double>(signal_numel));
245245
} else {
246246
scale_denom = static_cast<double>(signal_numel);
@@ -324,7 +324,7 @@ void cufft_clear_plan_cache_impl(int64_t device_index) {
324324
// Currently not utilizing multi GPUs so this can be potentially sped up.
325325
Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
326326
bool complex_input, bool complex_output, bool inverse,
327-
IntArrayRef checked_signal_sizes, bool normalized, bool onesided,
327+
IntArrayRef checked_signal_sizes, int64_t normalization, bool onesided,
328328
IntArrayRef output_sizes) {
329329

330330
CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(self.device().index());
@@ -377,14 +377,16 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
377377
complex_output, checked_signal_sizes,
378378
onesided, output_sizes);
379379
return _run_cufft(config, input, signal_ndim, complex_input,
380-
complex_output, inverse, checked_signal_sizes, normalized,
380+
complex_output, inverse, checked_signal_sizes,
381+
static_cast<fft_norm_mode>(normalization),
381382
onesided, output_sizes, input_was_cloned);
382383
}
383384
}
384385
CuFFTConfig config(input, signal_ndim, complex_input, complex_output,
385386
checked_signal_sizes, onesided, output_sizes);
386387
return _run_cufft(config, input, signal_ndim, complex_input,
387-
complex_output, inverse, checked_signal_sizes, normalized,
388+
complex_output, inverse, checked_signal_sizes,
389+
static_cast<fft_norm_mode>(normalization),
388390
onesided, output_sizes, input_was_cloned);
389391
}
390392

0 commit comments

Comments
 (0)
Please sign in to comment.