@@ -89,19 +89,30 @@ def forward(self, x):
89
89
optimizer = optim .Adam (model .parameters (), lr = 0.001 )
90
90
91
91
# Train the model
92
- n_epochs = 1000
92
+ n_epochs = 3000
93
+ losses = []
94
+
93
95
for epoch in range (n_epochs ):
94
96
model .train ()
95
-
97
+
96
98
optimizer .zero_grad ()
97
99
outputs = model (X_train )
98
-
100
+
99
101
loss = criterion (outputs , y_train )
100
102
loss .backward ()
101
103
optimizer .step ()
102
-
104
+
105
+ losses .append (loss .item ())
106
+
103
107
if epoch % 100 == 0 :
104
108
print (f'Epoch { epoch + 1 } /{ n_epochs } , Loss: { loss .item ()} ' )
109
+
110
+ torch .save (model .state_dict (), 'fourier_predictor_model.pth' )
111
+
112
+ # Load the model (for future use)
113
+ #loaded_model = FourierPredictor(length, n_components)
114
+ #loaded_model.load_state_dict(torch.load('fourier_predictor_model.pth'))
115
+
105
116
# ```
106
117
107
118
# ### 6. Test the Model
@@ -116,6 +127,52 @@ def forward(self, x):
116
127
predicted_coeffs = model (test_signal_tensor ).numpy ()
117
128
print ("Predicted Fourier Coefficients:" , predicted_coeffs )
118
129
print ("True Fourier Coefficients:" , test_coeffs )
130
+
131
+ import matplotlib .pyplot as plt
132
+
133
+ # Plot the loss curve
134
+ plt .plot (losses )
135
+ plt .xlabel ('Epochs' )
136
+ plt .ylabel ('Loss' )
137
+ plt .title ('Training Loss Curve' )
138
+ plt .show ()
139
+
140
+ ####
141
+
142
+ # Generate a test signal
143
+ test_signal , test_coeffs = generate_signal (1 , n_components , length )
144
+ test_signal_tensor = torch .tensor (test_signal , dtype = torch .float32 )
145
+
146
+ # Get the predicted Fourier coefficients
147
+ model .eval ()
148
+ with torch .no_grad ():
149
+ predicted_coeffs = model (test_signal_tensor ).numpy ()
150
+
151
+ # Generate signals based on true and predicted Fourier coefficients
152
+ def reconstruct_signal (x , coeffs ):
153
+ signal = np .zeros_like (x )
154
+ for amplitude , frequency , phase in coeffs [0 ]:
155
+ signal += amplitude * np .sin (frequency * x + phase )
156
+ return signal
157
+
158
+ # Create x values for the signal
159
+ x = np .linspace (0 , 2 * np .pi , length )
160
+
161
+ # Reconstruct the signals
162
+ true_signal = reconstruct_signal (x , test_coeffs )
163
+ predicted_signal = reconstruct_signal (x , predicted_coeffs )
164
+
165
+ # Plot the original, predicted, and true signals
166
+ plt .figure (figsize = (10 , 6 ))
167
+ plt .plot (x , test_signal [0 ], label = 'Original Signal' , color = 'blue' , alpha = 0.5 )
168
+ plt .plot (x , true_signal , label = 'True Reconstructed Signal' , color = 'green' , linestyle = '--' )
169
+ plt .plot (x , predicted_signal , label = 'Predicted Reconstructed Signal' , color = 'red' , linestyle = '--' )
170
+ plt .legend ()
171
+ plt .title ('True vs Predicted Signal' )
172
+ plt .xlabel ('Time' )
173
+ plt .ylabel ('Amplitude' )
174
+ plt .show ()
175
+
119
176
# ```
120
177
121
178
# ### Explanation:
0 commit comments