@@ -1007,6 +1007,142 @@ def to_frequency_domain(self, lower, upper, sampling_frequency, remove_dc=True):
1007
1007
extrapolation = "zero" ,
1008
1008
)
1009
1009
1010
+ def short_time_fft (
1011
+ self ,
1012
+ lower ,
1013
+ upper ,
1014
+ sampling_frequency ,
1015
+ window_size ,
1016
+ step_size ,
1017
+ remove_dc = True ,
1018
+ only_positive = True ,
1019
+ ):
1020
+ r"""
1021
+ Performs the Short-Time Fourier Transform (STFT) of the Function and
1022
+ returns the result. The STFT is computed by applying the Fourier
1023
+ transform to overlapping windows of the Function.
1024
+
1025
+ Parameters
1026
+ ----------
1027
+ lower : float
1028
+ Lower bound of the time range.
1029
+ upper : float
1030
+ Upper bound of the time range.
1031
+ sampling_frequency : float
1032
+ Sampling frequency at which to perform the Fourier transform.
1033
+ window_size : float
1034
+ Size of the window for the STFT, in seconds.
1035
+ step_size : float
1036
+ Step size for the window, in seconds.
1037
+ remove_dc : bool, optional
1038
+ If True, the DC component is removed from each window before
1039
+ computing the Fourier transform.
1040
+ only_positive: bool, optional
1041
+ If True, only the positive frequencies are returned.
1042
+
1043
+ Returns
1044
+ -------
1045
+ list[Function]
1046
+ A list of Functions, each representing the STFT of a window.
1047
+
1048
+ Examples
1049
+ --------
1050
+
1051
+ >>> import numpy as np
1052
+ >>> import matplotlib.pyplot as plt
1053
+ >>> from rocketpy import Function
1054
+
1055
+ Generate a signal with varying frequency:
1056
+
1057
+ >>> T_x, N = 1 / 20 , 1000 # 20 Hz sampling rate for 50 s signal
1058
+ >>> t_x = np.arange(N) * T_x # time indexes for signal
1059
+ >>> f_i = 1 * np.arctan((t_x - t_x[N // 2]) / 2) + 5 # varying frequency
1060
+ >>> signal = np.sin(2 * np.pi * np.cumsum(f_i) * T_x) # the signal
1061
+
1062
+ Create the Function object and perform the STFT:
1063
+
1064
+ >>> time_domain = Function(np.array([t_x, signal]).T)
1065
+ >>> stft_result = time_domain.short_time_fft(
1066
+ ... lower=0,
1067
+ ... upper=50,
1068
+ ... sampling_frequency=95,
1069
+ ... window_size=2,
1070
+ ... step_size=0.5,
1071
+ ... )
1072
+
1073
+ Plot the spectrogram:
1074
+
1075
+ >>> Sx = np.abs([window[:, 1] for window in stft_result])
1076
+ >>> t_lo, t_hi = t_x[0], t_x[-1]
1077
+ >>> fig1, ax1 = plt.subplots(figsize=(10, 6))
1078
+ >>> im1 = ax1.imshow(
1079
+ ... Sx.T,
1080
+ ... origin='lower',
1081
+ ... aspect='auto',
1082
+ ... extent=[t_lo, t_hi, 0, 50],
1083
+ ... cmap='viridis'
1084
+ ... )
1085
+ >>> _ = ax1.set_title(rf"STFT (2$\,s$ Gaussian window, $\sigma_t=0.4\,$s)")
1086
+ >>> _ = ax1.set(
1087
+ ... xlabel=f"Time $t$ in seconds",
1088
+ ... ylabel=f"Freq. $f$ in Hz)",
1089
+ ... xlim=(t_lo, t_hi)
1090
+ ... )
1091
+ >>> _ = ax1.plot(t_x, f_i, 'r--', alpha=.5, label='$f_i(t)$')
1092
+ >>> _ = fig1.colorbar(im1, label="Magnitude $|S_x(t, f)|$")
1093
+ >>> # Shade areas where window slices stick out to the side
1094
+ >>> for t0_, t1_ in [(t_lo, 1), (49, t_hi)]:
1095
+ ... _ = ax1.axvspan(t0_, t1_, color='w', linewidth=0, alpha=.2)
1096
+ >>> # Mark signal borders with vertical line
1097
+ >>> for t_ in [t_lo, t_hi]:
1098
+ ... _ = ax1.axvline(t_, color='y', linestyle='--', alpha=0.5)
1099
+ >>> # Add legend and finalize plot
1100
+ >>> _ = ax1.legend()
1101
+ >>> fig1.tight_layout()
1102
+ >>> # plt.show() # uncomment to show the plot
1103
+
1104
+ References
1105
+ ----------
1106
+ Example adapted from the SciPy documentation:
1107
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.ShortTimeFFT.html
1108
+ """
1109
+ # Get the time domain data
1110
+ sampling_time_step = 1.0 / sampling_frequency
1111
+ sampling_range = np .arange (lower , upper , sampling_time_step )
1112
+ sampled_points = self (sampling_range )
1113
+ samples_per_window = int (window_size * sampling_frequency )
1114
+ samples_skipped_per_step = int (step_size * sampling_frequency )
1115
+ stft_results = []
1116
+
1117
+ max_start = len (sampled_points ) - samples_per_window + 1
1118
+
1119
+ for start in range (0 , max_start , samples_skipped_per_step ):
1120
+ windowed_samples = sampled_points [start : start + samples_per_window ]
1121
+ if remove_dc :
1122
+ windowed_samples -= np .mean (windowed_samples )
1123
+ fourier_amplitude = np .abs (
1124
+ np .fft .fft (windowed_samples ) / (samples_per_window / 2 )
1125
+ )
1126
+ fourier_frequencies = np .fft .fftfreq (samples_per_window , sampling_time_step )
1127
+
1128
+ # Filter to keep only positive frequencies if specified
1129
+ if only_positive :
1130
+ positive_indices = fourier_frequencies > 0
1131
+ fourier_frequencies = fourier_frequencies [positive_indices ]
1132
+ fourier_amplitude = fourier_amplitude [positive_indices ]
1133
+
1134
+ stft_results .append (
1135
+ Function (
1136
+ source = np .array ([fourier_frequencies , fourier_amplitude ]).T ,
1137
+ inputs = "Frequency (Hz)" ,
1138
+ outputs = "Amplitude" ,
1139
+ interpolation = "linear" ,
1140
+ extrapolation = "zero" ,
1141
+ )
1142
+ )
1143
+
1144
+ return stft_results
1145
+
1010
1146
def low_pass_filter (self , alpha , file_path = None ):
1011
1147
"""Implements a low pass filter with a moving average filter. This does
1012
1148
not mutate the original Function object, but returns a new one with the
0 commit comments