Skip to content

Commit 7001982

Browse files
authoredDec 8, 2023
[RLlib] Add and enhance fault-tolerance tests for APPO. (#40743)
1 parent 563f7d8 commit 7001982

19 files changed

+520
-294
lines changed
 

‎rllib/BUILD

+42-2
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,48 @@ py_test(
222222
args = ["--dir=tuned_examples/appo"]
223223
)
224224

225+
# Tests against crashing or hanging environments.
226+
# Single-agent: Crash only.
227+
py_test(
228+
name = "learning_tests_cartpole_crashing_appo",
229+
main = "tests/run_regression_tests.py",
230+
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
231+
size = "large",
232+
srcs = ["tests/run_regression_tests.py"],
233+
data = ["tuned_examples/appo/cartpole-crashing-recreate-workers-appo.py"],
234+
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
235+
)
236+
# Single-agent: Crash and stall.
237+
py_test(
238+
name = "learning_tests_cartpole_crashing_and_stalling_appo",
239+
main = "tests/run_regression_tests.py",
240+
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
241+
size = "large",
242+
srcs = ["tests/run_regression_tests.py"],
243+
data = ["tuned_examples/appo/cartpole-crashing-and-stalling-recreate-workers-appo.py"],
244+
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
245+
)
246+
# Multi-agent: Crash only.
247+
py_test(
248+
name = "learning_tests_multi_agent_cartpole_crashing_appo",
249+
main = "tests/run_regression_tests.py",
250+
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
251+
size = "large",
252+
srcs = ["tests/run_regression_tests.py"],
253+
data = ["tuned_examples/appo/multi-agent-cartpole-crashing-recreate-workers-appo.py"],
254+
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
255+
)
256+
# Multi-agent: Crash and stall.
257+
py_test(
258+
name = "learning_tests_multi_agent_cartpole_crashing_and_stalling_appo",
259+
main = "tests/run_regression_tests.py",
260+
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "crashing_cartpole"],
261+
size = "large",
262+
srcs = ["tests/run_regression_tests.py"],
263+
data = ["tuned_examples/appo/multi-agent-cartpole-crashing-and-stalling-recreate-workers-appo.py"],
264+
args = ["--dir=tuned_examples/appo", "--num-cpus=6"]
265+
)
266+
225267
# CQL
226268
py_test(
227269
name = "learning_tests_pendulum_cql",
@@ -1569,7 +1611,6 @@ py_test(
15691611
args = ["TestCheckpointRestorePPO"]
15701612
)
15711613

1572-
15731614
py_test(
15741615
name = "tests/test_checkpoint_restore_ppo_gpu",
15751616
main = "tests/test_algorithm_checkpoint_restore.py",
@@ -1588,7 +1629,6 @@ py_test(
15881629
args = ["TestCheckpointRestoreOffPolicy"]
15891630
)
15901631

1591-
15921632
py_test(
15931633
name = "tests/test_checkpoint_restore_off_policy_gpu",
15941634
main = "tests/test_algorithm_checkpoint_restore.py",

‎rllib/algorithms/algorithm.py

+3
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,9 @@ def restore_workers(self, workers: WorkerSet) -> None:
15641564
restored = workers.probe_unhealthy_workers()
15651565

15661566
if restored:
1567+
# Count the restored workers.
1568+
self._counters["total_num_restored_workers"] += len(restored)
1569+
15671570
from_worker = workers.local_worker() or self.workers.local_worker()
15681571
# Get the state of the correct (reference) worker. E.g. The local worker
15691572
# of the main WorkerSet.

‎rllib/algorithms/impala/impala.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -860,12 +860,11 @@ def default_resource_request(
860860
strategy=cf.placement_strategy,
861861
)
862862

863-
def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]):
863+
def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]) -> None:
864864
"""Concatenate batches that are being returned from rollout workers
865865
866866
Args:
867-
batches: batches of experiences from rollout workers
868-
867+
batches: List of batches of experiences from EnvRunners.
869868
"""
870869

871870
def aggregate_into_larger_batch():
@@ -878,6 +877,33 @@ def aggregate_into_larger_batch():
878877
self.batch_being_built = []
879878

