11
11
12
12
13
13
class CartPoleCrashing (CartPoleEnv ):
14
- """A CartPole env that crashes from time to time.
14
+ """A CartPole env that crashes (or stalls) from time to time.
15
15
16
16
Useful for testing faulty sub-env (within a vectorized env) handling by
17
- RolloutWorkers .
17
+ EnvRunners .
18
18
19
19
After crashing, the env expects a `reset()` call next (calling `step()` will
20
20
result in yet another error), which may or may not take a very long time to
21
21
complete. This simulates the env having to reinitialize some sub-processes, e.g.
22
22
an external connection.
23
+
24
+ The env can also be configured to stall (and do nothing during a call to `step()`)
25
+ from time to time for a configurable amount of time.
23
26
"""
24
27
25
28
def __init__ (self , config = None ):
26
29
super ().__init__ ()
27
30
28
- config = config or {}
31
+ self . config = config if config is not None else {}
29
32
30
33
# Crash probability (in each `step()`).
31
34
self .p_crash = config .get ("p_crash" , 0.005 )
35
+ # Crash probability when `reset()` is called.
32
36
self .p_crash_reset = config .get ("p_crash_reset" , 0.0 )
37
+ # Crash exactly after every n steps. If a 2-tuple, will uniformly sample
38
+ # crash timesteps from in between the two given values.
33
39
self .crash_after_n_steps = config .get ("crash_after_n_steps" )
34
- # Only crash (with prob=p_crash) if on certain worker indices.
40
+ self ._crash_after_n_steps = None
41
+ assert (
42
+ self .crash_after_n_steps is None
43
+ or isinstance (self .crash_after_n_steps , int )
44
+ or (
45
+ isinstance (self .crash_after_n_steps , tuple )
46
+ and len (self .crash_after_n_steps ) == 2
47
+ )
48
+ )
49
+ # Only ever crash, if on certain worker indices.
35
50
faulty_indices = config .get ("crash_on_worker_indices" , None )
36
51
if faulty_indices and config .worker_index not in faulty_indices :
37
52
self .p_crash = 0.0
38
53
self .p_crash_reset = 0.0
39
54
self .crash_after_n_steps = None
55
+
56
+ # Stall probability (in each `step()`).
57
+ self .p_stall = config .get ("p_stall" , 0.0 )
58
+ # Stall probability when `reset()` is called.
59
+ self .p_stall_reset = config .get ("p_stall_reset" , 0.0 )
60
+ # Stall exactly after every n steps.
61
+ self .stall_after_n_steps = config .get ("stall_after_n_steps" )
62
+ self ._stall_after_n_steps = None
63
+ # Amount of time to stall. If a 2-tuple, will uniformly sample from in between
64
+ # the two given values.
65
+ self .stall_time_sec = config .get ("stall_time_sec" )
66
+ assert (
67
+ self .stall_time_sec is None
68
+ or isinstance (self .stall_time_sec , (int , float ))
69
+ or (
70
+ isinstance (self .stall_time_sec , tuple ) and len (self .stall_time_sec ) == 2
71
+ )
72
+ )
73
+
74
+ # Only ever stall, if on certain worker indices.
75
+ faulty_indices = config .get ("stall_on_worker_indices" , None )
76
+ if faulty_indices and config .worker_index not in faulty_indices :
77
+ self .p_stall = 0.0
78
+ self .p_stall_reset = 0.0
79
+ self .stall_after_n_steps = None
80
+
40
81
# Timestep counter for the ongoing episode.
41
82
self .timesteps = 0
42
83
43
84
# Time in seconds to initialize (in this c'tor).
85
+ sample = 0.0
44
86
if "init_time_s" in config :
45
- init_time_s = config .get ("init_time_s" , 0 )
46
- else :
47
- init_time_s = np .random .randint (
48
- config .get ("init_time_s_min" , 0 ),
49
- config .get ("init_time_s_max" , 1 ),
87
+ sample = (
88
+ config ["init_time_s" ]
89
+ if not isinstance (config ["init_time_s" ], tuple )
90
+ else np .random .uniform (
91
+ config ["init_time_s" ][0 ], config ["init_time_s" ][1 ]
92
+ )
50
93
)
51
- print (f"Initializing crashing env with init-delay of { init_time_s } sec ..." )
52
- time .sleep (init_time_s )
94
+
95
+ print (f"Initializing crashing env (with init-delay of { sample } sec) ..." )
96
+ time .sleep (sample )
53
97
54
98
# No env pre-checking?
55
99
self ._skip_env_checking = config .get ("skip_env_checking" , False )
@@ -61,30 +105,81 @@ def __init__(self, config=None):
61
105
def reset (self , * , seed = None , options = None ):
62
106
# Reset timestep counter for the new episode.
63
107
self .timesteps = 0
108
+ self ._crash_after_n_steps = None
109
+
64
110
# Should we crash?
65
- if self ._rng .rand () < self .p_crash_reset or (
66
- self .crash_after_n_steps is not None and self .crash_after_n_steps == 0
67
- ):
111
+ if self ._should_crash (p = self .p_crash_reset ):
68
112
raise EnvError (
69
- "Simulated env crash in `reset()`! Feel free to use any "
70
- "other exception type here instead."
113
+ f"Simulated env crash on worker={ self .config .worker_index } "
114
+ f"env-idx={ self .config .vector_index } during `reset()`! "
115
+ "Feel free to use any other exception type here instead."
71
116
)
117
+ # Should we stall for a while?
118
+ self ._stall_if_necessary (p = self .p_stall_reset )
119
+
72
120
return super ().reset ()
73
121
74
122
@override (CartPoleEnv )
75
123
def step (self , action ):
76
124
# Increase timestep counter for the ongoing episode.
77
125
self .timesteps += 1
126
+
78
127
# Should we crash?
79
- if self ._rng .rand () < self .p_crash or (
80
- self .crash_after_n_steps and self .crash_after_n_steps == self .timesteps
81
- ):
128
+ if self ._should_crash (p = self .p_crash ):
82
129
raise EnvError (
83
- "Simulated env crash in `step()`! Feel free to use any "
84
- "other exception type here instead."
130
+ f"Simulated env crash on worker={ self .config .worker_index } "
131
+ f"env-idx={ self .config .vector_index } during `step()`! "
132
+ "Feel free to use any other exception type here instead."
85
133
)
86
- # No crash.
134
+ # Should we stall for a while?
135
+ self ._stall_if_necessary (p = self .p_stall )
136
+
87
137
return super ().step (action )
88
138
139
+ def _should_crash (self , p ):
140
+ rnd = self ._rng .rand ()
141
+ if rnd < p :
142
+ print (f"Should crash! ({ rnd } < { p } )" )
143
+ return True
144
+ elif self .crash_after_n_steps is not None :
145
+ if self ._crash_after_n_steps is None :
146
+ self ._crash_after_n_steps = (
147
+ self .crash_after_n_steps
148
+ if not isinstance (self .crash_after_n_steps , tuple )
149
+ else np .random .randint (
150
+ self .crash_after_n_steps [0 ], self .crash_after_n_steps [1 ]
151
+ )
152
+ )
153
+ if self ._crash_after_n_steps == self .timesteps :
154
+ print (f"Should crash! (after { self .timesteps } steps)" )
155
+ return True
156
+
157
+ return False
158
+
159
+ def _stall_if_necessary (self , p ):
160
+ stall = False
161
+ if self ._rng .rand () < p :
162
+ stall = True
163
+ elif self .stall_after_n_steps is not None :
164
+ if self ._stall_after_n_steps is None :
165
+ self ._stall_after_n_steps = (
166
+ self .stall_after_n_steps
167
+ if not isinstance (self .stall_after_n_steps , tuple )
168
+ else np .random .randint (
169
+ self .stall_after_n_steps [0 ], self .stall_after_n_steps [1 ]
170
+ )
171
+ )
172
+ if self ._stall_after_n_steps == self .timesteps :
173
+ stall = True
174
+
175
+ if stall :
176
+ sec = (
177
+ self .stall_time_sec
178
+ if not isinstance (self .stall_time_sec , tuple )
179
+ else np .random .uniform (self .stall_time_sec [0 ], self .stall_time_sec [1 ])
180
+ )
181
+ print (f" -> will stall for { sec } sec ..." )
182
+ time .sleep (sec )
183
+
89
184
90
185
MultiAgentCartPoleCrashing = make_multi_agent (lambda config : CartPoleCrashing (config ))
0 commit comments