Skip to content

Commit 6fa90b7

Browse files
committed
Merge remote-tracking branch 'origin/develop' into enh/parallel_montecarlo
2 parents d07fcc2 + 44beade commit 6fa90b7

File tree

5 files changed

+157
-15
lines changed

5 files changed

+157
-15
lines changed

.vscode/settings.json

+3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"copybutton",
7979
"cstride",
8080
"csys",
81+
"cumsum",
8182
"datapoints",
8283
"datetime",
8384
"dcsys",
@@ -149,6 +150,7 @@
149150
"IGRA",
150151
"imageio",
151152
"imread",
153+
"imshow",
152154
"intc",
153155
"interp",
154156
"Interquartile",
@@ -259,6 +261,7 @@
259261
"SRTM",
260262
"SRTMGL",
261263
"Stano",
264+
"STFT",
262265
"subintervals",
263266
"suptitle",
264267
"ticklabel",

CHANGELOG.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ Attention: The newest changes should be on top -->
3232

3333
### Added
3434

35-
- ENH: Rocket Axis Definition [#635](https://github.com/RocketPy-Team/RocketPy/pull/635)
35+
- ENH: Add STFT function to Function class [#620](https://github.com/RocketPy-Team/RocketPy/pull/620)
36+
- ENH: Rocket Axis Definition [#635](https://github.com/RocketPy-Team/RocketPy/pull/635)
3637

3738
### Changed
3839

@@ -42,7 +43,8 @@ Attention: The newest changes should be on top -->
4243

4344
### Fixed
4445

45-
46+
- BUG: Pressure ISA Extrapolation as "linear" [#675](https://github.com/RocketPy-Team/RocketPy/pull/675)
47+
- BUG: fix the Frequency Response plot of Flight class [#653](https://github.com/RocketPy-Team/RocketPy/pull/653)
4648

4749
## [v1.4.2] - 2024-08-03
4850

rocketpy/environment/environment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2382,7 +2382,7 @@ def load_international_standard_atmosphere(self): # pragma: no cover
23822382
DeprecationWarning,
23832383
)
23842384

2385-
@funcify_method("Height Above Sea Level (m)", "Pressure (Pa)", "spline", "linear")
2385+
@funcify_method("Height Above Sea Level (m)", "Pressure (Pa)", "spline", "natural")
23862386
def pressure_ISA(self):
23872387
"""Pressure, in Pa, as a function of height above sea level as defined
23882388
by the `International Standard Atmosphere ISO 2533`."""

rocketpy/mathutils/function.py

+136
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,142 @@ def to_frequency_domain(self, lower, upper, sampling_frequency, remove_dc=True):
10071007
extrapolation="zero",
10081008
)
10091009

1010+
def short_time_fft(
1011+
self,
1012+
lower,
1013+
upper,
1014+
sampling_frequency,
1015+
window_size,
1016+
step_size,
1017+
remove_dc=True,
1018+
only_positive=True,
1019+
):
1020+
r"""
1021+
Performs the Short-Time Fourier Transform (STFT) of the Function and
1022+
returns the result. The STFT is computed by applying the Fourier
1023+
transform to overlapping windows of the Function.
1024+
1025+
Parameters
1026+
----------
1027+
lower : float
1028+
Lower bound of the time range.
1029+
upper : float
1030+
Upper bound of the time range.
1031+
sampling_frequency : float
1032+
Sampling frequency at which to perform the Fourier transform.
1033+
window_size : float
1034+
Size of the window for the STFT, in seconds.
1035+
step_size : float
1036+
Step size for the window, in seconds.
1037+
remove_dc : bool, optional
1038+
If True, the DC component is removed from each window before
1039+
computing the Fourier transform.
1040+
only_positive: bool, optional
1041+
If True, only the positive frequencies are returned.
1042+
1043+
Returns
1044+
-------
1045+
list[Function]
1046+
A list of Functions, each representing the STFT of a window.
1047+
1048+
Examples
1049+
--------
1050+
1051+
>>> import numpy as np
1052+
>>> import matplotlib.pyplot as plt
1053+
>>> from rocketpy import Function
1054+
1055+
Generate a signal with varying frequency:
1056+
1057+
>>> T_x, N = 1 / 20 , 1000 # 20 Hz sampling rate for 50 s signal
1058+
>>> t_x = np.arange(N) * T_x # time indexes for signal
1059+
>>> f_i = 1 * np.arctan((t_x - t_x[N // 2]) / 2) + 5 # varying frequency
1060+
>>> signal = np.sin(2 * np.pi * np.cumsum(f_i) * T_x) # the signal
1061+
1062+
Create the Function object and perform the STFT:
1063+
1064+
>>> time_domain = Function(np.array([t_x, signal]).T)
1065+
>>> stft_result = time_domain.short_time_fft(
1066+
... lower=0,
1067+
... upper=50,
1068+
... sampling_frequency=95,
1069+
... window_size=2,
1070+
... step_size=0.5,
1071+
... )
1072+
1073+
Plot the spectrogram:
1074+
1075+
>>> Sx = np.abs([window[:, 1] for window in stft_result])
1076+
>>> t_lo, t_hi = t_x[0], t_x[-1]
1077+
>>> fig1, ax1 = plt.subplots(figsize=(10, 6))
1078+
>>> im1 = ax1.imshow(
1079+
... Sx.T,
1080+
... origin='lower',
1081+
... aspect='auto',
1082+
... extent=[t_lo, t_hi, 0, 50],
1083+
... cmap='viridis'
1084+
... )
1085+
>>> _ = ax1.set_title(rf"STFT (2$\,s$ Gaussian window, $\sigma_t=0.4\,$s)")
1086+
>>> _ = ax1.set(
1087+
... xlabel=f"Time $t$ in seconds",
1088+
... ylabel=f"Freq. $f$ in Hz)",
1089+
... xlim=(t_lo, t_hi)
1090+
... )
1091+
>>> _ = ax1.plot(t_x, f_i, 'r--', alpha=.5, label='$f_i(t)$')
1092+
>>> _ = fig1.colorbar(im1, label="Magnitude $|S_x(t, f)|$")
1093+
>>> # Shade areas where window slices stick out to the side
1094+
>>> for t0_, t1_ in [(t_lo, 1), (49, t_hi)]:
1095+
... _ = ax1.axvspan(t0_, t1_, color='w', linewidth=0, alpha=.2)
1096+
>>> # Mark signal borders with vertical line
1097+
>>> for t_ in [t_lo, t_hi]:
1098+
... _ = ax1.axvline(t_, color='y', linestyle='--', alpha=0.5)
1099+
>>> # Add legend and finalize plot
1100+
>>> _ = ax1.legend()
1101+
>>> fig1.tight_layout()
1102+
>>> # plt.show() # uncomment to show the plot
1103+
1104+
References
1105+
----------
1106+
Example adapted from the SciPy documentation:
1107+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.ShortTimeFFT.html
1108+
"""
1109+
# Get the time domain data
1110+
sampling_time_step = 1.0 / sampling_frequency
1111+
sampling_range = np.arange(lower, upper, sampling_time_step)
1112+
sampled_points = self(sampling_range)
1113+
samples_per_window = int(window_size * sampling_frequency)
1114+
samples_skipped_per_step = int(step_size * sampling_frequency)
1115+
stft_results = []
1116+
1117+
max_start = len(sampled_points) - samples_per_window + 1
1118+
1119+
for start in range(0, max_start, samples_skipped_per_step):
1120+
windowed_samples = sampled_points[start : start + samples_per_window]
1121+
if remove_dc:
1122+
windowed_samples -= np.mean(windowed_samples)
1123+
fourier_amplitude = np.abs(
1124+
np.fft.fft(windowed_samples) / (samples_per_window / 2)
1125+
)
1126+
fourier_frequencies = np.fft.fftfreq(samples_per_window, sampling_time_step)
1127+
1128+
# Filter to keep only positive frequencies if specified
1129+
if only_positive:
1130+
positive_indices = fourier_frequencies > 0
1131+
fourier_frequencies = fourier_frequencies[positive_indices]
1132+
fourier_amplitude = fourier_amplitude[positive_indices]
1133+
1134+
stft_results.append(
1135+
Function(
1136+
source=np.array([fourier_frequencies, fourier_amplitude]).T,
1137+
inputs="Frequency (Hz)",
1138+
outputs="Amplitude",
1139+
interpolation="linear",
1140+
extrapolation="zero",
1141+
)
1142+
)
1143+
1144+
return stft_results
1145+
10101146
def low_pass_filter(self, alpha, file_path=None):
10111147
"""Implements a low pass filter with a moving average filter. This does
10121148
not mutate the original Function object, but returns a new one with the

rocketpy/plots/flight_plots.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -731,32 +731,33 @@ def stability_and_control_data(self): # pylint: disable=too-many-statements
731731
ax1.grid()
732732

733733
ax2 = plt.subplot(212)
734-
max_attitude = max(self.flight.attitude_frequency_response[:, 1])
734+
x_axis = np.arange(0, 5, 0.01)
735+
max_attitude = self.flight.attitude_frequency_response.max
735736
max_attitude = max_attitude if max_attitude != 0 else 1
736737
ax2.plot(
737-
self.flight.attitude_frequency_response[:, 0],
738-
self.flight.attitude_frequency_response[:, 1] / max_attitude,
738+
x_axis,
739+
self.flight.attitude_frequency_response(x_axis) / max_attitude,
739740
label="Attitude Angle",
740741
)
741-
max_omega1 = max(self.flight.omega1_frequency_response[:, 1])
742+
max_omega1 = self.flight.omega1_frequency_response.max
742743
max_omega1 = max_omega1 if max_omega1 != 0 else 1
743744
ax2.plot(
744-
self.flight.omega1_frequency_response[:, 0],
745-
self.flight.omega1_frequency_response[:, 1] / max_omega1,
745+
x_axis,
746+
self.flight.omega1_frequency_response(x_axis) / max_omega1,
746747
label=r"$\omega_1$",
747748
)
748-
max_omega2 = max(self.flight.omega2_frequency_response[:, 1])
749+
max_omega2 = self.flight.omega2_frequency_response.max
749750
max_omega2 = max_omega2 if max_omega2 != 0 else 1
750751
ax2.plot(
751-
self.flight.omega2_frequency_response[:, 0],
752-
self.flight.omega2_frequency_response[:, 1] / max_omega2,
752+
x_axis,
753+
self.flight.omega2_frequency_response(x_axis) / max_omega2,
753754
label=r"$\omega_2$",
754755
)
755-
max_omega3 = max(self.flight.omega3_frequency_response[:, 1])
756+
max_omega3 = self.flight.omega3_frequency_response.max
756757
max_omega3 = max_omega3 if max_omega3 != 0 else 1
757758
ax2.plot(
758-
self.flight.omega3_frequency_response[:, 0],
759-
self.flight.omega3_frequency_response[:, 1] / max_omega3,
759+
x_axis,
760+
self.flight.omega3_frequency_response(x_axis) / max_omega3,
760761
label=r"$\omega_3$",
761762
)
762763
ax2.set_title("Frequency Response")

0 commit comments

Comments
 (0)