Skip to content

Commit 1b50e94

Browse files
committed
Write mode added
1 parent 2927448 commit 1b50e94

File tree

2 files changed

+75
-34
lines changed

2 files changed

+75
-34
lines changed

docs/notebooks/monte_carlo_analysis/mc_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,4 @@
218218
)
219219

220220
##### Running the Monte Carlo Simulations
221-
test_dispersion.simulate(number_of_simulations=1, append=False, parallel=True)
221+
test_dispersion.simulate(number_of_simulations=50, append=False, parallel=True)

rocketpy/simulation/monte_carlo.py

+74-33
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
"""Defines the MonteCarlo class."""
22
import json
33
import os
4-
from multiprocess import JoinableQueue, Process, get_context
5-
from multiprocess.managers import BaseManager, NamespaceProxy
4+
import types
5+
from pathlib import Path
66
from time import process_time, time
7+
8+
import h5py
79
import numpy as np
810
import simplekml
9-
import types
11+
from multiprocess import Lock, Process, Queue
12+
from multiprocess.managers import BaseManager, NamespaceProxy
1013

14+
from rocketpy import Function
1115
from rocketpy._encoders import RocketPyEncoder
1216
from rocketpy.plots.monte_carlo_plots import _MonteCarloPlots
1317
from rocketpy.prints.monte_carlo_prints import _MonteCarloPrints
1418
from rocketpy.simulation.flight import Flight
19+
from rocketpy.simulation.sim_config.flight2serializer import flightv1_serializer
20+
from rocketpy.simulation.sim_config.serializer import function_serializer
1521
from rocketpy.stochastic import (
1622
StochasticEnvironment,
17-
StochasticRocket,
1823
StochasticFlight,
24+
StochasticRocket,
1925
)
20-
from rocketpy.simulation.sim_config.flight2serializer import flightv1_serializer
21-
from rocketpy.simulation.sim_config.run_sim import run_flight
2226
from rocketpy.tools import (
2327
generate_monte_carlo_ellipses,
2428
generate_monte_carlo_ellipses_coordinates,
@@ -129,7 +133,14 @@ def __init__(
129133
self.plots = _MonteCarloPlots(self)
130134
self._inputs_dict = {}
131135
self._last_print_len = 0 # used to print on the same line
132-
self.batch_path = batch_path
136+
137+
if batch_path is None:
138+
self.batch_path = Path.cwd() / "mc_simulations"
139+
else:
140+
self.batch_path = Path(batch_path)
141+
142+
if not os.path.exists(self.batch_path):
143+
os.makedirs(self.batch_path)
133144

134145
# Checks export_list
135146
self.export_list = self.__check_export_list(export_list)
@@ -210,9 +221,9 @@ def _run_in_parallel(self, number_of_simulations, n_workers=None):
210221
n_workers = os.cpu_count()
211222

212223
with MonteCarloManager() as manager:
213-
parallel_start = process_time()
224+
parallel_start = time()
214225
# initialize queue
215-
# simulation_queue = manager.JoinableQueue()
226+
write_lock = manager.Lock()
216227
sim_counter = manager.SimCounter()
217228

218229
# initialize stochastic objects
@@ -278,11 +289,16 @@ def _run_in_parallel(self, number_of_simulations, n_workers=None):
278289
sto_rocket,
279290
sto_flight,
280291
sim_counter,
281-
self.batch_path,
292+
write_lock,
293+
self.batch_path / 'montecarlo_output.h5',
282294
),
283295
)
284296
processes.append(p)
285297

298+
# Initialize write file
299+
with h5py.File(self.batch_path / 'montecarlo_output.h5', 'w') as _:
300+
pass
301+
286302
# Starts all the processes
287303
for p in processes:
288304
p.start()
@@ -291,7 +307,7 @@ def _run_in_parallel(self, number_of_simulations, n_workers=None):
291307
for p in processes:
292308
p.join()
293309

