From 9402e02698996790bd49445ac04bfc700ab98df4 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sun, 11 Aug 2024 13:52:42 +0000 Subject: [PATCH] feat: intermediate work --- stoix/systems/ppo/sebulba/ff_ppo.py | 22 ++++----- stoix/wrappers/envpool.py | 74 +++++++++++++++++++++++++---- stoix/wrappers/gymnasium.py | 14 ++++++ 3 files changed, 91 insertions(+), 19 deletions(-) diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index 8ceb8bc..01b8e28 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -455,12 +455,13 @@ def learner_setup( config: DictConfig, ) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: - # Get number/dimension of actions. + # Create a single environment to get the observation and action shapes. env = env_factory(num_envs=1) - obs_shape = env.unwrapped.single_observation_space.shape - num_actions = int(env.env.unwrapped.single_action_space.n) - env.close() + # Get number/dimension of actions. + num_actions = int(env.action_spec().num_values) config.system.action_dim = num_actions + example_obs = env.observation_spec().generate_value() + env.close() # PRNG keys. key, actor_net_key, critic_net_key = keys @@ -493,7 +494,7 @@ def learner_setup( ) # Initialise observation - init_x = Observation(agent_view=jnp.ones(obs_shape), action_mask=jnp.ones(num_actions)) + init_x = example_obs init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x) # Initialise actor params and optimiser state. @@ -583,12 +584,11 @@ def run_experiment(_config: DictConfig) -> float: ), "The number of envs per actor must be divisible by the number of learner devices" # Create the environments for train and eval. - # env_factory = EnvPoolFactory( - # config.arch.seed, - # task_id="CartPole-v1", - # env_type="dm", - # ) - env_factory = GymnasiumFactory("CartPole-v1") + env_factory = EnvPoolFactory( + "CartPole-v1", + config.arch.seed + ) + # env_factory = GymnasiumFactory("CartPole-v1") # PRNG keys. key, key_e, actor_net_key, critic_net_key = jax.random.split( diff --git a/stoix/wrappers/envpool.py b/stoix/wrappers/envpool.py index 0661246..e1457ef 100644 --- a/stoix/wrappers/envpool.py +++ b/stoix/wrappers/envpool.py @@ -10,19 +10,25 @@ import numpy as np from jumanji.types import StepType, TimeStep from numpy.typing import NDArray - +from jumanji.specs import Array, Spec, DiscreteArray from stoix.base_types import Observation class EnvPoolToJumanji: - """Converts from the Gym API to the dm_env API, using Jumanji's Timestep type.""" + """Converts from the Gymnasium envpool API to Jumanji's API.""" def __init__(self, env: Any): self.env = env - self.num_envs = self.env.num_envs + obs, _ = self.env.reset() + self.num_envs = obs.shape[0] + self.obs_shape = obs.shape[1:] self.num_actions = self.env.action_space.n - self.obs_shape = self.env.observation_space.shape - self._default_action_mask = np.ones(self.num_actions, dtype=np.float32) + self._default_action_mask = np.ones((self.num_envs, self.num_actions), dtype=np.float32) + # Create the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) def reset( self, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None @@ -32,6 +38,21 @@ def reset( ep_done = np.zeros(self.num_envs, dtype=float) rewards = np.zeros(self.num_envs, dtype=float) terminated = np.zeros(self.num_envs, dtype=float) + + # Reset the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) + + # Create the metrics dict + metrics = { + "episode_return": self.episode_return, + "episode_length": self.episode_length, + "is_terminal_step": np.zeros(self.num_envs, dtype=bool), + } + + info["metrics"] = metrics timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) @@ -41,21 +62,42 @@ def step(self, action: list) -> TimeStep: obs, rewards, terminated, truncated, info = self.env.step(action) ep_done = np.logical_or(terminated, truncated) + not_done = 1 - ep_done + + # Counting episode return and length. + if "reward" in info: + metric_reward = info["reward"] + else: + metric_reward = rewards + new_episode_return = self.running_count_episode_return + metric_reward + new_episode_length = self.running_count_episode_length + 1 + + # Previous episode return/length until done and then the next episode return. + episode_return_info = self.episode_return * not_done + new_episode_return * ep_done + episode_length_info = self.episode_length * not_done + new_episode_length * ep_done + # Create the metrics dict + metrics = { + "episode_return": episode_return_info, + "episode_length": episode_length_info, + "is_terminal_step": ep_done, + } + + info["metrics"] = metrics + timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) return timestep def _format_observation(self, obs: NDArray, info: Dict) -> Observation: action_mask = self._default_action_mask - multi_env_action_mask = np.stack([action_mask] * obs.shape[0]) - return Observation(agent_view=obs, action_mask=multi_env_action_mask) + return Observation(agent_view=obs, action_mask=action_mask) def _create_timestep( self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict ) -> TimeStep: obs = self._format_observation(obs, info) - extras = {} #jax.tree.map(lambda *x: np.stack(x), *info["metrics"]) + extras = info["metrics"] step_type = np.where(ep_done, StepType.LAST, StepType.MID) return TimeStep( @@ -65,3 +107,19 @@ def _create_timestep( observation=obs, extras=extras, ) + + def observation_spec(self) -> Spec: + agent_view_spec = Array(shape=self.obs_shape, dtype=float) + return Spec( + Observation, + "ObservationSpec", + agent_view=agent_view_spec, + action_mask=Array(shape=(self.num_actions,), dtype=float), + step_count=Array(shape=(), dtype=int), + ) + + def action_spec(self) -> Spec: + return DiscreteArray(num_values=self.num_actions) + + def close(self) -> None: + self.env.close() \ No newline at end of file diff --git a/stoix/wrappers/gymnasium.py b/stoix/wrappers/gymnasium.py index cbde701..6d210d6 100644 --- a/stoix/wrappers/gymnasium.py +++ b/stoix/wrappers/gymnasium.py @@ -10,6 +10,7 @@ import numpy as np from jumanji.types import StepType, TimeStep from numpy.typing import NDArray +from jumanji.specs import Array, Spec, DiscreteArray from stoix.base_types import Observation @@ -163,3 +164,16 @@ def _create_timestep( observation=obs, extras=extras, ) + + def observation_spec(self) -> Spec: + agent_view_spec = Array(shape=self.obs_shape, dtype=float) + return Spec( + Observation, + "ObservationSpec", + agent_view=agent_view_spec, + action_mask=Array(shape=(self.num_actions,), dtype=float), + step_count=Array(shape=(), dtype=int), + ) + + def action_spec(self) -> Spec: + return DiscreteArray(num_values=self.num_actions)