Skip to content

Commit

Permalink
Merge pull request #1088 from Louay-Ben-nessir/seb-ff-ippo-only
Browse files Browse the repository at this point in the history
feat: sebulba ff_ippo
  • Loading branch information
sash-a authored Nov 13, 2024
2 parents 83f207b + 097df80 commit 3df74da
Show file tree
Hide file tree
Showing 35 changed files with 1,723 additions and 57 deletions.
2 changes: 1 addition & 1 deletion mava/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
1 change: 1 addition & 0 deletions mava/configs/arch/anakin.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# --- Anakin config ---
architecture_name: anakin

# --- Training ---
num_envs: 16 # Number of vectorised environments per device.
Expand Down
25 changes: 25 additions & 0 deletions mava/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# --- Sebulba config ---
architecture_name: sebulba

# --- Training ---
num_envs: 32 # number of environments per thread.

# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
# an action which corresponds to the greatest logit. If false, the policy will sample
# from the logits.
num_eval_episodes: 32 # Number of episodes to evaluate per evaluation.
num_evaluation: 100 # Number of evenly spaced evaluations to perform during training.
num_absolute_metric_eval_episodes: 320 # Number of episodes to evaluate the absolute metric (the final evaluation).
absolute_metric: True # Whether the absolute metric should be computed. For more details
# on the absolute metric please see: https://arxiv.org/abs/2209.10485

# --- Sebulba devices config ---
n_threads_per_executor: 2 # num of different threads/env batches per actor
actor_device_ids: [0] # ids of actor devices
learner_device_ids: [0] # ids of learner devices
rollout_queue_size : 5
# The size of the pipeline queue determines the extent of off-policy training allowed. A larger value permits more off-policy training.
# Too large of a value with too many actors will lead to all of the updates getting wasted in old episodes
# Too small of a value and the utility of having multiple actors is lost.
# A value of 1 with a single actor leads to almost strictly on-policy training.
11 changes: 11 additions & 0 deletions mava/configs/default/ff_ippo_sebulba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- logger: logger
- arch: sebulba
- system: ppo/ff_ippo
- network: mlp # [mlp, continuous_mlp, cnn]
- env: lbf_gym # [rware_gym, lbf_gym, smaclite_gym]
- _self_

hydra:
searchpath:
- file://mava/configs
25 changes: 25 additions & 0 deletions mava/configs/env/lbf_gym.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ---Environment Configs---
defaults:
- _self_

env_name: LevelBasedForaging # Used for logging purposes.
scenario:
name: lbforaging
task_name: Foraging-8x8-2p-1f-v3

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used.
# This should not be changed.
implicit_agent_id: False
# Whether or not to log the winrate of this environment. This should not be changed as not all
# environments have a winrate metric.
log_win_rate: False

# Weather or not to sum the returned rewards over all of the agents.
use_shared_rewards: True

kwargs:
max_episode_steps: 100
25 changes: 25 additions & 0 deletions mava/configs/env/rware_gym.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ---Environment Configs---
defaults:
- _self_

env_name: RobotWarehouse # Used for logging purposes.
scenario:
name: rware
task_name: rware-tiny-2ag-v2 # [rware-tiny-2ag-v2, rware-tiny-4ag-v2, rware-tiny-4ag-easy-v2, rware-small-4ag-v2]

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used.
# This should not be changed.
implicit_agent_id: False
# Whether or not to log the winrate of this environment. This should not be changed as not all
# environments have a winrate metric.
log_win_rate: False

# Weather or not to sum the returned rewards over all of the agents.
use_shared_rewards: True

kwargs:
max_episode_steps: 500
25 changes: 25 additions & 0 deletions mava/configs/env/smaclite_gym.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ---Environment Configs---
defaults:
- _self_

env_name: SMACLite # Used for logging purposes.
scenario:
name: smaclite
task_name: smaclite/2s3z-v0 # smaclite/ + ['10m_vs_11m-v0', '27m_vs_30m-v0', '3s5z_vs_3s6z-v0', '2s3z-v0', '3s5z-v0', '2c_vs_64zg-v0', '2s_vs_1sc-v0', '3s_vs_5z-v0']

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used.
# This should not be changed.
implicit_agent_id: False
# Whether or not to log the winrate of this environment. This should not be changed as not all
# environments have a winrate metric.
log_win_rate: True

# Weather or not to sum the returned rewards over all of the agents.
use_shared_rewards: True

kwargs:
max_episode_steps: 500
115 changes: 115 additions & 0 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import jax
import jax.numpy as jnp
import numpy as np
from chex import Array, PRNGKey
from flax.core.frozen_dict import FrozenDict
from jax import tree
Expand All @@ -36,6 +37,7 @@
RecActorApply,
State,
)
from mava.wrappers.gym import GymToJumanji

