Skip to content

Commit

Permalink
feat: intermediate work
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 11, 2024
1 parent e06d70e commit 9402e02
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 19 deletions.
22 changes: 11 additions & 11 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,13 @@ def learner_setup(
config: DictConfig,
) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]:

# Get number/dimension of actions.
# Create a single environment to get the observation and action shapes.
env = env_factory(num_envs=1)
obs_shape = env.unwrapped.single_observation_space.shape
num_actions = int(env.env.unwrapped.single_action_space.n)
env.close()
# Get number/dimension of actions.
num_actions = int(env.action_spec().num_values)
config.system.action_dim = num_actions
example_obs = env.observation_spec().generate_value()
env.close()

# PRNG keys.
key, actor_net_key, critic_net_key = keys
Expand Down Expand Up @@ -493,7 +494,7 @@ def learner_setup(
)

# Initialise observation
init_x = Observation(agent_view=jnp.ones(obs_shape), action_mask=jnp.ones(num_actions))
init_x = example_obs
init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x)

# Initialise actor params and optimiser state.
Expand Down Expand Up @@ -583,12 +584,11 @@ def run_experiment(_config: DictConfig) -> float:
), "The number of envs per actor must be divisible by the number of learner devices"

# Create the environments for train and eval.
# env_factory = EnvPoolFactory(
# config.arch.seed,
# task_id="CartPole-v1",
# env_type="dm",
# )
env_factory = GymnasiumFactory("CartPole-v1")
env_factory = EnvPoolFactory(
"CartPole-v1",
config.arch.seed
)
# env_factory = GymnasiumFactory("CartPole-v1")

# PRNG keys.
key, key_e, actor_net_key, critic_net_key = jax.random.split(
Expand Down
74 changes: 66 additions & 8 deletions stoix/wrappers/envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@
import numpy as np
from jumanji.types import StepType, TimeStep
from numpy.typing import NDArray

from jumanji.specs import Array, Spec, DiscreteArray
from stoix.base_types import Observation


class EnvPoolToJumanji:
"""Converts from the Gym API to the dm_env API, using Jumanji's Timestep type."""
"""Converts from the Gymnasium envpool API to Jumanji's API."""

def __init__(self, env: Any):
self.env = env
self.num_envs = self.env.num_envs
obs, _ = self.env.reset()
self.num_envs = obs.shape[0]
self.obs_shape = obs.shape[1:]
self.num_actions = self.env.action_space.n
self.obs_shape = self.env.observation_space.shape
self._default_action_mask = np.ones(self.num_actions, dtype=np.float32)
self._default_action_mask = np.ones((self.num_envs, self.num_actions), dtype=np.float32)
# Create the metrics
self.running_count_episode_return = np.zeros(self.num_envs, dtype=float)
self.running_count_episode_length = np.zeros(self.num_envs, dtype=int)
self.episode_return = np.zeros(self.num_envs, dtype=float)
self.episode_length = np.zeros(self.num_envs, dtype=int)

def reset(
self, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None
Expand All @@ -32,6 +38,21 @@ def reset(
ep_done = np.zeros(self.num_envs, dtype=float)
rewards = np.zeros(self.num_envs, dtype=float)
terminated = np.zeros(self.num_envs, dtype=float)

# Reset the metrics
self.running_count_episode_return = np.zeros(self.num_envs, dtype=float)
self.running_count_episode_length = np.zeros(self.num_envs, dtype=int)
self.episode_return = np.zeros(self.num_envs, dtype=float)
self.episode_length = np.zeros(self.num_envs, dtype=int)

# Create the metrics dict
metrics = {
"episode_return": self.episode_return,
"episode_length": self.episode_length,
"is_terminal_step": np.zeros(self.num_envs, dtype=bool),
}

info["metrics"] = metrics

timestep = self._create_timestep(obs, ep_done, terminated, rewards, info)

Expand All @@ -41,21 +62,42 @@ def step(self, action: list) -> TimeStep:
obs, rewards, terminated, truncated, info = self.env.step(action)

ep_done = np.logical_or(terminated, truncated)
not_done = 1 - ep_done

# Counting episode return and length.
if "reward" in info:
metric_reward = info["reward"]
else:
metric_reward = rewards
new_episode_return = self.running_count_episode_return + metric_reward
new_episode_length = self.running_count_episode_length + 1

# Previous episode return/length until done and then the next episode return.
episode_return_info = self.episode_return * not_done + new_episode_return * ep_done
episode_length_info = self.episode_length * not_done + new_episode_length * ep_done

# Create the metrics dict
metrics = {
"episode_return": episode_return_info,
"episode_length": episode_length_info,
"is_terminal_step": ep_done,
}

info["metrics"] = metrics

timestep = self._create_timestep(obs, ep_done, terminated, rewards, info)

return timestep

def _format_observation(self, obs: NDArray, info: Dict) -> Observation:
action_mask = self._default_action_mask
multi_env_action_mask = np.stack([action_mask] * obs.shape[0])
return Observation(agent_view=obs, action_mask=multi_env_action_mask)
return Observation(agent_view=obs, action_mask=action_mask)

def _create_timestep(
self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict
) -> TimeStep:
obs = self._format_observation(obs, info)
extras = {} #jax.tree.map(lambda *x: np.stack(x), *info["metrics"])
extras = info["metrics"]
step_type = np.where(ep_done, StepType.LAST, StepType.MID)

return TimeStep(
Expand All @@ -65,3 +107,19 @@ def _create_timestep(
observation=obs,
extras=extras,
)

def observation_spec(self) -> Spec:
agent_view_spec = Array(shape=self.obs_shape, dtype=float)
return Spec(
Observation,
"ObservationSpec",
agent_view=agent_view_spec,
action_mask=Array(shape=(self.num_actions,), dtype=float),
step_count=Array(shape=(), dtype=int),
)

def action_spec(self) -> Spec:
return DiscreteArray(num_values=self.num_actions)

def close(self) -> None:
self.env.close()
14 changes: 14 additions & 0 deletions stoix/wrappers/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
from jumanji.types import StepType, TimeStep
from numpy.typing import NDArray
from jumanji.specs import Array, Spec, DiscreteArray

from stoix.base_types import Observation

Expand Down Expand Up @@ -163,3 +164,16 @@ def _create_timestep(
observation=obs,
extras=extras,
)

def observation_spec(self) -> Spec:
agent_view_spec = Array(shape=self.obs_shape, dtype=float)
return Spec(
Observation,
"ObservationSpec",
agent_view=agent_view_spec,
action_mask=Array(shape=(self.num_actions,), dtype=float),
step_count=Array(shape=(), dtype=int),
)

def action_spec(self) -> Spec:
return DiscreteArray(num_values=self.num_actions)

0 comments on commit 9402e02

Please sign in to comment.