Skip to content

Commit

Permalink
chore: more pre-commit chores
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 15, 2024
1 parent 7684678 commit 2b5cf90
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
6 changes: 4 additions & 2 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import time
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import chex
import flax.linen as nn
Expand Down Expand Up @@ -358,6 +358,8 @@ def evaluator_setup(

##### THIS IS TEMPORARY

SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict[str, chex.Array]]


def get_sebulba_eval_fn(
env_factory: EnvFactory,
Expand All @@ -366,7 +368,7 @@ def get_sebulba_eval_fn(
np_rng: np.random.Generator,
device: jax.Device,
eval_multiplier: float = 1.0,
) -> Tuple[EvalFn, Any]:
) -> Tuple[SebulbaEvalFn, Any]:
"""Creates a function that can be used to evaluate agents on a given environment.
Args:
Expand Down
17 changes: 9 additions & 8 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
ActorCriticParams,
CriticApply,
ExperimentOutput,
LearnerFn,
LearnerState,
Observation,
SebulbaLearnerFn,
Expand Down Expand Up @@ -62,7 +61,7 @@ def get_act_fn(

def actor_fn(
params: ActorCriticParams, observation: Observation, rng_key: chex.PRNGKey
) -> Tuple[chex.Array, chex.Array]:
) -> Tuple[chex.Array, chex.Array, chex.Array]:
"""Get the action, value and log_prob from the actor and critic networks."""
rng_key, policy_key = jax.random.split(rng_key)
pi = actor_apply_fn(params.actor_params, observation)
Expand All @@ -83,7 +82,7 @@ def get_rollout_fn(
config: DictConfig,
seeds: List[int],
thread_lifetime: ThreadLifetime,
):
) -> Callable[[chex.PRNGKey], None]:
"""Get the rollout function that is used by the actor threads."""
# Unpack and set up the functions
act_fn = get_act_fn(apply_fns)
Expand Down Expand Up @@ -185,7 +184,7 @@ def get_actor_thread(
seeds: List[int],
thread_lifetime: ThreadLifetime,
name: str,
):
) -> threading.Thread:
"""Get the actor thread that once started will collect data from the
environment and send it to the pipeline."""
rng_key = jax.device_put(rng_key, actor_device)
Expand Down Expand Up @@ -416,7 +415,7 @@ def get_learner_rollout_fn(
eval_queue: Queue,
pipeline: Pipeline,
params_sources: Sequence[ParamsSource],
):
) -> Callable[[LearnerState], None]:
"""Get the learner rollout function that is used by the learner thread to update the networks.
This function is what is actually run by the learner thread. It gets the data from the pipeline
and uses the learner update function to update the networks. It then sends these intermediate
Expand All @@ -434,7 +433,7 @@ def learner_rollout(learner_state: LearnerState) -> None:
# Get the trajectory batch from the pipeline
# This is blocking so it will wait until the pipeline has data.
with RecordTimeTo(learn_timings["rollout_get_time"]):
traj_batch, timestep, rollout_time = pipeline.get(block=True)
traj_batch, timestep, rollout_time = pipeline.get(block=True) # type: ignore
# We then replace the timestep in the learner state with the latest timestep
# This means the learner has access to the entire trajectory as well as
# an additional timestep which it can use to bootstrap.
Expand Down Expand Up @@ -473,7 +472,7 @@ def get_learner_thread(
eval_queue: Queue,
pipeline: Pipeline,
params_sources: Sequence[ParamsSource],
):
) -> threading.Thread:
"""Get the learner thread that is used to update the networks."""

learner_rollout_fn = get_learner_rollout_fn(learn, config, eval_queue, pipeline, params_sources)
Expand All @@ -492,7 +491,9 @@ def learner_setup(
keys: chex.Array,
learner_devices: Sequence[jax.Device],
config: DictConfig,
) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]:
) -> Tuple[
SebulbaLearnerFn[LearnerState, PPOTransition], Tuple[ActorApply, CriticApply], LearnerState
]:
"""Setup for the learner state and networks."""

# Create a single environment just to get the observation and action specs.
Expand Down
2 changes: 2 additions & 0 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,5 @@ def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]:
return make_envpool_factory(env_name, config)
elif "gymnasium" in suite_name:
return make_gymnasium_factory(env_name, config)
else:
raise ValueError(f"{suite_name} is not a supported suite.")

0 comments on commit 2b5cf90

Please sign in to comment.