880879
for batch in batches:
880+
# TODO (sven): Strange bug in tf/tf2 after a RolloutWorker crash and proper
881+
# restart. The bug is related to (old, non-V2) connectors being used and
882+
# seems to happen inside the AgentCollector's `add_action_reward_next_obs`
883+
# method, at the end of which the number of vf_preds (and all other
884+
# extra action outs) in the batch is one smaller than the number of obs/
885+
# actions/rewards, which leads to a malformed train batch. IMPALA/APPO then
886+
# crash inside the loss function (during v-trace operations). The following
887+
# if-block prevents this from happening and it can be removed once we are
888+
# on the new API stack for good (and use the new connectors and also no
889+
# longer AgentCollectors, RolloutWorkers, Policies, TrajectoryView API,
890+
# etc..):
891+
if (
892+
self.config.batch_mode == "truncate_episodes"
893+
and self.config.enable_connectors
894+
and self.config.recreate_failed_workers
895+
and self.config.framework_str in ["tf", "tf2"]
896+
):
897+
if any(
898+
SampleBatch.VF_PREDS in pb
899+
and (
900+
pb[SampleBatch.VF_PREDS].shape[0]
901+
!= pb[SampleBatch.REWARDS].shape[0]
902+
)
903+
for pb in batch.policy_batches.values()
904+
):
905+
continue
906+
881907
self.batch_being_built.append(batch)
882908
aggregate_into_larger_batch()
883909

@@ -929,7 +955,7 @@ def get_samples_from_workers(
929955
sample_batches = [(0, sample_batch)]
930956
else:
931957
# Not much we can do. Return empty list and wait.
932-
return []
958+
sample_batches = []
933959

934960
return sample_batches
935961

‎rllib/algorithms/tests/test_callbacks.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from ray.rllib.evaluation.episode import Episode
1212
from ray.rllib.examples.env.random_env import RandomEnv
1313
from ray.rllib.utils.test_utils import framework_iterator
14+
from ray import tune
1415

1516

16-
class OnWorkerCreatedCallbacks(DefaultCallbacks):
17+
class OnWorkersRecreatedCallbacks(DefaultCallbacks):
1718
def on_workers_recreated(
1819
self,
1920
*,
@@ -109,11 +110,13 @@ def tearDownClass(cls):
109110
ray.shutdown()
110111

111112
def test_on_workers_recreated_callback(self):
113+
tune.register_env("env", lambda cfg: CartPoleCrashing(cfg))
114+
112115
config = (
113116
APPOConfig()
114-
.environment(CartPoleCrashing)
115-
.callbacks(OnWorkerCreatedCallbacks)
116-
.rollouts(num_rollout_workers=2)
117+
.environment("env")
118+
.callbacks(OnWorkersRecreatedCallbacks)
119+
.rollouts(num_rollout_workers=3)
117120
.fault_tolerance(recreate_failed_workers=True)
118121
)
119122

@@ -122,19 +125,24 @@ def test_on_workers_recreated_callback(self):
122125
original_worker_ids = algo.workers.healthy_worker_ids()
123126
for id_ in original_worker_ids:
124127
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] == 0)
128+
self.assertTrue(algo._counters["total_num_workers_recreated"] == 0)
125129

126130
# After building the algorithm, we should have 2 healthy (remote) workers.
127-
self.assertTrue(len(original_worker_ids) == 2)
131+
self.assertTrue(len(original_worker_ids) == 3)
128132

129133
# Train a bit (and have the envs/workers crash a couple of times).
130-
for _ in range(3):
131-
algo.train()
134+
for _ in range(5):
135+
print(algo.train())
132136

133-
# After training, each new worker should have been recreated at least once.
137+
# After training, the `on_workers_recreated` callback should have captured
138+
# the exact worker IDs recreated (the exact number of times) as the actor
139+
# manager itself. This confirms that the callback is triggered correctly,
140+
# always.
134141
new_worker_ids = algo.workers.healthy_worker_ids()
135-
self.assertTrue(len(new_worker_ids) == 2)
142+
self.assertTrue(len(new_worker_ids) == 3)
136143
for id_ in new_worker_ids:
137-
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] >= 1)
144+
# num_restored = algo.workers.restored_actors_history[id_]
145+
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] > 1)
138146
algo.stop()
139147

140148
def test_on_init_and_checkpoint_loaded(self):

