Skip to content

Commit 8e02e2a

Browse files
committedSep 16, 2024
save model and visualization the process
1 parent 3a2af35 commit 8e02e2a

File tree

3 files changed

+62
-4
lines changed

3 files changed

+62
-4
lines changed
 

‎.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ rnn_model.pth
1414
rnn_model_ai.pth
1515
transformer_model_ai.pth
1616
transformer_model_ai_chinese.pth
17+
fourier_predictor_model.pth

‎fourier_signal_nn.png

326 KB
Loading

‎fourier_signal_nn.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,30 @@ def forward(self, x):
8989
optimizer = optim.Adam(model.parameters(), lr=0.001)
9090

9191
# Train the model
92-
n_epochs = 1000
92+
n_epochs = 3000
93+
losses = []
94+
9395
for epoch in range(n_epochs):
9496
model.train()
95-
97+
9698
optimizer.zero_grad()
9799
outputs = model(X_train)
98-
100+
99101
loss = criterion(outputs, y_train)
100102
loss.backward()
101103
optimizer.step()
102-
104+
105+
losses.append(loss.item())
106+
103107
if epoch % 100 == 0:
104108
print(f'Epoch {epoch+1}/{n_epochs}, Loss: {loss.item()}')
109+
110+
torch.save(model.state_dict(), 'fourier_predictor_model.pth')
111+
112+
# Load the model (for future use)
113+
#loaded_model = FourierPredictor(length, n_components)
114+
#loaded_model.load_state_dict(torch.load('fourier_predictor_model.pth'))
115+
105116
# ```
106117

107118
# ### 6. Test the Model
@@ -116,6 +127,52 @@ def forward(self, x):
116127
predicted_coeffs = model(test_signal_tensor).numpy()
117128
print("Predicted Fourier Coefficients:", predicted_coeffs)
118129
print("True Fourier Coefficients:", test_coeffs)
130+
131+
import matplotlib.pyplot as plt
132+
133+
# Plot the loss curve
134+
plt.plot(losses)
135+
plt.xlabel('Epochs')
136+
plt.ylabel('Loss')
137+
plt.title('Training Loss Curve')
138+
plt.show()
139+
140+
####
141+
142+
# Generate a test signal
143+
test_signal, test_coeffs = generate_signal(1, n_components, length)
144+
test_signal_tensor = torch.tensor(test_signal, dtype=torch.float32)
145+
146+
# Get the predicted Fourier coefficients
147+
model.eval()
148+
with torch.no_grad():
149+
predicted_coeffs = model(test_signal_tensor).numpy()
150+
151+
# Generate signals based on true and predicted Fourier coefficients
152+
def reconstruct_signal(x, coeffs):
153+
signal = np.zeros_like(x)
154+
for amplitude, frequency, phase in coeffs[0]:
155+
signal += amplitude * np.sin(frequency * x + phase)
156+
return signal
157+
158+
# Create x values for the signal
159+
x = np.linspace(0, 2 * np.pi, length)
160+
161+
# Reconstruct the signals
162+
true_signal = reconstruct_signal(x, test_coeffs)
163+
predicted_signal = reconstruct_signal(x, predicted_coeffs)
164+
165+
# Plot the original, predicted, and true signals
166+
plt.figure(figsize=(10, 6))
167+
plt.plot(x, test_signal[0], label='Original Signal', color='blue', alpha=0.5)
168+
plt.plot(x, true_signal, label='True Reconstructed Signal', color='green', linestyle='--')
169+
plt.plot(x, predicted_signal, label='Predicted Reconstructed Signal', color='red', linestyle='--')
170+
plt.legend()
171+
plt.title('True vs Predicted Signal')
172+
plt.xlabel('Time')
173+
plt.ylabel('Amplitude')
174+
plt.show()
175+
119176
# ```
120177

121178
# ### Explanation:

0 commit comments

Comments
 (0)
Please sign in to comment.