18
18
19
19
namespace at { namespace native {
20
20
21
+ // Common code for all FFT functions
22
+ static inline Tensor _fft (
23
+ const Tensor &self, int64_t signal_ndim, bool complex_input,
24
+ const bool complex_output, bool inverse, IntArrayRef signal_sizes,
25
+ fft_norm_mode normalization, bool onesided);
26
+
27
+ namespace {
28
+
29
+ // Promote inputs to FFT functions
30
+ // * Integers are promoted to the default floating type
31
+ // * If require_complex=True, all types are promoted to complex
32
+ // * Raises an error for half-precision dtypes to allow future support
33
+ ScalarType promote_type_fft (ScalarType type, bool require_complex) {
34
+ if (at::isComplexType (type)) {
35
+ return type;
36
+ }
37
+ // Promote integral to default float type
38
+ if (!at::isFloatingType (type)) {
39
+ type = c10::typeMetaToScalarType (c10::get_default_dtype ());
40
+ }
41
+
42
+ TORCH_CHECK (type == kFloat || type == kDouble , " Unsupported dtype " , type);
43
+
44
+ if (!require_complex) {
45
+ return type;
46
+ }
47
+
48
+ // Promote to complex
49
+ switch (type) {
50
+ case kFloat : return kComplexFloat ;
51
+ case kDouble : return kComplexDouble ;
52
+ default : TORCH_INTERNAL_ASSERT (false , " Unhandled dtype" );
53
+ }
54
+ }
55
+
56
+ // Promote a tensor's dtype according to promote_type_fft
57
+ Tensor promote_tensor_fft (const Tensor& t, bool require_complex=false ) {
58
+ auto cur_type = t.scalar_type ();
59
+ auto new_type = promote_type_fft (cur_type, require_complex);
60
+ return (cur_type == new_type) ? t : t.to (new_type);
61
+ }
62
+
63
+ // Convert NumPy compatible normalization mode string to enum values
64
+ // NOTE: NumPy's normalization modes have direction-specific meanings. For example,
65
+ // "forward" translates to `by_n` for a forward transform and `none` for backward.
66
+ fft_norm_mode norm_from_string (c10::optional<std::string> norm, bool forward) {
67
+ if (!norm || *norm == " backward" ) {
68
+ return forward ? fft_norm_mode::none : fft_norm_mode::by_n;
69
+ }
70
+
71
+ if (*norm == " forward" ) {
72
+ return forward ? fft_norm_mode::by_n : fft_norm_mode::none;
73
+ }
74
+
75
+ if (*norm == " ortho" ) {
76
+ return fft_norm_mode::by_root_n;
77
+ }
78
+
79
+ TORCH_CHECK (false , " Invalid normalization mode: \" " , *norm, " \" " )
80
+ }
81
+
82
+ // Fixes the shape of x such that x.size(dims[i]) == sizes[i],
83
+ // either by zero-padding, or by slicing x starting from 0.
84
+ Tensor resize_fft_input (Tensor x, IntArrayRef dims, IntArrayRef sizes) {
85
+ TORCH_INTERNAL_ASSERT (dims.size () == sizes.size ());
86
+ bool must_copy = false ;
87
+ auto x_sizes = x.sizes ();
88
+ DimVector pad_amount (x_sizes.size () * 2 );
89
+ for (int64_t i = 0 ; i < dims.size (); ++i) {
90
+ if (sizes[i] == -1 ) {
91
+ continue ;
92
+ }
93
+
94
+ if (x_sizes[dims[i]] < sizes[i]) {
95
+ must_copy = true ;
96
+ auto pad_idx = pad_amount.size () - 2 * dims[i] - 1 ;
97
+ pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]];
98
+ }
99
+
100
+ if (x_sizes[dims[i]] > sizes[i]) {
101
+ x = x.slice (dims[i], 0 , sizes[i]);
102
+ }
103
+ }
104
+
105
+ // Only call pad if necessary since pad copies the entire tensor
106
+ return must_copy ? at::constant_pad_nd (x, pad_amount) : x;
107
+ }
108
+
109
+ // Complex to real FFT
110
+ Tensor fft_c2r (Tensor input, c10::optional<int64_t > n_opt,
111
+ int64_t unwrapped_dim, c10::optional<std::string> norm_str,
112
+ bool forward) {
113
+ input = promote_tensor_fft (input, /* require_complex=*/ true );
114
+ const auto input_dim = input.dim ();
115
+ const auto dim = maybe_wrap_dim (unwrapped_dim, input_dim);
116
+ const auto n = n_opt.value_or (2 *(input.sizes ()[dim] - 1 ));
117
+ TORCH_CHECK (n >= 1 , " Invalid number of data points (" , n, " ) specified" );
118
+ if (n_opt) {
119
+ input = resize_fft_input (input, dim, n/2 + 1 );
120
+ }
121
+ // _fft only operates on the last dim, so transpose the selected dim to the end
122
+ const bool must_transpose = (dim != input_dim - 1 );
123
+ if (must_transpose) {
124
+ input = at::transpose (input, -1 , dim);
125
+ }
126
+ const auto norm = norm_from_string (norm_str, forward);
127
+ if (forward) {
128
+ // FIXME: _fft does not support complex_output=false with inverse=false
129
+ input = at::conj (input);
130
+ }
131
+ auto out = _fft (at::view_as_real (input),
132
+ /* signal_ndim=*/ 1 , /* complex_input=*/ true ,
133
+ /* complex_output=*/ false , /* inverse=*/ true ,
134
+ /* signal_sizes=*/ {n}, /* normalization=*/ norm,
135
+ /* onesided=*/ true );
136
+ if (must_transpose) {
137
+ out = at::transpose (out, -1 , dim);
138
+ }
139
+ return out;
140
+ }
141
+
142
+ // Real to complex FFT
143
+ Tensor fft_r2c (Tensor input, c10::optional<int64_t > n_opt,
144
+ int64_t unwrapped_dim, c10::optional<std::string> norm_str,
145
+ bool forward, bool onesided) {
146
+ TORCH_CHECK (!input.is_complex (), " Expected a real input tensor to FFT" );
147
+ input = promote_tensor_fft (input);
148
+ const auto input_dim = input.dim ();
149
+ const auto dim = maybe_wrap_dim (unwrapped_dim, input_dim);
150
+ const auto n = n_opt.value_or (input.sizes ()[dim]);
151
+ TORCH_CHECK (n >= 1 , " Invalid number of data points (" , n, " ) specified" );
152
+ if (n_opt) {
153
+ input = resize_fft_input (input, dim, n);
154
+ }
155
+ // _fft only operates on the last dim, so transpose the selected dim to the end
156
+ const bool must_transpose = (dim != input_dim - 1 );
157
+ if (must_transpose) {
158
+ input = at::transpose (input, -1 , dim);
159
+ }
160
+ const auto norm = norm_from_string (norm_str, forward);
161
+ auto out = _fft (input, /* signal_ndim=*/ 1 , /* complex_input=*/ false ,
162
+ /* complex_output=*/ true , /* inverse=*/ false ,
163
+ /* signal_sizes=*/ {n}, /* normalization=*/ norm,
164
+ /* onesided=*/ onesided);
165
+ out = at::view_as_complex (out);
166
+ if (must_transpose) {
167
+ out = at::transpose (out, -1 , dim);
168
+ }
169
+ if (!forward) {
170
+ // FIXME: _fft does not support complex_input=false with inverse=true
171
+ out = at::conj (out);
172
+ }
173
+ return out;
174
+ }
175
+
176
+ // Complex to complex FFT
177
+ Tensor fft_c2c (Tensor input, c10::optional<int64_t > n_opt,
178
+ int64_t unwrapped_dim, c10::optional<std::string> norm_str,
179
+ bool forward) {
180
+ TORCH_CHECK (input.is_complex (), " Expected a complex input tensor to FFT" );
181
+ const auto input_dim = input.dim ();
182
+ const auto dim = maybe_wrap_dim (unwrapped_dim, input_dim);
183
+ const auto n = n_opt.value_or (input.sizes ()[dim]);
184
+ TORCH_CHECK (n >= 1 , " Invalid number of data points (" , n, " ) specified" );
185
+ if (n_opt) {
186
+ input = resize_fft_input (input, dim, n);
187
+ }
188
+ // _fft only operates on the last dim, so transpose the selected dim to the end
189
+ const bool must_transpose = (dim != input_dim - 1 );
190
+ if (must_transpose) {
191
+ input = at::transpose (input, -1 , dim);
192
+ }
193
+ const auto norm = norm_from_string (norm_str, forward);
194
+ auto out = _fft (at::view_as_real (input),
195
+ /* signal_ndim=*/ 1 , /* complex_input=*/ true ,
196
+ /* complex_output=*/ true , /* inverse=*/ !forward,
197
+ /* signal_sizes=*/ {}, /* normalization=*/ norm,
198
+ /* onesided=*/ false );
199
+ out = at::view_as_complex (out);
200
+ if (must_transpose) {
201
+ out = at::transpose (out, -1 , dim);
202
+ }
203
+ return out;
204
+ }
205
+
206
+ }
207
+
21
208
// torch.fft.fft, analogous to NumPy's numpy.fft.fft
22
- Tensor fft_fft (const Tensor& self) {
23
- TORCH_CHECK (self.is_complex (), " Expected a complex tensor." );
24
- TORCH_CHECK (self.dim () == 1 , " Expected a 1D tensor." );
209
+ Tensor fft_fft (const Tensor& self, c10::optional<int64_t > n, int64_t dim,
210
+ c10::optional<std::string> norm) {
211
+ return self.is_complex () ?
212
+ fft_c2c (self, n, dim, norm, /* forward=*/ true ) :
213
+ fft_r2c (self, n, dim, norm, /* forward=*/ true , /* onesided=*/ false );
214
+ }
25
215
26
- auto result = at::fft (at::view_as_real (self), 1 , false );
27
- return at::view_as_complex (result);
216
+ Tensor fft_ifft (const Tensor& self, c10::optional<int64_t > n, int64_t dim,
217
+ c10::optional<std::string> norm) {
218
+ return self.is_complex () ?
219
+ fft_c2c (self, n, dim, norm, /* forward=*/ false ) :
220
+ fft_r2c (self, n, dim, norm, /* forward=*/ false , /* onesided=*/ false );
28
221
}
29
222
223
+ Tensor fft_rfft (const Tensor& self, c10::optional<int64_t > n, int64_t dim,
224
+ c10::optional<std::string> norm) {
225
+ return fft_r2c (self, n, dim, norm, /* forward=*/ true , /* onesided=*/ true );
226
+ }
227
+
228
+ Tensor fft_irfft (const Tensor& self, c10::optional<int64_t > n, int64_t dim,
229
+ c10::optional<std::string> norm) {
230
+ return fft_c2r (self, n, dim, norm, /* forward=*/ false );
231
+ }
232
+
233
+ Tensor fft_hfft (const Tensor& self, c10::optional<int64_t > n, int64_t dim,
234
+ c10::optional<std::string> norm) {
235
+ return fft_c2r (self, n, dim, norm, /* forward=*/ true );
236
+ }
237
+
238
+ Tensor fft_ihfft (const Tensor& self, c10::optional<int64_t > n, int64_t dim,
239
+ c10::optional<std::string> norm) {
240
+ return fft_r2c (self, n, dim, norm, /* forward=*/ false , /* onesided=*/ true );
241
+ }
242
+
243
+
30
244
// This is a pass-through wrapper function that does the size check and
31
245
// inferences. The actual forward implementation function is called
32
246
// at::_fft_with_size which dispatches to _fft_cufft (CUDA) or _fft_mkl (CPU).
33
247
static inline Tensor _fft (const Tensor &self, const int64_t signal_ndim,
34
248
const bool complex_input, const bool complex_output,
35
- const bool inverse, IntArrayRef signal_sizes, const bool normalized,
36
- const bool onesided) {
249
+ const bool inverse, IntArrayRef signal_sizes,
250
+ const fft_norm_mode normalization, const bool onesided) {
37
251
38
252
TORCH_CHECK (signal_ndim >= 1 && signal_ndim <= 3 ,
39
253
" Expected signal_ndim to be 1, 2, or 3, but got signal_ndim=" ,
@@ -122,7 +336,9 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
122
336
123
337
Tensor output = at::_fft_with_size (input, signal_ndim, complex_input,
124
338
complex_output, inverse,
125
- checked_signal_sizes, normalized, onesided,
339
+ checked_signal_sizes,
340
+ static_cast <int64_t >(normalization),
341
+ onesided,
126
342
output_sizes);
127
343
128
344
// unflatten the batch dims
@@ -139,6 +355,25 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
139
355
return output;
140
356
}
141
357
358
+ // Wrapper to preserve the historic signature of _fft_with_size
359
+ // NOTE: This is only used for torchscript backwards compatibility and the new
360
+ // signature with normalization modes should be used in all other cases
361
+ Tensor _fft_with_size (const Tensor& input, int64_t signal_ndim,
362
+ bool complex_input, bool complex_output,
363
+ bool inverse, IntArrayRef checked_signal_sizes,
364
+ bool normalized, bool onesided,
365
+ IntArrayRef output_sizes) {
366
+ fft_norm_mode norm;
367
+ if (normalized) {
368
+ norm = fft_norm_mode::by_root_n;
369
+ } else {
370
+ norm = inverse ? fft_norm_mode::by_n : fft_norm_mode::none;
371
+ }
372
+ return at::_fft_with_size (
373
+ input, signal_ndim, complex_input, complex_output, inverse,
374
+ checked_signal_sizes, static_cast <int64_t >(norm), onesided, output_sizes);
375
+ }
376
+
142
377
// We call the following methods via CUDA hooks because they are really only
143
378
// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
144
379
int64_t _cufft_get_plan_cache_max_size (int64_t device_index) {
@@ -159,28 +394,32 @@ void _cufft_clear_plan_cache(int64_t device_index) {
159
394
160
395
Tensor fft (const Tensor& self, const int64_t signal_ndim, const bool normalized) {
161
396
return _fft (self, signal_ndim, /* complex_input */ true ,
162
- /* complex_output */ true , /* inverse */ false , {}, normalized,
397
+ /* complex_output */ true , /* inverse */ false , {},
398
+ normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none,
163
399
/* onesided */ false );
164
400
}
165
401
166
402
Tensor ifft (const Tensor& self, const int64_t signal_ndim, const bool normalized) {
167
403
return _fft (self, signal_ndim, /* complex_input */ true ,
168
- /* complex_output */ true , /* inverse */ true , {}, normalized,
404
+ /* complex_output */ true , /* inverse */ true , {},
405
+ normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n,
169
406
/* onesided */ false );
170
407
}
171
408
172
409
Tensor rfft (const Tensor& self, const int64_t signal_ndim, const bool normalized,
173
410
const bool onesided) {
174
411
return _fft (self, signal_ndim, /* complex_input */ false ,
175
- /* complex_output */ true , /* inverse */ false , {}, normalized,
412
+ /* complex_output */ true , /* inverse */ false , {},
413
+ normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none,
176
414
onesided);
177
415
}
178
416
179
417
Tensor irfft (const Tensor& self, const int64_t signal_ndim, const bool normalized,
180
418
const bool onesided, IntArrayRef signal_sizes) {
181
419
return _fft (self, signal_ndim, /* complex_input */ true ,
182
420
/* complex_output */ false , /* inverse */ true , signal_sizes,
183
- normalized, onesided);
421
+ normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n,
422
+ onesided);
184
423
}
185
424
186
425
template <typename Stream, typename T>
0 commit comments