‎rllib/evaluation/collectors/agent_collector.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ray.rllib.policy.sample_batch import SampleBatch
1111
from ray.rllib.policy.view_requirement import ViewRequirement
12-
from ray.rllib.utils.framework import try_import_tf, try_import_torch
12+
from ray.rllib.utils.framework import try_import_torch
1313
from ray.rllib.utils.spaces.space_utils import (
1414
flatten_to_single_ndarray,
1515
get_dummy_batch_for_space,
@@ -24,7 +24,6 @@
2424

2525
logger = logging.getLogger(__name__)
2626

27-
_, tf, _ = try_import_tf()
2827
torch, _ = try_import_torch()
2928

3029

‎rllib/evaluation/collectors/simple_list_collector.py

-3
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,6 @@ def postprocess_episode(
426426
episode_id = episode.episode_id
427427
policy_collector_group = episode.batch_builder
428428

429-
# TODO: (sven) Once we implement multi-agent communication channels,
430-
# we have to resolve the restriction of only sending other agent
431-
# batches from the same policy to the postprocess methods.
432429
# Build SampleBatches for the given episode.
433430
pre_batches = {}
434431
for (eps_id, agent_id), collector in self.agent_collectors.items():

‎rllib/examples/env/cartpole_crashing.py

+117-22
Original file line numberDiff line numberDiff line change
@@ -11,45 +11,89 @@
1111

1212

1313
class CartPoleCrashing(CartPoleEnv):
14-
"""A CartPole env that crashes from time to time.
14+
"""A CartPole env that crashes (or stalls) from time to time.
1515
1616
Useful for testing faulty sub-env (within a vectorized env) handling by
17-
RolloutWorkers.
17+
EnvRunners.
1818
1919
After crashing, the env expects a `reset()` call next (calling `step()` will
2020
result in yet another error), which may or may not take a very long time to
2121
complete. This simulates the env having to reinitialize some sub-processes, e.g.
2222
an external connection.
23+
24+
The env can also be configured to stall (and do nothing during a call to `step()`)
25+
from time to time for a configurable amount of time.
2326
"""
2427

2528
def __init__(self, config=None):
2629
super().__init__()
2730

28-
config = config or {}
31+
self.config = config if config is not None else {}
2932

3033
# Crash probability (in each `step()`).
3134
self.p_crash = config.get("p_crash", 0.005)
35+
# Crash probability when `reset()` is called.
3236
self.p_crash_reset = config.get("p_crash_reset", 0.0)
37+
# Crash exactly after every n steps. If a 2-tuple, will uniformly sample
38+
# crash timesteps from in between the two given values.
3339
self.crash_after_n_steps = config.get("crash_after_n_steps")
34-
# Only crash (with prob=p_crash) if on certain worker indices.
40+
self._crash_after_n_steps = None
41+
assert (
42+
self.crash_after_n_steps is None
43+
or isinstance(self.crash_after_n_steps, int)
44+
or (
45+
isinstance(self.crash_after_n_steps, tuple)
46+
and len(self.crash_after_n_steps) == 2
47+
)
48+
)
49+
# Only ever crash, if on certain worker indices.
3550
faulty_indices = config.get("crash_on_worker_indices", None)
3651
if faulty_indices and config.worker_index not in faulty_indices:
3752
self.p_crash = 0.0
3853
self.p_crash_reset = 0.0
3954
self.crash_after_n_steps = None
55+
56+
# Stall probability (in each `step()`).
57+
self.p_stall = config.get("p_stall", 0.0)
58+
# Stall probability when `reset()` is called.
59+
self.p_stall_reset = config.get("p_stall_reset", 0.0)
60+
# Stall exactly after every n steps.
61+
self.stall_after_n_steps = config.get("stall_after_n_steps")
62+
self._stall_after_n_steps = None
63+
# Amount of time to stall. If a 2-tuple, will uniformly sample from in between
64+
# the two given values.
65+
self.stall_time_sec = config.get("stall_time_sec")
66+
assert (
67+
self.stall_time_sec is None
68+
or isinstance(self.stall_time_sec, (int, float))
69+
or (
70+
isinstance(self.stall_time_sec, tuple) and len(self.stall_time_sec) == 2
71+
)
72+
)
73+
74+
# Only ever stall, if on certain worker indices.
75+
faulty_indices = config.get("stall_on_worker_indices", None)
76+
if faulty_indices and config.worker_index not in faulty_indices:
77+
self.p_stall = 0.0
78+
self.p_stall_reset = 0.0
79+
self.stall_after_n_steps = None
80+
4081
# Timestep counter for the ongoing episode.
4182
self.timesteps = 0
4283

4384
# Time in seconds to initialize (in this c'tor).
85+
sample = 0.0
4486
if "init_time_s" in config:
45-
init_time_s = config.get("init_time_s", 0)
46-
else:
47-
init_time_s = np.random.randint(
48-
config.get("init_time_s_min", 0),
49-
config.get("init_time_s_max", 1),
87+
sample = (
88+
config["init_time_s"]
89+
if not isinstance(config["init_time_s"], tuple)
90+
else np.random.uniform(
91+
config["init_time_s"][0], config["init_time_s"][1]
92+
)
5093
)
51-
print(f"Initializing crashing env with init-delay of {init_time_s}sec ...")
52-
time.sleep(init_time_s)
94+
95+
print(f"Initializing crashing env (with init-delay of {sample}sec) ...")
96+
time.sleep(sample)
5397

5498
# No env pre-checking?
5599
self._skip_env_checking = config.get("skip_env_checking", False)
@@ -61,30 +105,81 @@ def __init__(self, config=None):
61105
def reset(self, *, seed=None, options=None):
62106
# Reset timestep counter for the new episode.
63107
self.timesteps = 0
108+
self._crash_after_n_steps = None
109+
64110
# Should we crash?
65-
if self._rng.rand() < self.p_crash_reset or (
66-
self.crash_after_n_steps is not None and self.crash_after_n_steps == 0
67-
):
111+
if self._should_crash(p=self.p_crash_reset):
68112
raise EnvError(
69-
"Simulated env crash in `reset()`! Feel free to use any "
70-
"other exception type here instead."
113+
f"Simulated env crash on worker={self.config.worker_index} "
114+
f"env-idx={self.config.vector_index} during `reset()`! "
115+
"Feel free to use any other exception type here instead."
71116
)
117+
# Should we stall for a while?
118+
self._stall_if_necessary(p=self.p_stall_reset)
119+
72120
return super().reset()
73121

74122
@override(CartPoleEnv)
75123
def step(self, action):
76124
# Increase timestep counter for the ongoing episode.
77125
self.timesteps += 1
126+
78127
# Should we crash?
79-
if self._rng.rand() < self.p_crash or (
80-
self.crash_after_n_steps and self.crash_after_n_steps == self.timesteps
81-
):
128+
if self._should_crash(p=self.p_crash):
82129
raise EnvError(
83-
"Simulated env crash in `step()`! Feel free to use any "
84-
"other exception type here instead."
130+
f"Simulated env crash on worker={self.config.worker_index} "
131+
f"env-idx={self.config.vector_index} during `step()`! "
132+
"Feel free to use any other exception type here instead."
85133
)
86-
# No crash.
134+
# Should we stall for a while?
135+
self._stall_if_necessary(p=self.p_stall)
136+
87137
return super().step(action)
88138

139+
def _should_crash(self, p):
140+
rnd = self._rng.rand()
141+
if rnd < p:
142+
print(f"Should crash! ({rnd} < {p})")
143+
return True
144+
elif self.crash_after_n_steps is not None:
145+
if self._crash_after_n_steps is None:
146+
self._crash_after_n_steps = (
147+
self.crash_after_n_steps
148+
if not isinstance(self.crash_after_n_steps, tuple)
149+
else np.random.randint(
150+
self.crash_after_n_steps[0], self.crash_after_n_steps[1]
151+
)
152+
)
153+
if self._crash_after_n_steps == self.timesteps:
154+
print(f"Should crash! (after {self.timesteps} steps)")
155+
return True
156+
157+
return False
158+
159+
def _stall_if_necessary(self, p):
160+
stall = False
161+
if self._rng.rand() < p:
162+
stall = True
163+
elif self.stall_after_n_steps is not None:
164+
if self._stall_after_n_steps is None:
165+
self._stall_after_n_steps = (
166+
self.stall_after_n_steps
167+
if not isinstance(self.stall_after_n_steps, tuple)
168+
else np.random.randint(
169+
self.stall_after_n_steps[0], self.stall_after_n_steps[1]
170+
)
171+
)
172+
if self._stall_after_n_steps == self.timesteps:
173+
stall = True
174+
175+
if stall:
176+
sec = (
177+
self.stall_time_sec
178+
if not isinstance(self.stall_time_sec, tuple)
179+
else np.random.uniform(self.stall_time_sec[0], self.stall_time_sec[1])
180+
)
181+
print(f" -> will stall for {sec}sec ...")
182+
time.sleep(sec)
183+
89184

90185
MultiAgentCartPoleCrashing = make_multi_agent(lambda config: CartPoleCrashing(config))

‎rllib/execution/multi_gpu_learner_thread.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,12 @@ def __init__(
140140

141141
@override(LearnerThread)
142142
def step(self) -> None:
143-
assert self.loader_thread.is_alive()
143+
if not self.loader_thread.is_alive():
144+
raise RuntimeError(
145+
"The `_MultiGPULoaderThread` has died! Will therefore also terminate "
146+
"the `MultiGPULearnerThread`."
147+
)
148+
144149
with self.load_wait_timer:
145150
buffer_idx, released = self.ready_tower_stacks_buffer.get()
146151

‎rllib/policy/eager_tf_policy_v2.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1033,9 +1033,8 @@ def _compute_actions_helper(
10331033
episodes=episodes,
10341034
)
10351035
else:
1036+
# Try `action_distribution_fn`.
10361037
if is_overridden(self.action_distribution_fn):
1037-
# Try new action_distribution_fn signature, supporting
1038-
# state_batches and seq_lens.
10391038
(
10401039
dist_inputs,
10411040
self.dist_class,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Tests, whether APPO can learn in a fault-tolerant fashion.
3+
4+
Workers will be configured to automatically get recreated upon failures (here: within
5+
the environment).
6+
The environment we use here is configured to crash with a certain probability on each
7+
`step()` and/or `reset()` call. Additionally, the environment is configured to stall
8+
with a configured probability on each `step()` call for a certain amount of time.
9+
"""
10+
from ray.rllib.algorithms.appo import APPOConfig
11+
from ray.rllib.examples.env.cartpole_crashing import CartPoleCrashing
12+
from ray import tune
13+
14+
tune.register_env("env", lambda cfg: CartPoleCrashing(cfg))
15+
16+
17+
stop = {
18+
"evaluation/sampler_results/episode_reward_mean": 400.0,
19+
"num_env_steps_sampled": 250000,
20+
}
21+
22+
config = (
23+
APPOConfig()
24+
.environment(
25+
"env",
26+
env_config={
27+
"p_crash": 0.0001, # prob to crash during step()
28+
"p_crash_reset": 0.001, # prob to crash during reset()
29+
"crash_on_worker_indices": [1, 2],
30+
"init_time_s": 2.0,
31+
"p_stall": 0.0005, # prob to stall during step()
32+
"p_stall_reset": 0.001, # prob to stall during reset()
33+
"stall_time_sec": (2, 5), # stall between 2 and 10sec.
34+
"stall_on_worker_indices": [2, 3],
35+
},
36+
# Disable env checking. Env checker doesn't handle Exceptions from
37+
# user envs, and will crash rollout worker.
38+
disable_env_checking=True,
39+
)
40+
.rollouts(
41+
num_rollout_workers=1,
42+
num_envs_per_worker=1,
43+
)
44+
# Switch on resiliency (recreate any failed worker).
45+
.fault_tolerance(
46+
recreate_failed_workers=True,
47+
)
48+
.evaluation(
49+
evaluation_num_workers=1,
50+
evaluation_interval=1,
51+
evaluation_duration=25,
52+
evaluation_duration_unit="episodes",
53+
evaluation_parallel_to_training=True,
54+
enable_async_evaluation=True,
55+
evaluation_config=APPOConfig.overrides(
56+
explore=False,
57+
env_config={
58+
# Make eval workers solid.
59+
# This test is to prove that we can learn with crashing envs,
60+
# not evaluate with crashing envs.
61+
"p_crash": 0.0,
62+
"p_crash_reset": 0.0,
63+
"init_time_s": 0.0,
64+
"p_stall": 0.0,
65+
"p_stall_reset": 0.0,
66+
},
67+
),
68+
)
69+
)
70+
71+
72+
# algo = config.framework("tf2").build()
73+
# for _ in range(1000):
74+
# print(algo.train())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
Tests, whether APPO can learn in a fault-tolerant fashion.
3+
4+
Workers will be configured to automatically get recreated upon failures (here: within
5+
the environment).
6+
The environment we use here is configured to crash with a certain probability on each
7+
`step()` and/or `reset()` call.
8+
"""
9+
from ray.rllib.algorithms.appo import APPOConfig
10+
from ray.rllib.examples.env.cartpole_crashing import CartPoleCrashing
11+
from ray import tune
12+
13+
tune.register_env("env", lambda cfg: CartPoleCrashing(cfg))
14+
15+
16+
stop = {
17+
"evaluation/sampler_results/episode_reward_mean": 400.0,
18+
"num_env_steps_sampled": 250000,
19+
}
20+
21+
config = (
22+
APPOConfig()
23+
.environment(
24+
"env",
25+
env_config={
26+
# Crash roughly every 500 ts.
27+
"p_crash": 0.0005, # prob to crash during step()
28+
"p_crash_reset": 0.005, # prob to crash during reset()
29+
"crash_on_worker_indices": [1, 2],
30+
},
31+
# Disable env checking. Env checker doesn't handle Exceptions from
32+
# user envs, and will crash rollout worker.
33+
disable_env_checking=True,
34+
)
35+
.rollouts(
36+
num_rollout_workers=3,
37+
num_envs_per_worker=1,
38+
)
39+
# Switch on resiliency (recreate any failed worker).
40+
.fault_tolerance(
41+
recreate_failed_workers=True,
42+
)
43+
.evaluation(
44+
evaluation_num_workers=1,
45+
evaluation_interval=1,
46+
evaluation_duration=25,
47+
evaluation_duration_unit="episodes",
48+
evaluation_parallel_to_training=True,
49+
enable_async_evaluation=True,
50+
evaluation_config=APPOConfig.overrides(
51+
explore=False,
52+
env_config={
53+
# Make eval workers solid.
54+
# This test is to prove that we can learn with crashing envs,
55+
# not evaluate with crashing envs.
56+
"p_crash": 0.0,
57+
"p_crash_reset": 0.0,
58+
"init_time_s": 0.0,
59+
},
60+
),
61+
)
62+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Tests, whether APPO can learn in a fault-tolerant fashion in a
3+
multi-agent setting.
4+
5+
Workers will be configured to automatically get recreated upon failures (here: within
6+
the environment).
7+
The environment we use here is configured to crash with a certain probability on each
8+
`step()` and/or `reset()` call.
9+
"""
10+
from ray.rllib.algorithms.appo import APPOConfig
11+
from ray.rllib.examples.env.cartpole_crashing import MultiAgentCartPoleCrashing
12+
from ray import tune
13+
14+
tune.register_env("ma_env", lambda cfg: MultiAgentCartPoleCrashing(cfg))
15+
16+
stop = {
17+
"evaluation/sampler_results/episode_reward_mean": 800.0,
18+
"num_env_steps_sampled": 250000,
19+
}
20+
21+
config = (
22+
APPOConfig()
23+
.environment(
24+
"ma_env",
25+
env_config={
26+
"num_agents": 2,
27+
# Crash roughly every 300 ts. This should be ok to measure 180.0
28+
# reward (episodes are 200 ts long).
29+
"p_crash": 0.00005, # prob to crash during step()
30+
"p_crash_reset": 0.0005, # prob to crash during reset()
31+
"init_time_s": 2.0,
32+
"p_stall": 0.001, # prob to stall during step()
33+
"p_stall_reset": 0.001, # prob to stall during reset()
34+
"stall_time_sec": (2, 5), # stall between 2 and 10sec.
35+
"stall_on_worker_indices": [2, 3],
36+
},
37+
# Disable env checking. Env checker doesn't handle Exceptions from
38+
# user envs, and will crash rollout worker.
39+
disable_env_checking=True,
40+
)
41+
.rollouts(
42+
num_rollout_workers=3,
43+
num_envs_per_worker=1,
44+
)
45+
# Switch on resiliency (recreate any failed worker).
46+
.fault_tolerance(
47+
recreate_failed_workers=True,
48+
)
49+
.evaluation(
50+
evaluation_num_workers=1,
51+
evaluation_interval=1,
52+
evaluation_duration=25,
53+
evaluation_duration_unit="episodes",
54+
evaluation_parallel_to_training=True,
55+
enable_async_evaluation=True,
56+
evaluation_config=APPOConfig.overrides(
57+
explore=False,
58+
env_config={
59+
# Make eval workers solid.
60+
# This test is to prove that we can learn with crashing envs,
61+
# not evaluate with crashing envs.
62+
"p_crash": 0.0,
63+
"p_crash_reset": 0.0,
64+
"init_time_s": 0.0,
65+
"p_stall": 0.0,
66+
"p_stall_reset": 0.0,
67+
},
68+
),
69+
)
70+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Tests, whether APPO can learn in a fault-tolerant fashion in a
3+
multi-agent setting.
4+
5+
Workers will be configured to automatically get recreated upon failures (here: within
6+
the environment).
7+
The environment we use here is configured to crash with a certain probability on each
8+
`step()` and/or `reset()` call.
9+
"""
10+
from ray.rllib.algorithms.appo import APPOConfig
11+
from ray.rllib.examples.env.cartpole_crashing import MultiAgentCartPoleCrashing
12+
from ray import tune
13+
14+
tune.register_env("ma_env", lambda cfg: MultiAgentCartPoleCrashing(cfg))
15+
16+
stop = {
17+
"evaluation/sampler_results/episode_reward_mean": 800.0,
18+
"num_env_steps_sampled": 250000,
19+
}
20+
21+
config = (
22+
APPOConfig()
23+
.environment(
24+
"ma_env",
25+
env_config={
26+
"num_agents": 2,
27+
# Crash roughly every 300 ts. This should be ok to measure 180.0
28+
# reward (episodes are 200 ts long).
29+
"p_crash": 0.0005, # prob to crash during step()
30+
"p_crash_reset": 0.005, # prob to crash during reset()
31+
},
32+
# Disable env checking. Env checker doesn't handle Exceptions from
33+
# user envs, and will crash rollout worker.
34+
disable_env_checking=True,
35+
)
36+
.rollouts(
37+
num_rollout_workers=4,
38+
num_envs_per_worker=1,
39+
)
40+
# Switch on resiliency (recreate any failed worker).
41+
.fault_tolerance(
42+
recreate_failed_workers=True,
43+
)
44+
.evaluation(
45+
evaluation_num_workers=1,
46+
evaluation_interval=1,
47+
evaluation_duration=25,
48+
evaluation_duration_unit="episodes",
49+
evaluation_parallel_to_training=True,
50+
enable_async_evaluation=True,
51+
evaluation_config=APPOConfig.overrides(
52+
explore=False,
53+
env_config={
54+
# Make eval workers solid.
55+
# This test is to prove that we can learn with crashing envs,
56+
# not evaluate with crashing envs.
57+
"p_crash": 0.0,
58+
"p_crash_reset": 0.0,
59+
"init_time_s": 0.0,
60+
},
61+
),
62+
)
63+
)

‎rllib/tuned_examples/appo/multi-agent-cartpole-crashing-restart-env-appo.yaml

-53
This file was deleted.

‎rllib/tuned_examples/pg/cartpole-crashing-pg.yaml

-45
This file was deleted.

‎rllib/tuned_examples/pg/cartpole-crashing-with-remote-envs-pg.yaml

-47
This file was deleted.

‎rllib/tuned_examples/pg/multi-agent-cartpole-crashing-restart-sub-envs-pg.yaml

-47
This file was deleted.

‎rllib/tuned_examples/pg/multi-agent-cartpole-crashing-with-remote-envs-pg.yaml

-49
This file was deleted.

‎rllib/utils/actor_manager.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def apply(
179179
except Exception as e:
180180
# Actor should be recreated by Ray.
181181
if self.config.recreate_failed_workers:
182-
logger.exception("Worker exception, recreating: {}".format(e))
182+
logger.exception(f"Worker exception caught during `apply()`: {e}")
183183
# Small delay to allow logs messages to propagate.
184184
time.sleep(self.config.delay_between_worker_restarts_s)
185185
# Kill this worker so Ray Core can restart it.
@@ -260,6 +260,7 @@ def __init__(
260260
# Actors are stored in a map and indexed by a unique id.
261261
self.__actors: Mapping[int, ActorHandle] = {}
262262
self.__remote_actor_states: Mapping[int, self._ActorState] = {}
263+
self.__restored_actors = set()
263264
self.add_actors(actors or [])
264265

265266
# Maps outstanding async requests to the ids of the actors that
@@ -328,6 +329,7 @@ def remove_actor(self, actor_id: int) -> ActorHandle:
328329
# Remove the actor from the pool.
329330
del self.__actors[actor_id]
330331
del self.__remote_actor_states[actor_id]
332+
self.__restored_actors.discard(actor_id)
331333
self._remove_async_state(actor_id)
332334

333335
return actor
@@ -376,6 +378,15 @@ def set_actor_state(self, actor_id: int, healthy: bool) -> None:
376378
"""
377379
if actor_id not in self.__remote_actor_states:
378380
raise ValueError(f"Unknown actor id: {actor_id}")
381+
382+
was_healthy = self.__remote_actor_states[actor_id].is_healthy
383+
# Set from unhealthy to healthy -> Add to restored set.
384+
if not was_healthy and healthy:
385+
self.__restored_actors.add(actor_id)
386+
# Set from healthy to unhealthy -> Remove from restored set.
387+
elif was_healthy and not healthy:
388+
self.__restored_actors.discard(actor_id)
389+
379390
self.__remote_actor_states[actor_id].is_healthy = healthy
380391

381392
if not healthy:
@@ -389,6 +400,7 @@ def clear(self):
389400
ray.kill(actor)
390401
self.__actors.clear()
391402
self.__remote_actor_states.clear()
403+
self.__restored_actors.clear()
392404
self.__in_flight_req_to_actor_id.clear()
393405

394406
def __call_actors(
@@ -487,8 +499,9 @@ def __fetch_result(
487499
result = ray.get(r)
488500
remote_results.add_result(actor_id, ResultOrError(result=result), tag)
489501

502+
# Actor came back from an unhealthy state. Mark this actor as healthy
503+
# and add it to our restored set.
490504
if mark_healthy and not self.is_actor_healthy(actor_id):
491-
# Yay, mark this actor as healthy.
492505
logger.info(f"brining actor {actor_id} back into service.")
493506
self.set_actor_state(actor_id, healthy=True)
494507
self._num_actor_restarts += 1
@@ -498,7 +511,7 @@ def __fetch_result(
498511

499512
# Mark the actor as unhealthy.
500513
# TODO(jungong): Using RayError here to preserve historical behavior.
501-
# It may very likely be better to use RayActorError here.
514+
# It may very likely be better to use RayActorError here.
502515
if isinstance(e, RayError):
503516
# Take this actor out of service and wait for Ray Core to
504517
# restore it.
@@ -790,17 +803,27 @@ def probe_unhealthy_actors(
790803
mark_healthy: Whether to mark actors healthy if they respond to the ping.
791804
792805
Returns:
793-
A list of actor ids that are restored.
806+
A list of actor IDs that were restored by the `ping` AND those actors that
807+
were previously restored via other remote requests. The cached set of
808+
such previously restored actors will be erased in this call.
794809
"""
810+
# Collect recently restored actors (from `self.__fetch_result` calls other than
811+
# the one triggered here via the `ping`).
812+
restored_actors = list(self.__restored_actors)
813+
self.__restored_actors.clear()
814+
815+
# Probe all unhealthy actors via a simple `ping()`.
795816
unhealthy_actor_ids = [
796817
actor_id
797818
for actor_id in self.actor_ids()
798819
if not self.is_actor_healthy(actor_id)
799820
]
821+
# No unhealthy actors currently -> Return recently restored ones.
800822
if not unhealthy_actor_ids:
801-
# Great, nothing to do.
802-
return []
823+
return restored_actors
803824

825+
# Some unhealthy actors -> `ping()` all of them to trigger a new fetch and
826+
# capture all restored ones.
804827
remote_results = self.foreach_actor(
805828
func=lambda actor: actor.ping(),
806829
remote_actor_ids=unhealthy_actor_ids,
@@ -809,7 +832,10 @@ def probe_unhealthy_actors(
809832
mark_healthy=mark_healthy,
810833
)
811834

812-
return [result.actor_id for result in remote_results if result.ok]
835+
# Return previously restored actors AND actors restored via the `ping()` call.
836+
return restored_actors + [
837+
result.actor_id for result in remote_results if result.ok
838+
]
813839

814840
def actors(self):
815841
# TODO(jungong) : remove this API once WorkerSet.remote_workers()

0 commit comments

Comments
 (0)
Please sign in to comment.