294-
parallel_end = process_time()
310+
parallel_end = time()
295311

296312
print("-" * 80 + "\nAll workers joined, simulation complete.")
297313
print(f"In total, {sim_counter.get_count()} simulations were performed.")
@@ -306,13 +322,14 @@ def _run_simulation_worker(
306322
sto_rocket,
307323
sto_flight,
308324
sim_counter,
309-
batch_path,
325+
write_lock,
326+
file_path,
310327
):
311328
"""Runs a simulation from a queue."""
312329

313330
for i in range(worker_no, n_sim, n_workers):
314331
sim_idx = sim_counter.increment()
315-
sim_start = process_time()
332+
sim_start = time()
316333

317334
env = sto_env.create_object()
318335
rocket = sto_rocket.create_object()
@@ -338,7 +355,22 @@ def _run_simulation_worker(
338355

339356
flight_results = MonteCarlo.inspect_object_attributes(flight)
340357

341-
sim_end = process_time()
358+
export_dict = {
359+
str(i): {
360+
"inputs": input_parameters,
361+
"outputs": flight_results,
362+
}
363+
}
364+
365+
# Export to file
366+
write_lock.acquire()
367+
368+
with h5py.File(file_path, 'a') as h5file:
369+
MonteCarlo.dict_to_h5(h5file, '/', export_dict)
370+
371+
write_lock.release()
372+
373+
sim_end = time()
342374

343375
print(
344376
"-" * 80
@@ -843,18 +875,41 @@ def inspect_object_attributes(obj):
843875
if isinstance(
844876
attr_value, (int, float, tuple, list, dict, object)
845877
) and not attr_name.startswith('__'):
846-
result[attr_name] = attr_value
878+
879+
if isinstance(attr_value, Function):
880+
result[attr_name] = function_serializer(attr_value)
881+
else:
882+
result[attr_name] = attr_value
847883
return result
848884

885+
@staticmethod
886+
def dict_to_h5(h5_file, path, dic):
887+
"""
888+
....
889+
"""
890+
for key, item in dic.items():
891+
if isinstance(
892+
item, (np.ndarray, np.int64, np.float64, str, bytes, int, float)
893+
):
894+
h5_file[path + key] = item
895+
elif isinstance(item, Function):
896+
raise TypeError(
897+
"Function objects should be preprocessed before saving."
898+
)
899+
elif isinstance(item, dict):
900+
MonteCarlo.dict_to_h5(h5_file, path + key + '/', item)
901+
else:
902+
pass # Implement other types as needed
903+
849904

850905
class MonteCarloManager(BaseManager):
851906
def __init__(self):
852907
super().__init__()
853-
self.register('JoinableQueue', JoinableQueue)
908+
self.register('Lock', Lock)
854909
self.register('SimCounter', SimCounter)
855-
self.register('StochasticEnvironment', StochasticEnvironment, StochasticProxy)
856-
self.register('StochasticRocket', StochasticRocket, StochasticProxy)
857-
self.register('StochasticFlight', StochasticFlight, StochasticProxy)
910+
self.register('StochasticEnvironment', StochasticEnvironment)
911+
self.register('StochasticRocket', StochasticRocket)
912+
self.register('StochasticFlight', StochasticFlight)
858913

859914

860915
class SimCounter:
@@ -867,17 +922,3 @@ def increment(self) -> int:
867922

868923
def get_count(self) -> int:
869924
return self.count
870-
871-
872-
class StochasticProxy(NamespaceProxy):
873-
_exposed_ = tuple(dir(StochasticEnvironment))
874-
875-
def __getattr__(self, name):
876-
result = super().__getattr__(name)
877-
if isinstance(result, types.MethodType):
878-
879-
def wrapper(*args, **kwargs):
880-
return self._callmethod(name, args, kwargs)
881-
882-
return wrapper
883-
return result

0 commit comments

Comments
 (0)