# Optional extras that are passed out of the actor and then into the actor in the next step
ActorState: TypeAlias = Dict[str, Any]
Expand Down Expand Up @@ -207,3 +209,116 @@ def eval_act_fn(
return action.squeeze(0), {_hidden_state: hidden_state}

return eval_act_fn


def get_sebulba_eval_fn(
env_maker: Callable[[int, int], GymToJumanji],
act_fn: EvalActFn,
config: DictConfig,
np_rng: np.random.Generator,
absolute_metric: bool,
) -> Tuple[EvalFn, Any]:
"""Creates a function that can be used to evaluate agents on a given environment.
Args:
----
env_maker: A function to create the environment instances.
act_fn: A function that takes in params, timestep, key and optionally a state
and returns actions and optionally a state (see `EvalActFn`).
config: The system config.
np_rng: Random number generator for seeding environment.
absolute_metric: Whether or not this evaluator calculates the absolute_metric.
This determines how many evaluation episodes it does.
"""
n_devices = jax.device_count()
eval_episodes = (
config.arch.num_absolute_metric_eval_episodes
if absolute_metric
else config.arch.num_eval_episodes
)

n_parallel_envs = min(eval_episodes, config.arch.num_envs)
episode_loops = math.ceil(eval_episodes / n_parallel_envs)
env = env_maker(config, n_parallel_envs)

act_fn = jax.jit(
act_fn, device=jax.local_devices()[config.arch.actor_device_ids[0]]
) # Evaluate using the first actor device

# Warnings if num eval episodes is not divisible by num parallel envs.
if eval_episodes % n_parallel_envs != 0:
warnings.warn(
f"Number of evaluation episodes ({eval_episodes}) is not divisible by `num_envs` * "
f"`num_devices` ({n_parallel_envs} * {n_devices}). Some extra evaluations will be "
f"executed. New number of evaluation episodes = {episode_loops * n_parallel_envs}",
stacklevel=2,
)

def eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics:
"""Evaluates the given params on an environment and returns relevent metrics.
Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length,
also win rate for environments that support it.
Returns: Dict[str, Array] - dictionary of metric name to metric values for each episode.
"""

def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
"""Simulates `num_envs` episodes."""

# Generate a list of random seeds within the 32-bit integer range, using a seeded RNG.
seeds = np_rng.integers(np.iinfo(np.int32).max, size=n_parallel_envs).tolist()
ts = env.reset(seed=seeds)

timesteps_array = [ts]

actor_state = init_act_state
finished_eps = ts.last()

while not finished_eps.all():
key, act_key = jax.random.split(key)
action, actor_state = act_fn(params, ts, act_key, actor_state)
cpu_action = jax.device_get(action)
ts = env.step(cpu_action)
timesteps_array.append(ts)

finished_eps = np.logical_or(finished_eps, ts.last())

timesteps = jax.tree.map(lambda *x: np.stack(x), *timesteps_array)

metrics = timesteps.extras["episode_metrics"]
if config.env.log_win_rate:
metrics["won_episode"] = timesteps.extras["won_episode"]

# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
done_idx = np.argmax(timesteps.last(), axis=0)
metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # uneeded for logging

return key, metrics

# This loop is important because we don't want too many parallel envs.
# So in evaluation we have num_envs parallel envs and loop enough times
# so that we do at least `eval_episodes` number of episodes.
metrics_array = []
for _ in range(episode_loops):
key, metric = _episode(key)
metrics_array.append(metric)

# flatten metrics
metrics: Metrics = jax.tree_map(lambda *x: np.array(x).reshape(-1), *metrics_array)
return metrics

def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics:
"""Wrapper around eval function to time it and add in steps per second metric."""
start_time = time.time()

metrics = eval_fn(params, key, init_act_state)

end_time = time.time()
total_timesteps = jnp.sum(metrics["episode_length"])
metrics["steps_per_second"] = total_timesteps / (end_time - start_time)
return metrics

return timed_eval_fn, env
13 changes: 0 additions & 13 deletions mava/systems/__init__.py

This file was deleted.

4 changes: 2 additions & 2 deletions mava/systems/mat/anakin/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import (
merge_leading_dims,
unreplicate_batch_dim,
unreplicate_n_dims,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -355,7 +355,7 @@ def learner_setup(
init_x = env.observation_spec().generate_value()
init_x = tree.map(lambda x: x[None, ...], init_x)

_, action_space_type = get_action_head(env)
_, action_space_type = get_action_head(env.action_spec())

if action_space_type == "discrete":
init_action = jnp.zeros((1, config.system.num_agents), dtype=jnp.int32)
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import (
merge_leading_dims,
unreplicate_batch_dim,
unreplicate_n_dims,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -362,7 +362,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -346,7 +346,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -457,7 +457,7 @@ def learner_setup(
# Define network and optimisers.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -452,7 +452,7 @@ def learner_setup(
# Define network and optimiser.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
Loading

0 comments on commit 3df74da

Please sign in to comment.