Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor types #106

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .github/workflows/run_algs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Check Algorithms 🧪

on:
push:
branches:
- main
pull_request:
branches:
- main
workflow_dispatch:

jobs:
test-algorithms:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
runs-on: "${{ matrix.os }}"
timeout-minutes: 30

strategy:
matrix:
python-version: ["3.10"]
os: [ubuntu-latest]

steps:
- name: Checkout stoix
uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "${{ matrix.python-version }}"
- name: Install python dependencies 🔧
run: pip install .

- name: Make Bash Script Executable
run: chmod +x bash_scripts/run-algorithms.sh

- name: Run Bash Script
run: ./bash_scripts/run-algorithms.sh
17 changes: 17 additions & 0 deletions bash_scripts/run-algorithms.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

echo "Running All Algorithms..."

python stoix/systems/ppo/anakin/ff_ppo.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/ppo/anakin/ff_ppo_continuous.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/q_learning/ff_dqn.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/q_learning/ff_ddqn.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/q_learning/ff_mdqn.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/q_learning/ff_c51.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/q_learning/ff_qr_dqn.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/sac/ff_sac.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/ddpg/ff_ddpg.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/ddpg/ff_td3.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/vpg/ff_reinforce.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/awr/ff_awr.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
python stoix/systems/mpo/ff_mpo.py arch.total_timesteps=300 arch.total_num_envs=8 arch.num_evaluation=1 system.rollout_length=8
23 changes: 19 additions & 4 deletions stoix/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,33 @@ class OnlineAndTarget(NamedTuple):
)


class ExperimentOutput(NamedTuple, Generic[StoixState]):
class SebulbaExperimentOutput(NamedTuple, Generic[StoixState]):
"""Experiment output."""

learner_state: StoixState
train_metrics: Dict[str, chex.Array]


class AnakinExperimentOutput(NamedTuple, Generic[StoixState]):
"""Experiment output."""

learner_state: StoixState
episode_metrics: Dict[str, chex.Array]
train_metrics: Dict[str, chex.Array]


class EvaluationOutput(NamedTuple, Generic[StoixState]):
"""Evaluation output."""

learner_state: StoixState
episode_metrics: Dict[str, chex.Array]


RNNObservation: TypeAlias = Tuple[Observation, Done]
LearnerFn = Callable[[StoixState], ExperimentOutput[StoixState]]
SebulbaLearnerFn = Callable[[StoixState, StoixTransition], ExperimentOutput[StoixState]]
EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[StoixState]]
LearnerFn = Callable[[StoixState], AnakinExperimentOutput[StoixState]]
SebulbaLearnerFn = Callable[[StoixState, StoixTransition], SebulbaExperimentOutput[StoixState]]
EvalFn = Callable[[FrozenDict, chex.PRNGKey], EvaluationOutput[StoixState]]
SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict[str, chex.Array]]

ActorApply = Callable[..., DistributionLike]

Expand Down
39 changes: 7 additions & 32 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import time
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import chex
import flax.linen as nn
Expand All @@ -18,11 +18,12 @@
EnvFactory,
EvalFn,
EvalState,
ExperimentOutput,
EvaluationOutput,
RecActFn,
RecActorApply,
RNNEvalState,
RNNObservation,
SebulbaEvalFn,
)
from stoix.utils.jax_utils import unreplicate_batch_dim

Expand Down Expand Up @@ -133,7 +134,7 @@ def not_done(carry: Tuple) -> bool:

return eval_metrics

def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOutput[EvalState]:
def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> EvaluationOutput[EvalState]:
"""Evaluator function."""

# Initialise environment states and timesteps.
Expand Down Expand Up @@ -164,10 +165,9 @@ def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOut
axis_name="eval_batch",
)(trained_params, eval_state)

return ExperimentOutput(
return EvaluationOutput(
learner_state=eval_state,
episode_metrics=eval_metrics,
train_metrics={},
)

return evaluator_fn
Expand Down Expand Up @@ -248,7 +248,7 @@ def not_done(carry: Tuple) -> bool:

def evaluator_fn(
trained_params: FrozenDict, key: chex.PRNGKey
) -> ExperimentOutput[RNNEvalState]:
) -> EvaluationOutput[RNNEvalState]:
"""Evaluator function."""

# Initialise environment states and timesteps.
Expand Down Expand Up @@ -289,10 +289,9 @@ def evaluator_fn(
axis_name="eval_batch",
)(trained_params, eval_state)

