1
1
"""Defines the MonteCarlo class."""
2
2
import json
3
3
import os
4
- from multiprocess import JoinableQueue , Process , get_context
5
- from multiprocess . managers import BaseManager , NamespaceProxy
4
+ import types
5
+ from pathlib import Path
6
6
from time import process_time , time
7
+
8
+ import h5py
7
9
import numpy as np
8
10
import simplekml
9
- import types
11
+ from multiprocess import Lock , Process , Queue
12
+ from multiprocess .managers import BaseManager , NamespaceProxy
10
13
14
+ from rocketpy import Function
11
15
from rocketpy ._encoders import RocketPyEncoder
12
16
from rocketpy .plots .monte_carlo_plots import _MonteCarloPlots
13
17
from rocketpy .prints .monte_carlo_prints import _MonteCarloPrints
14
18
from rocketpy .simulation .flight import Flight
19
+ from rocketpy .simulation .sim_config .flight2serializer import flightv1_serializer
20
+ from rocketpy .simulation .sim_config .serializer import function_serializer
15
21
from rocketpy .stochastic import (
16
22
StochasticEnvironment ,
17
- StochasticRocket ,
18
23
StochasticFlight ,
24
+ StochasticRocket ,
19
25
)
20
- from rocketpy .simulation .sim_config .flight2serializer import flightv1_serializer
21
- from rocketpy .simulation .sim_config .run_sim import run_flight
22
26
from rocketpy .tools import (
23
27
generate_monte_carlo_ellipses ,
24
28
generate_monte_carlo_ellipses_coordinates ,
@@ -129,7 +133,14 @@ def __init__(
129
133
self .plots = _MonteCarloPlots (self )
130
134
self ._inputs_dict = {}
131
135
self ._last_print_len = 0 # used to print on the same line
132
- self .batch_path = batch_path
136
+
137
+ if batch_path is None :
138
+ self .batch_path = Path .cwd () / "mc_simulations"
139
+ else :
140
+ self .batch_path = Path (batch_path )
141
+
142
+ if not os .path .exists (self .batch_path ):
143
+ os .makedirs (self .batch_path )
133
144
134
145
# Checks export_list
135
146
self .export_list = self .__check_export_list (export_list )
@@ -210,9 +221,9 @@ def _run_in_parallel(self, number_of_simulations, n_workers=None):
210
221
n_workers = os .cpu_count ()
211
222
212
223
with MonteCarloManager () as manager :
213
- parallel_start = process_time ()
224
+ parallel_start = time ()
214
225
# initialize queue
215
- # simulation_queue = manager.JoinableQueue ()
226
+ write_lock = manager .Lock ()
216
227
sim_counter = manager .SimCounter ()
217
228
218
229
# initialize stochastic objects
@@ -278,11 +289,16 @@ def _run_in_parallel(self, number_of_simulations, n_workers=None):
278
289
sto_rocket ,
279
290
sto_flight ,
280
291
sim_counter ,
281
- self .batch_path ,
292
+ write_lock ,
293
+ self .batch_path / 'montecarlo_output.h5' ,
282
294
),
283
295
)
284
296
processes .append (p )
285
297
298
+ # Initialize write file
299
+ with h5py .File (self .batch_path / 'montecarlo_output.h5' , 'w' ) as _ :
300
+ pass
301
+
286
302
# Starts all the processes
287
303
for p in processes :
288
304
p .start ()
@@ -291,7 +307,7 @@ def _run_in_parallel(self, number_of_simulations, n_workers=None):
291
307
for p in processes :
292
308
p .join ()
293
309
294
- parallel_end = process_time ()
310
+ parallel_end = time ()
295
311
296
312
print ("-" * 80 + "\n All workers joined, simulation complete." )
297
313
print (f"In total, { sim_counter .get_count ()} simulations were performed." )
@@ -306,13 +322,14 @@ def _run_simulation_worker(
306
322
sto_rocket ,
307
323
sto_flight ,
308
324
sim_counter ,
309
- batch_path ,
325
+ write_lock ,
326
+ file_path ,
310
327
):
311
328
"""Runs a simulation from a queue."""
312
329
313
330
for i in range (worker_no , n_sim , n_workers ):
314
331
sim_idx = sim_counter .increment ()
315
- sim_start = process_time ()
332
+ sim_start = time ()
316
333
317
334
env = sto_env .create_object ()
318
335
rocket = sto_rocket .create_object ()
@@ -338,7 +355,22 @@ def _run_simulation_worker(
338
355
339
356
flight_results = MonteCarlo .inspect_object_attributes (flight )
340
357
341
- sim_end = process_time ()
358
+ export_dict = {
359
+ str (i ): {
360
+ "inputs" : input_parameters ,
361
+ "outputs" : flight_results ,
362
+ }
363
+ }
364
+
365
+ # Export to file
366
+ write_lock .acquire ()
367
+
368
+ with h5py .File (file_path , 'a' ) as h5file :
369
+ MonteCarlo .dict_to_h5 (h5file , '/' , export_dict )
370
+
371
+ write_lock .release ()
372
+
373
+ sim_end = time ()
342
374
343
375
print (
344
376
"-" * 80
@@ -843,18 +875,41 @@ def inspect_object_attributes(obj):
843
875
if isinstance (
844
876
attr_value , (int , float , tuple , list , dict , object )
845
877
) and not attr_name .startswith ('__' ):
846
- result [attr_name ] = attr_value
878
+
879
+ if isinstance (attr_value , Function ):
880
+ result [attr_name ] = function_serializer (attr_value )
881
+ else :
882
+ result [attr_name ] = attr_value
847
883
return result
848
884
885
+ @staticmethod
886
+ def dict_to_h5 (h5_file , path , dic ):
887
+ """
888
+ ....
889
+ """
890
+ for key , item in dic .items ():
891
+ if isinstance (
892
+ item , (np .ndarray , np .int64 , np .float64 , str , bytes , int , float )
893
+ ):
894
+ h5_file [path + key ] = item
895
+ elif isinstance (item , Function ):
896
+ raise TypeError (
897
+ "Function objects should be preprocessed before saving."
898
+ )
899
+ elif isinstance (item , dict ):
900
+ MonteCarlo .dict_to_h5 (h5_file , path + key + '/' , item )
901
+ else :
902
+ pass # Implement other types as needed
903
+
849
904
850
905
class MonteCarloManager (BaseManager ):
851
906
def __init__ (self ):
852
907
super ().__init__ ()
853
- self .register ('JoinableQueue ' , JoinableQueue )
908
+ self .register ('Lock ' , Lock )
854
909
self .register ('SimCounter' , SimCounter )
855
- self .register ('StochasticEnvironment' , StochasticEnvironment , StochasticProxy )
856
- self .register ('StochasticRocket' , StochasticRocket , StochasticProxy )
857
- self .register ('StochasticFlight' , StochasticFlight , StochasticProxy )
910
+ self .register ('StochasticEnvironment' , StochasticEnvironment )
911
+ self .register ('StochasticRocket' , StochasticRocket )
912
+ self .register ('StochasticFlight' , StochasticFlight )
858
913
859
914
860
915
class SimCounter :
@@ -867,17 +922,3 @@ def increment(self) -> int:
867
922
868
923
def get_count (self ) -> int :
869
924
return self .count
870
-
871
-
872
- class StochasticProxy (NamespaceProxy ):
873
- _exposed_ = tuple (dir (StochasticEnvironment ))
874
-
875
- def __getattr__ (self , name ):
876
- result = super ().__getattr__ (name )
877
- if isinstance (result , types .MethodType ):
878
-
879
- def wrapper (* args , ** kwargs ):
880
- return self ._callmethod (name , args , kwargs )
881
-
882
- return wrapper
883
- return result
0 commit comments