From 38a61aa4f268e4b0e78cfa5785eeb91d3b98c3f7 Mon Sep 17 00:00:00 2001 From: Edan Toledo <42650996+EdanToledo@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:37:37 +0100 Subject: [PATCH] chore: refactor types (#106) refactor types to remove redundant variables, align future APIs. Additionally, add a new github workflow. --- .github/workflows/run_algs.yaml | 36 ++++++++++++ bash_scripts/run-algorithms.sh | 17 ++++++ stoix/base_types.py | 23 ++++++-- stoix/evaluator.py | 39 +++---------- stoix/systems/awr/ff_awr.py | 6 +- stoix/systems/awr/ff_awr_continuous.py | 6 +- stoix/systems/ddpg/ff_d4pg.py | 8 ++- stoix/systems/ddpg/ff_ddpg.py | 8 ++- stoix/systems/ddpg/ff_td3.py | 8 ++- stoix/systems/mpo/ff_mpo.py | 6 +- stoix/systems/mpo/ff_mpo_continuous.py | 6 +- stoix/systems/mpo/ff_vmpo.py | 6 +- stoix/systems/mpo/ff_vmpo_continuous.py | 6 +- stoix/systems/ppo/anakin/ff_dpo_continuous.py | 8 ++- stoix/systems/ppo/anakin/ff_ppo.py | 8 ++- stoix/systems/ppo/anakin/ff_ppo_continuous.py | 8 ++- stoix/systems/ppo/anakin/ff_ppo_penalty.py | 8 ++- .../ppo/anakin/ff_ppo_penalty_continuous.py | 8 ++- stoix/systems/ppo/anakin/rec_ppo.py | 6 +- stoix/systems/ppo/sebulba/ff_ppo.py | 58 ++++++++++--------- stoix/systems/q_learning/ff_c51.py | 8 ++- stoix/systems/q_learning/ff_ddqn.py | 8 ++- stoix/systems/q_learning/ff_dqn.py | 8 ++- stoix/systems/q_learning/ff_dqn_reg.py | 8 ++- stoix/systems/q_learning/ff_mdqn.py | 8 ++- stoix/systems/q_learning/ff_qr_dqn.py | 8 ++- stoix/systems/q_learning/ff_rainbow.py | 8 ++- stoix/systems/sac/ff_sac.py | 8 ++- stoix/systems/search/evaluator.py | 7 +-- stoix/systems/search/ff_az.py | 6 +- stoix/systems/search/ff_mz.py | 6 +- stoix/systems/search/ff_sampled_az.py | 6 +- stoix/systems/search/ff_sampled_mz.py | 6 +- stoix/systems/vpg/ff_reinforce.py | 8 ++- stoix/systems/vpg/ff_reinforce_continuous.py | 8 ++- stoix/utils/sebulba_utils.py | 26 ++++++++- 36 files changed, 260 insertions(+), 156 deletions(-) create mode 100644 .github/workflows/run_algs.yaml create mode 100644 bash_scripts/run-algorithms.sh diff --git a/.github/workflows/run_algs.yaml b/.github/workflows/run_algs.yaml new file mode 100644 index 00000000..b586bd72 --- /dev/null +++ b/.github/workflows/run_algs.yaml @@ -0,0 +1,36 @@ +name: Check Algorithms 🧪 + +on: + push: + branches: + - main + pull_request: + branches: + - main + workflow_dispatch: + +jobs: + test-algorithms: + name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" + runs-on: "${{ matrix.os }}" + timeout-minutes: 30 + + strategy: + matrix: + python-version: ["3.10"] + os: [ubuntu-latest] + + steps: + - name: Checkout stoix + uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "${{ matrix.python-version }}" + - name: Install python dependencies 🔧 + run: pip install . + + - name: Make Bash Script Executable + run: chmod +x bash_scripts/run-algorithms.sh + + - name: Run Bash Script + run: ./bash_scripts/run-algorithms.sh diff --git a/bash_scripts/run-algorithms.sh b/bash_scripts/run-algorithms.sh new file mode 100644 index 00000000..c4cee0a7 --- /dev/null +++ b/bash_scripts/run-algorithms.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +echo "Running All Algorithms..." + +python stoix/systems/ppo/anakin/ff_ppo.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/ppo/anakin/ff_ppo_continuous.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/q_learning/ff_dqn.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/q_learning/ff_ddqn.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/q_learning/ff_mdqn.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/q_learning/ff_c51.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/q_learning/ff_qr_dqn.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/sac/ff_sac.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/ddpg/ff_ddpg.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/ddpg/ff_td3.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/vpg/ff_reinforce.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/awr/ff_awr.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 +python stoix/systems/mpo/ff_mpo.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8 diff --git a/stoix/base_types.py b/stoix/base_types.py index 00ddcf43..af26199b 100644 --- a/stoix/base_types.py +++ b/stoix/base_types.py @@ -161,7 +161,14 @@ class OnlineAndTarget(NamedTuple): ) -class ExperimentOutput(NamedTuple, Generic[StoixState]): +class SebulbaExperimentOutput(NamedTuple, Generic[StoixState]): + """Experiment output.""" + + learner_state: StoixState + train_metrics: Dict[str, chex.Array] + + +class AnakinExperimentOutput(NamedTuple, Generic[StoixState]): """Experiment output.""" learner_state: StoixState @@ -169,10 +176,18 @@ class ExperimentOutput(NamedTuple, Generic[StoixState]): train_metrics: Dict[str, chex.Array] +class EvaluationOutput(NamedTuple, Generic[StoixState]): + """Evaluation output.""" + + learner_state: StoixState + episode_metrics: Dict[str, chex.Array] + + RNNObservation: TypeAlias = Tuple[Observation, Done] -LearnerFn = Callable[[StoixState], ExperimentOutput[StoixState]] -SebulbaLearnerFn = Callable[[StoixState, StoixTransition], ExperimentOutput[StoixState]] -EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[StoixState]] +LearnerFn = Callable[[StoixState], AnakinExperimentOutput[StoixState]] +SebulbaLearnerFn = Callable[[StoixState, StoixTransition], SebulbaExperimentOutput[StoixState]] +EvalFn = Callable[[FrozenDict, chex.PRNGKey], EvaluationOutput[StoixState]] +SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict[str, chex.Array]] ActorApply = Callable[..., DistributionLike] diff --git a/stoix/evaluator.py b/stoix/evaluator.py index 2ed3a100..2a868908 100644 --- a/stoix/evaluator.py +++ b/stoix/evaluator.py @@ -1,6 +1,6 @@ import math import time -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import chex import flax.linen as nn @@ -18,11 +18,12 @@ EnvFactory, EvalFn, EvalState, - ExperimentOutput, + EvaluationOutput, RecActFn, RecActorApply, RNNEvalState, RNNObservation, + SebulbaEvalFn, ) from stoix.utils.jax_utils import unreplicate_batch_dim @@ -133,7 +134,7 @@ def not_done(carry: Tuple) -> bool: return eval_metrics - def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOutput[EvalState]: + def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> EvaluationOutput[EvalState]: """Evaluator function.""" # Initialise environment states and timesteps. @@ -164,10 +165,9 @@ def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOut axis_name="eval_batch", )(trained_params, eval_state) - return ExperimentOutput( + return EvaluationOutput( learner_state=eval_state, episode_metrics=eval_metrics, - train_metrics={}, ) return evaluator_fn @@ -248,7 +248,7 @@ def not_done(carry: Tuple) -> bool: def evaluator_fn( trained_params: FrozenDict, key: chex.PRNGKey - ) -> ExperimentOutput[RNNEvalState]: + ) -> EvaluationOutput[RNNEvalState]: """Evaluator function.""" # Initialise environment states and timesteps. @@ -289,10 +289,9 @@ def evaluator_fn( axis_name="eval_batch", )(trained_params, eval_state) - return ExperimentOutput( + return EvaluationOutput( learner_state=eval_state, episode_metrics=eval_metrics, - train_metrics={}, ) return evaluator_fn @@ -356,11 +355,6 @@ def evaluator_setup( return evaluator, absolute_metric_evaluator, (trained_params, eval_keys) -##### THIS IS TEMPORARY - -SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict[str, chex.Array]] - - def get_sebulba_eval_fn( env_factory: EnvFactory, act_fn: ActFn, @@ -369,18 +363,7 @@ def get_sebulba_eval_fn( device: jax.Device, eval_multiplier: float = 1.0, ) -> Tuple[SebulbaEvalFn, Any]: - """Creates a function that can be used to evaluate agents on a given environment. - Args: - ---- - env: an environment that conforms to the mava environment spec. - act_fn: a function that takes in params, timestep, key and optionally a state - and returns actions and optionally a state (see `EvalActFn`). - config: the system config. - np_rng: a numpy random number generator. - eval_multiplier: a scalar that will increase the number of evaluation episodes - by a fixed factor. - """ eval_episodes = config.arch.num_eval_episodes * eval_multiplier # We calculate here the number of parallel envs we can run in parallel. @@ -405,14 +388,6 @@ def get_sebulba_eval_fn( print(f"{Fore.YELLOW}{Style.BRIGHT}{msg}{Style.RESET_ALL}") def eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Dict: - """Evaluates the given params on an environment and returns relevant metrics. - - Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length, - also win rate for environments that support it. - - Returns: Dict[str, Array] - dictionary of metric name to metric values for each episode. - """ - def _run_episodes(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]: """Simulates `num_envs` episodes.""" with jax.default_device(device): diff --git a/stoix/systems/awr/ff_awr.py b/stoix/systems/awr/ff_awr.py index 1b500336..98c03111 100644 --- a/stoix/systems/awr/ff_awr.py +++ b/stoix/systems/awr/ff_awr.py @@ -22,8 +22,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, LogEnvState, ) @@ -323,7 +323,7 @@ def _actor_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerState]: + def learner_fn(learner_state: AWRLearnerState) -> AnakinExperimentOutput[AWRLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -336,7 +336,7 @@ def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerSta learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/awr/ff_awr_continuous.py b/stoix/systems/awr/ff_awr_continuous.py index b83686f5..f389720b 100644 --- a/stoix/systems/awr/ff_awr_continuous.py +++ b/stoix/systems/awr/ff_awr_continuous.py @@ -22,8 +22,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, LogEnvState, ) @@ -323,7 +323,7 @@ def _actor_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerState]: + def learner_fn(learner_state: AWRLearnerState) -> AnakinExperimentOutput[AWRLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -336,7 +336,7 @@ def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerSta learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ddpg/ff_d4pg.py b/stoix/systems/ddpg/ff_d4pg.py index 74b759e3..c5016a8c 100644 --- a/stoix/systems/ddpg/ff_d4pg.py +++ b/stoix/systems/ddpg/ff_d4pg.py @@ -20,8 +20,8 @@ from stoix.base_types import ( ActorApply, + AnakinExperimentOutput, ContinuousQApply, - ExperimentOutput, LearnerFn, LogEnvState, Observation, @@ -347,7 +347,9 @@ def _actor_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -360,7 +362,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ddpg/ff_ddpg.py b/stoix/systems/ddpg/ff_ddpg.py index 313934b1..965f7310 100644 --- a/stoix/systems/ddpg/ff_ddpg.py +++ b/stoix/systems/ddpg/ff_ddpg.py @@ -20,8 +20,8 @@ from stoix.base_types import ( ActorApply, + AnakinExperimentOutput, ContinuousQApply, - ExperimentOutput, LearnerFn, LogEnvState, Observation, @@ -309,7 +309,9 @@ def _actor_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -323,7 +325,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ddpg/ff_td3.py b/stoix/systems/ddpg/ff_td3.py index 09e7bc09..b4dc8abd 100644 --- a/stoix/systems/ddpg/ff_td3.py +++ b/stoix/systems/ddpg/ff_td3.py @@ -20,8 +20,8 @@ from stoix.base_types import ( ActorApply, + AnakinExperimentOutput, ContinuousQApply, - ExperimentOutput, LearnerFn, LogEnvState, Observation, @@ -327,7 +327,9 @@ def _actor_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -341,7 +343,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/mpo/ff_mpo.py b/stoix/systems/mpo/ff_mpo.py index 00a96bd1..399f9044 100644 --- a/stoix/systems/mpo/ff_mpo.py +++ b/stoix/systems/mpo/ff_mpo.py @@ -20,8 +20,8 @@ from stoix.base_types import ( ActorApply, + AnakinExperimentOutput, ContinuousQApply, - ExperimentOutput, LearnerFn, LogEnvState, OnlineAndTarget, @@ -406,7 +406,7 @@ def _q_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerState]: + def learner_fn(learner_state: MPOLearnerState) -> AnakinExperimentOutput[MPOLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -419,7 +419,7 @@ def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerSta learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/mpo/ff_mpo_continuous.py b/stoix/systems/mpo/ff_mpo_continuous.py index e56121ef..455bfeca 100644 --- a/stoix/systems/mpo/ff_mpo_continuous.py +++ b/stoix/systems/mpo/ff_mpo_continuous.py @@ -20,8 +20,8 @@ from stoix.base_types import ( ActorApply, + AnakinExperimentOutput, ContinuousQApply, - ExperimentOutput, LearnerFn, LogEnvState, OnlineAndTarget, @@ -422,7 +422,7 @@ def _q_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerState]: + def learner_fn(learner_state: MPOLearnerState) -> AnakinExperimentOutput[MPOLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -435,7 +435,7 @@ def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerSta learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/mpo/ff_vmpo.py b/stoix/systems/mpo/ff_vmpo.py index 052b36cd..b5d4d963 100644 --- a/stoix/systems/mpo/ff_vmpo.py +++ b/stoix/systems/mpo/ff_vmpo.py @@ -17,8 +17,8 @@ from stoix.base_types import ( ActorApply, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, OnlineAndTarget, ) @@ -305,7 +305,7 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: VMPOLearnerState) -> ExperimentOutput[VMPOLearnerState]: + def learner_fn(learner_state: VMPOLearnerState) -> AnakinExperimentOutput[VMPOLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -318,7 +318,7 @@ def learner_fn(learner_state: VMPOLearnerState) -> ExperimentOutput[VMPOLearnerS learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/mpo/ff_vmpo_continuous.py b/stoix/systems/mpo/ff_vmpo_continuous.py index 5b5b3690..de34a882 100644 --- a/stoix/systems/mpo/ff_vmpo_continuous.py +++ b/stoix/systems/mpo/ff_vmpo_continuous.py @@ -18,8 +18,8 @@ from stoix.base_types import ( ActorApply, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, OnlineAndTarget, ) @@ -362,7 +362,7 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: VMPOLearnerState) -> ExperimentOutput[VMPOLearnerState]: + def learner_fn(learner_state: VMPOLearnerState) -> AnakinExperimentOutput[VMPOLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -375,7 +375,7 @@ def learner_fn(learner_state: VMPOLearnerState) -> ExperimentOutput[VMPOLearnerS learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ppo/anakin/ff_dpo_continuous.py b/stoix/systems/ppo/anakin/ff_dpo_continuous.py index 56764047..5f0bc8ca 100644 --- a/stoix/systems/ppo/anakin/ff_dpo_continuous.py +++ b/stoix/systems/ppo/anakin/ff_dpo_continuous.py @@ -18,8 +18,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, OnPolicyLearnerState, ) @@ -280,7 +280,9 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: + def learner_fn( + learner_state: OnPolicyLearnerState, + ) -> AnakinExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -301,7 +303,7 @@ def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicy learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ppo/anakin/ff_ppo.py b/stoix/systems/ppo/anakin/ff_ppo.py index db729d98..87443e64 100644 --- a/stoix/systems/ppo/anakin/ff_ppo.py +++ b/stoix/systems/ppo/anakin/ff_ppo.py @@ -18,8 +18,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, OnPolicyLearnerState, ) @@ -275,7 +275,9 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: + def learner_fn( + learner_state: OnPolicyLearnerState, + ) -> AnakinExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -296,7 +298,7 @@ def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicy learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ppo/anakin/ff_ppo_continuous.py b/stoix/systems/ppo/anakin/ff_ppo_continuous.py index bdb4bbb0..a2298b79 100644 --- a/stoix/systems/ppo/anakin/ff_ppo_continuous.py +++ b/stoix/systems/ppo/anakin/ff_ppo_continuous.py @@ -18,8 +18,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, OnPolicyLearnerState, ) @@ -281,7 +281,9 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: + def learner_fn( + learner_state: OnPolicyLearnerState, + ) -> AnakinExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -302,7 +304,7 @@ def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicy learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ppo/anakin/ff_ppo_penalty.py b/stoix/systems/ppo/anakin/ff_ppo_penalty.py index b3240620..407b2f32 100644 --- a/stoix/systems/ppo/anakin/ff_ppo_penalty.py +++ b/stoix/systems/ppo/anakin/ff_ppo_penalty.py @@ -18,8 +18,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, OnPolicyLearnerState, ) @@ -284,7 +284,9 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: + def learner_fn( + learner_state: OnPolicyLearnerState, + ) -> AnakinExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -305,7 +307,7 @@ def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicy learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ppo/anakin/ff_ppo_penalty_continuous.py b/stoix/systems/ppo/anakin/ff_ppo_penalty_continuous.py index 46e0285c..a5edbae4 100644 --- a/stoix/systems/ppo/anakin/ff_ppo_penalty_continuous.py +++ b/stoix/systems/ppo/anakin/ff_ppo_penalty_continuous.py @@ -18,8 +18,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, OnPolicyLearnerState, ) @@ -289,7 +289,9 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: + def learner_fn( + learner_state: OnPolicyLearnerState, + ) -> AnakinExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -310,7 +312,7 @@ def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicy learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ppo/anakin/rec_ppo.py b/stoix/systems/ppo/anakin/rec_ppo.py index be6ac789..5a1287a4 100644 --- a/stoix/systems/ppo/anakin/rec_ppo.py +++ b/stoix/systems/ppo/anakin/rec_ppo.py @@ -17,7 +17,7 @@ from stoix.base_types import ( ActorCriticOptStates, ActorCriticParams, - ExperimentOutput, + AnakinExperimentOutput, LearnerFn, RecActorApply, RecCriticApply, @@ -400,7 +400,7 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerState]: + def learner_fn(learner_state: RNNLearnerState) -> AnakinExperimentOutput[RNNLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -423,7 +423,7 @@ def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerSta learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index 934a37c6..dce260a3 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -25,8 +25,8 @@ ActorCriticParams, CoreLearnerState, CriticApply, - ExperimentOutput, Observation, + SebulbaExperimentOutput, SebulbaLearnerFn, ) from stoix.evaluator import get_distribution_act_fn, get_sebulba_eval_fn @@ -105,20 +105,21 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None: # Create the list to store transitions traj: List[PPOTransition] = [] # Create the dictionary to store timings for metrics - timings_dict: Dict[str, List[float]] = defaultdict(list) + actor_timings_dict: Dict[str, List[float]] = defaultdict(list) + episode_metrics: List[Dict[str, List[float]]] = [] # Rollout the environment - with RecordTimeTo(timings_dict["single_rollout_time"]): + with RecordTimeTo(actor_timings_dict["single_rollout_time"]): # Loop until the rollout length is reached for _ in range(config.system.rollout_length): # Get the latest parameters from the source - with RecordTimeTo(timings_dict["get_params_time"]): + with RecordTimeTo(actor_timings_dict["get_params_time"]): params = params_source.get() # Move the environment data to the actor device cached_obs = move_to_device(timestep.observation) # Run the actor and critic networks to get the action, value and log_prob - with RecordTimeTo(timings_dict["compute_action_time"]): + with RecordTimeTo(actor_timings_dict["compute_action_time"]): rng_key, policy_key = split_key_fn(rng_key) action, value, log_prob = act_fn(params, cached_obs, policy_key) @@ -126,7 +127,7 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None: action_cpu = np.asarray(jax.device_put(action, cpu)) # Step the environment - with RecordTimeTo(timings_dict["env_step_time"]): + with RecordTimeTo(actor_timings_dict["env_step_time"]): timestep = envs.step(action_cpu) # Get the next dones and truncation flags @@ -154,11 +155,12 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None: metrics, ) ) + episode_metrics.append(metrics) # Send the trajectory to the pipeline - with RecordTimeTo(timings_dict["rollout_put_time"]): + with RecordTimeTo(actor_timings_dict["rollout_put_time"]): try: - pipeline.put(traj, timestep, timings_dict) + pipeline.put(traj, timestep, actor_timings_dict, episode_metrics) except queue.Full: warnings.warn( "Waited too long to add to the rollout queue, killing the actor thread", @@ -222,7 +224,7 @@ def get_learner_step_fn( def _update_step( learner_state: CoreLearnerState, traj_batch: PPOTransition - ) -> Tuple[CoreLearnerState, Tuple]: + ) -> Tuple[CoreLearnerState, Dict[str, chex.Array]]: # CALCULATE ADVANTAGE params, opt_states, key, last_timestep = learner_state @@ -376,12 +378,12 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = CoreLearnerState(params, opt_states, key, last_timestep) - metrics = traj_batch.info - return learner_state, (metrics, loss_info) + + return learner_state, loss_info def learner_step_fn( learner_state: CoreLearnerState, traj_batch: PPOTransition - ) -> ExperimentOutput[CoreLearnerState]: + ) -> SebulbaExperimentOutput[CoreLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -397,11 +399,10 @@ def learner_step_fn( - timesteps (TimeStep): The initial timestep in the initial trajectory. """ - learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) + learner_state, loss_info = _update_step(learner_state, traj_batch) - return ExperimentOutput( + return SebulbaExperimentOutput( learner_state=learner_state, - episode_metrics=episode_info, train_metrics=loss_info, ) @@ -425,28 +426,33 @@ def learner_rollout(learner_state: CoreLearnerState) -> None: for _ in range(config.arch.num_evaluation): # Create the lists to store metrics and timings for this learning iteration. metrics: List[Tuple[Dict, Dict]] = [] - rollout_times: List[Dict] = [] + actor_timings: List[Dict] = [] + learner_timings: Dict[str, List[float]] = defaultdict(list) q_sizes: List[int] = [] - learn_timings: Dict[str, List[float]] = defaultdict(list) # Loop for the number of updates per evaluation for _ in range(config.arch.num_updates_per_eval): # Get the trajectory batch from the pipeline # This is blocking so it will wait until the pipeline has data. - with RecordTimeTo(learn_timings["rollout_get_time"]): - traj_batch, timestep, rollout_time = pipeline.get(block=True) # type: ignore + with RecordTimeTo(learner_timings["rollout_get_time"]): + ( + traj_batch, + timestep, + actor_times, + episode_metrics, + ) = pipeline.get( # type: ignore + block=True + ) # We then replace the timestep in the learner state with the latest timestep # This means the learner has access to the entire trajectory as well as # an additional timestep which it can use to bootstrap. learner_state = learner_state._replace(timestep=timestep) # We then call the update function to update the networks - with RecordTimeTo(learn_timings["learning_time"]): - learner_state, episode_metrics, train_metrics = learn_step( - learner_state, traj_batch - ) + with RecordTimeTo(learner_timings["learning_time"]): + learner_state, train_metrics = learn_step(learner_state, traj_batch) # We store the metrics and timings for this update metrics.append((episode_metrics, train_metrics)) - rollout_times.append(rollout_time) + actor_timings.append(actor_times) q_sizes.append(pipeline.qsize()) # After the update we need to update the params sources with the new params @@ -460,8 +466,8 @@ def learner_rollout(learner_state: CoreLearnerState) -> None: # and timings to the evaluation queue. This is so the evaluator correctly evaluates # the performance of the networks at this point in time. episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics) - rollout_times = jax.tree.map(lambda *x: np.mean(x), *rollout_times) - timing_dict = rollout_times | learn_timings + actor_timings = jax.tree.map(lambda *x: np.mean(x), *actor_timings) + timing_dict = actor_timings | learner_timings timing_dict["pipeline_qsize"] = q_sizes timing_dict = jax.tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) try: diff --git a/stoix/systems/q_learning/ff_c51.py b/stoix/systems/q_learning/ff_c51.py index 073d2dfd..aed3afd2 100644 --- a/stoix/systems/q_learning/ff_c51.py +++ b/stoix/systems/q_learning/ff_c51.py @@ -32,7 +32,7 @@ from stoix.base_types import ( ActorApply, - ExperimentOutput, + AnakinExperimentOutput, LearnerFn, LogEnvState, Observation, @@ -240,7 +240,9 @@ def _q_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -253,7 +255,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/q_learning/ff_ddqn.py b/stoix/systems/q_learning/ff_ddqn.py index b7f45b06..bc48a80f 100644 --- a/stoix/systems/q_learning/ff_ddqn.py +++ b/stoix/systems/q_learning/ff_ddqn.py @@ -19,7 +19,7 @@ from stoix.base_types import ( ActorApply, - ExperimentOutput, + AnakinExperimentOutput, LearnerFn, LogEnvState, OffPolicyLearnerState, @@ -235,7 +235,9 @@ def _q_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -249,7 +251,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/q_learning/ff_dqn.py b/stoix/systems/q_learning/ff_dqn.py index cf0de979..f380b33e 100644 --- a/stoix/systems/q_learning/ff_dqn.py +++ b/stoix/systems/q_learning/ff_dqn.py @@ -19,7 +19,7 @@ from stoix.base_types import ( ActorApply, - ExperimentOutput, + AnakinExperimentOutput, LearnerFn, LogEnvState, OffPolicyLearnerState, @@ -234,7 +234,9 @@ def _q_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -248,7 +250,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/q_learning/ff_dqn_reg.py b/stoix/systems/q_learning/ff_dqn_reg.py index 47df21bf..9d593eb7 100644 --- a/stoix/systems/q_learning/ff_dqn_reg.py +++ b/stoix/systems/q_learning/ff_dqn_reg.py @@ -19,7 +19,7 @@ from stoix.base_types import ( ActorApply, - ExperimentOutput, + AnakinExperimentOutput, LearnerFn, LogEnvState, OffPolicyLearnerState, @@ -238,7 +238,9 @@ def _q_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -252,7 +254,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/q_learning/ff_mdqn.py b/stoix/systems/q_learning/ff_mdqn.py index cb154988..9a1c762d 100644 --- a/stoix/systems/q_learning/ff_mdqn.py +++ b/stoix/systems/q_learning/ff_mdqn.py @@ -19,7 +19,7 @@ from stoix.base_types import ( ActorApply, - ExperimentOutput, + AnakinExperimentOutput, LearnerFn, LogEnvState, OffPolicyLearnerState, @@ -238,7 +238,9 @@ def _q_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -252,7 +254,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/q_learning/ff_qr_dqn.py b/stoix/systems/q_learning/ff_qr_dqn.py index e5be23a8..0adb2419 100644 --- a/stoix/systems/q_learning/ff_qr_dqn.py +++ b/stoix/systems/q_learning/ff_qr_dqn.py @@ -32,7 +32,7 @@ from stoix.base_types import ( ActorApply, - ExperimentOutput, + AnakinExperimentOutput, LearnerFn, LogEnvState, Observation, @@ -254,7 +254,9 @@ def _q_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -267,7 +269,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/q_learning/ff_rainbow.py b/stoix/systems/q_learning/ff_rainbow.py index 52426c68..51db4cd2 100644 --- a/stoix/systems/q_learning/ff_rainbow.py +++ b/stoix/systems/q_learning/ff_rainbow.py @@ -33,7 +33,7 @@ from stoix.base_types import ( ActorApply, - ExperimentOutput, + AnakinExperimentOutput, LearnerFn, LogEnvState, Observation, @@ -302,7 +302,9 @@ def _q_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -315,7 +317,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/sac/ff_sac.py b/stoix/systems/sac/ff_sac.py index 859fecf6..3aad9cc8 100644 --- a/stoix/systems/sac/ff_sac.py +++ b/stoix/systems/sac/ff_sac.py @@ -19,8 +19,8 @@ from stoix.base_types import ( ActorApply, + AnakinExperimentOutput, ContinuousQApply, - ExperimentOutput, LearnerFn, LogEnvState, OffPolicyLearnerState, @@ -321,7 +321,9 @@ def _actor_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]: + def learner_fn( + learner_state: OffPolicyLearnerState, + ) -> AnakinExperimentOutput[OffPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -334,7 +336,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/search/evaluator.py b/stoix/systems/search/evaluator.py index 3ca7fba3..ae1131d8 100644 --- a/stoix/systems/search/evaluator.py +++ b/stoix/systems/search/evaluator.py @@ -7,7 +7,7 @@ from jumanji.env import Environment from omegaconf import DictConfig -from stoix.base_types import EvalFn, EvalState, ExperimentOutput +from stoix.base_types import EvalFn, EvalState, EvaluationOutput from stoix.systems.search.search_types import RootFnApply, SearchApply from stoix.utils.jax_utils import unreplicate_batch_dim @@ -68,7 +68,7 @@ def not_done(carry: Tuple) -> bool: return eval_metrics - def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOutput[EvalState]: + def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> EvaluationOutput[EvalState]: """Evaluator function.""" # Initialise environment states and timesteps. @@ -99,10 +99,9 @@ def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOut axis_name="eval_batch", )(trained_params, eval_state) - return ExperimentOutput( + return EvaluationOutput( learner_state=eval_state, episode_metrics=eval_metrics, - train_metrics={}, ) return evaluator_fn diff --git a/stoix/systems/search/ff_az.py b/stoix/systems/search/ff_az.py index 1e236f33..31c518c7 100644 --- a/stoix/systems/search/ff_az.py +++ b/stoix/systems/search/ff_az.py @@ -25,8 +25,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, LogEnvState, ) @@ -354,7 +354,7 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: ZLearnerState) -> ExperimentOutput[ZLearnerState]: + def learner_fn(learner_state: ZLearnerState) -> AnakinExperimentOutput[ZLearnerState]: """Learner function.""" batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") @@ -362,7 +362,7 @@ def learner_fn(learner_state: ZLearnerState) -> ExperimentOutput[ZLearnerState]: learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/search/ff_mz.py b/stoix/systems/search/ff_mz.py index 5a7b322c..34fe2183 100644 --- a/stoix/systems/search/ff_mz.py +++ b/stoix/systems/search/ff_mz.py @@ -23,9 +23,9 @@ from stoix.base_types import ( ActorApply, ActorCriticParams, + AnakinExperimentOutput, CriticApply, DistributionCriticApply, - ExperimentOutput, LearnerFn, LogEnvState, ) @@ -430,7 +430,7 @@ def unroll_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: ZLearnerState) -> ExperimentOutput[ZLearnerState]: + def learner_fn(learner_state: ZLearnerState) -> AnakinExperimentOutput[ZLearnerState]: """Learner function.""" batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") @@ -438,7 +438,7 @@ def learner_fn(learner_state: ZLearnerState) -> ExperimentOutput[ZLearnerState]: learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/search/ff_sampled_az.py b/stoix/systems/search/ff_sampled_az.py index 14701b8f..72e2ce31 100644 --- a/stoix/systems/search/ff_sampled_az.py +++ b/stoix/systems/search/ff_sampled_az.py @@ -25,8 +25,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, LogEnvState, ) @@ -488,7 +488,7 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: ZLearnerState) -> ExperimentOutput[ZLearnerState]: + def learner_fn(learner_state: ZLearnerState) -> AnakinExperimentOutput[ZLearnerState]: """Learner function.""" batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") @@ -496,7 +496,7 @@ def learner_fn(learner_state: ZLearnerState) -> ExperimentOutput[ZLearnerState]: learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/search/ff_sampled_mz.py b/stoix/systems/search/ff_sampled_mz.py index e41b3b50..3e85cc80 100644 --- a/stoix/systems/search/ff_sampled_mz.py +++ b/stoix/systems/search/ff_sampled_mz.py @@ -23,9 +23,9 @@ from stoix.base_types import ( ActorApply, ActorCriticParams, + AnakinExperimentOutput, CriticApply, DistributionCriticApply, - ExperimentOutput, LearnerFn, LogEnvState, ) @@ -561,7 +561,7 @@ def unroll_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: ZLearnerState) -> ExperimentOutput[ZLearnerState]: + def learner_fn(learner_state: ZLearnerState) -> AnakinExperimentOutput[ZLearnerState]: """Learner function.""" batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") @@ -569,7 +569,7 @@ def learner_fn(learner_state: ZLearnerState) -> ExperimentOutput[ZLearnerState]: learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/vpg/ff_reinforce.py b/stoix/systems/vpg/ff_reinforce.py index 5b99ae75..d8afda1e 100644 --- a/stoix/systems/vpg/ff_reinforce.py +++ b/stoix/systems/vpg/ff_reinforce.py @@ -19,8 +19,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, OnPolicyLearnerState, ) @@ -200,14 +200,16 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: + def learner_fn( + learner_state: OnPolicyLearnerState, + ) -> AnakinExperimentOutput[OnPolicyLearnerState]: batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/systems/vpg/ff_reinforce_continuous.py b/stoix/systems/vpg/ff_reinforce_continuous.py index df0d9af0..935abcb8 100644 --- a/stoix/systems/vpg/ff_reinforce_continuous.py +++ b/stoix/systems/vpg/ff_reinforce_continuous.py @@ -19,8 +19,8 @@ ActorApply, ActorCriticOptStates, ActorCriticParams, + AnakinExperimentOutput, CriticApply, - ExperimentOutput, LearnerFn, OnPolicyLearnerState, ) @@ -198,14 +198,16 @@ def _critic_loss_fn( metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: + def learner_fn( + learner_state: OnPolicyLearnerState, + ) -> AnakinExperimentOutput[OnPolicyLearnerState]: batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") learner_state, (episode_info, loss_info) = jax.lax.scan( batched_update_step, learner_state, None, config.arch.num_updates_per_eval ) - return ExperimentOutput( + return AnakinExperimentOutput( learner_state=learner_state, episode_metrics=episode_info, train_metrics=loss_info, diff --git a/stoix/utils/sebulba_utils.py b/stoix/utils/sebulba_utils.py index b912eaba..6d35ab81 100644 --- a/stoix/utils/sebulba_utils.py +++ b/stoix/utils/sebulba_utils.py @@ -62,7 +62,13 @@ def run(self) -> None: except queue.Empty: continue - def put(self, traj: Sequence[StoixTransition], timestep: TimeStep, timings_dict: Dict) -> None: + def put( + self, + traj: Sequence[StoixTransition], + timestep: TimeStep, + actor_timings_dict: Dict[str, List[float]], + actor_episode_metrics: List[Dict[str, List[float]]], + ) -> None: """Put a trajectory on the queue to be consumed by the learner.""" start_condition, end_condition = (threading.Condition(), threading.Condition()) with start_condition: @@ -78,6 +84,9 @@ def put(self, traj: Sequence[StoixTransition], timestep: TimeStep, timings_dict: # [(num_envs / num_learner_devices, ...)] * num_learner_devices sharded_timestep = jax.tree.map(self.shard_split_playload, timestep) + # Concatenate metrics - List[Dict[str, List[float]]] --> Dict[str, List[float]] + actor_episode_metrics = self.concatenate_metrics(actor_episode_metrics) + # We block on the put to ensure that actors wait for the learners to catch up. This does two # things: # 1. It ensures that the actors don't get too far ahead of the learners, which could lead to @@ -89,7 +98,11 @@ def put(self, traj: Sequence[StoixTransition], timestep: TimeStep, timings_dict: # operation. We use a try-finally since the lock has to be released even if an exception # is raised. try: - self._queue.put((sharded_traj, sharded_timestep, timings_dict), block=True, timeout=180) + self._queue.put( + (sharded_traj, sharded_timestep, actor_timings_dict, actor_episode_metrics), + block=True, + timeout=180, + ) except queue.Full: print( f"{Fore.RED}{Style.BRIGHT}Pipeline is full and actor has timed out, " @@ -105,7 +118,7 @@ def qsize(self) -> int: def get( self, block: bool = True, timeout: Union[float, None] = None - ) -> Tuple[StoixTransition, TimeStep, Dict]: + ) -> Tuple[StoixTransition, TimeStep, Dict[str, List[float]], Dict[str, List[float]]]: """Get a trajectory from the pipeline.""" return self._queue.get(block, timeout) # type: ignore @@ -115,6 +128,13 @@ def stack_trajectory(self, trajectory: List[StoixTransition]) -> StoixTransition transition of shape [rollout_len, num_envs, ...].""" return jax.tree_map(lambda *x: jnp.stack(x, axis=0), *trajectory) # type: ignore + @partial(jax.jit, static_argnums=(0,)) + def concatenate_metrics( + self, actor_metrics: List[Dict[str, List[float]]] + ) -> Dict[str, List[float]]: + """Concatenate a list of actor metrics into a single dictionary.""" + return jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *actor_metrics) # type: ignore + def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: split_payload = jnp.split(payload, len(self.learner_devices), axis=axis) return jax.device_put_sharded(split_payload, devices=self.learner_devices)