return ExperimentOutput(
return EvaluationOutput(
learner_state=eval_state,
episode_metrics=eval_metrics,
train_metrics={},
)

return evaluator_fn
Expand Down Expand Up @@ -356,11 +355,6 @@ def evaluator_setup(
return evaluator, absolute_metric_evaluator, (trained_params, eval_keys)


##### THIS IS TEMPORARY

SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict[str, chex.Array]]


def get_sebulba_eval_fn(
env_factory: EnvFactory,
act_fn: ActFn,
Expand All @@ -369,18 +363,7 @@ def get_sebulba_eval_fn(
device: jax.Device,
eval_multiplier: float = 1.0,
) -> Tuple[SebulbaEvalFn, Any]:
"""Creates a function that can be used to evaluate agents on a given environment.

Args:
----
env: an environment that conforms to the mava environment spec.
act_fn: a function that takes in params, timestep, key and optionally a state
and returns actions and optionally a state (see `EvalActFn`).
config: the system config.
np_rng: a numpy random number generator.
eval_multiplier: a scalar that will increase the number of evaluation episodes
by a fixed factor.
"""
eval_episodes = config.arch.num_eval_episodes * eval_multiplier

# We calculate here the number of parallel envs we can run in parallel.
Expand All @@ -405,14 +388,6 @@ def get_sebulba_eval_fn(
print(f"{Fore.YELLOW}{Style.BRIGHT}{msg}{Style.RESET_ALL}")

def eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Dict:
"""Evaluates the given params on an environment and returns relevant metrics.

Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length,
also win rate for environments that support it.

Returns: Dict[str, Array] - dictionary of metric name to metric values for each episode.
"""

def _run_episodes(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]:
"""Simulates `num_envs` episodes."""
with jax.default_device(device):
Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/awr/ff_awr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
ActorApply,
ActorCriticOptStates,
ActorCriticParams,
AnakinExperimentOutput,
CriticApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
)
Expand Down Expand Up @@ -323,7 +323,7 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerState]:
def learner_fn(learner_state: AWRLearnerState) -> AnakinExperimentOutput[AWRLearnerState]:
"""Learner function.

This function represents the learner, it updates the network parameters
Expand All @@ -336,7 +336,7 @@ def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerSta
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/awr/ff_awr_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
ActorApply,
ActorCriticOptStates,
ActorCriticParams,
AnakinExperimentOutput,
CriticApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
)
Expand Down Expand Up @@ -323,7 +323,7 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerState]:
def learner_fn(learner_state: AWRLearnerState) -> AnakinExperimentOutput[AWRLearnerState]:
"""Learner function.

This function represents the learner, it updates the network parameters
Expand All @@ -336,7 +336,7 @@ def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerSta
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
8 changes: 5 additions & 3 deletions stoix/systems/ddpg/ff_d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
ContinuousQApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
Observation,
Expand Down Expand Up @@ -347,7 +347,9 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]:
def learner_fn(
learner_state: OffPolicyLearnerState,
) -> AnakinExperimentOutput[OffPolicyLearnerState]:
"""Learner function.

This function represents the learner, it updates the network parameters
Expand All @@ -360,7 +362,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
8 changes: 5 additions & 3 deletions stoix/systems/ddpg/ff_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
ContinuousQApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
Observation,
Expand Down Expand Up @@ -309,7 +309,9 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]:
def learner_fn(
learner_state: OffPolicyLearnerState,
) -> AnakinExperimentOutput[OffPolicyLearnerState]:
"""Learner function.

This function represents the learner, it updates the network parameters
Expand All @@ -323,7 +325,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
8 changes: 5 additions & 3 deletions stoix/systems/ddpg/ff_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
ContinuousQApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
Observation,
Expand Down Expand Up @@ -327,7 +327,9 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]:
def learner_fn(
learner_state: OffPolicyLearnerState,
) -> AnakinExperimentOutput[OffPolicyLearnerState]:
"""Learner function.

This function represents the learner, it updates the network parameters
Expand All @@ -341,7 +343,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/mpo/ff_mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
ContinuousQApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
OnlineAndTarget,
Expand Down Expand Up @@ -406,7 +406,7 @@ def _q_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerState]:
def learner_fn(learner_state: MPOLearnerState) -> AnakinExperimentOutput[MPOLearnerState]:
"""Learner function.

This function represents the learner, it updates the network parameters
Expand All @@ -419,7 +419,7 @@ def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerSta
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
Loading
Loading