From 2b5cf90d8dc07b18c0841cb6a5df93d6e9293ad7 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Thu, 15 Aug 2024 21:27:30 +0000 Subject: [PATCH] chore: more pre-commit chores --- stoix/evaluator.py | 6 ++++-- stoix/systems/ppo/sebulba/ff_ppo.py | 17 +++++++++-------- stoix/utils/make_env.py | 2 ++ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/stoix/evaluator.py b/stoix/evaluator.py index aa7b9a6..1bb2a59 100644 --- a/stoix/evaluator.py +++ b/stoix/evaluator.py @@ -1,6 +1,6 @@ import math import time -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import chex import flax.linen as nn @@ -358,6 +358,8 @@ def evaluator_setup( ##### THIS IS TEMPORARY +SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict[str, chex.Array]] + def get_sebulba_eval_fn( env_factory: EnvFactory, @@ -366,7 +368,7 @@ def get_sebulba_eval_fn( np_rng: np.random.Generator, device: jax.Device, eval_multiplier: float = 1.0, -) -> Tuple[EvalFn, Any]: +) -> Tuple[SebulbaEvalFn, Any]: """Creates a function that can be used to evaluate agents on a given environment. Args: diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index 64eba63..12d45de 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -25,7 +25,6 @@ ActorCriticParams, CriticApply, ExperimentOutput, - LearnerFn, LearnerState, Observation, SebulbaLearnerFn, @@ -62,7 +61,7 @@ def get_act_fn( def actor_fn( params: ActorCriticParams, observation: Observation, rng_key: chex.PRNGKey - ) -> Tuple[chex.Array, chex.Array]: + ) -> Tuple[chex.Array, chex.Array, chex.Array]: """Get the action, value and log_prob from the actor and critic networks.""" rng_key, policy_key = jax.random.split(rng_key) pi = actor_apply_fn(params.actor_params, observation) @@ -83,7 +82,7 @@ def get_rollout_fn( config: DictConfig, seeds: List[int], thread_lifetime: ThreadLifetime, -): +) -> Callable[[chex.PRNGKey], None]: """Get the rollout function that is used by the actor threads.""" # Unpack and set up the functions act_fn = get_act_fn(apply_fns) @@ -185,7 +184,7 @@ def get_actor_thread( seeds: List[int], thread_lifetime: ThreadLifetime, name: str, -): +) -> threading.Thread: """Get the actor thread that once started will collect data from the environment and send it to the pipeline.""" rng_key = jax.device_put(rng_key, actor_device) @@ -416,7 +415,7 @@ def get_learner_rollout_fn( eval_queue: Queue, pipeline: Pipeline, params_sources: Sequence[ParamsSource], -): +) -> Callable[[LearnerState], None]: """Get the learner rollout function that is used by the learner thread to update the networks. This function is what is actually run by the learner thread. It gets the data from the pipeline and uses the learner update function to update the networks. It then sends these intermediate @@ -434,7 +433,7 @@ def learner_rollout(learner_state: LearnerState) -> None: # Get the trajectory batch from the pipeline # This is blocking so it will wait until the pipeline has data. with RecordTimeTo(learn_timings["rollout_get_time"]): - traj_batch, timestep, rollout_time = pipeline.get(block=True) + traj_batch, timestep, rollout_time = pipeline.get(block=True) # type: ignore # We then replace the timestep in the learner state with the latest timestep # This means the learner has access to the entire trajectory as well as # an additional timestep which it can use to bootstrap. @@ -473,7 +472,7 @@ def get_learner_thread( eval_queue: Queue, pipeline: Pipeline, params_sources: Sequence[ParamsSource], -): +) -> threading.Thread: """Get the learner thread that is used to update the networks.""" learner_rollout_fn = get_learner_rollout_fn(learn, config, eval_queue, pipeline, params_sources) @@ -492,7 +491,9 @@ def learner_setup( keys: chex.Array, learner_devices: Sequence[jax.Device], config: DictConfig, -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: +) -> Tuple[ + SebulbaLearnerFn[LearnerState, PPOTransition], Tuple[ActorApply, CriticApply], LearnerState +]: """Setup for the learner state and networks.""" # Create a single environment just to get the observation and action specs. diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index 93899de..c11184a 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -443,3 +443,5 @@ def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]: return make_envpool_factory(env_name, config) elif "gymnasium" in suite_name: return make_gymnasium_factory(env_name, config) + else: + raise ValueError(f"{suite_name} is not a supported suite.")