Skip to content

Commit

Permalink
feat: intermediate work
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 11, 2024
1 parent 961a34b commit 52017a2
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 66 deletions.
5 changes: 3 additions & 2 deletions stoix/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
architecture_name : sebulba
# --- Training ---
seed: 42 # RNG seed.
total_num_envs: 4 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device.
total_num_envs: 1024 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device.
total_timesteps: 1e7 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
Expand All @@ -16,7 +16,8 @@ actor:
learner:
device_ids: [2, 3] # Define which devices to use for the learner.

pipeline_queue_size: 5
# Size of the queue for the pipeline where actors push data and the learner pulls data.
pipeline_queue_size: 10

# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
Expand Down
13 changes: 5 additions & 8 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def get_sebulba_eval_fn(
stacklevel=2,
)

def eval_fn(params: FrozenDict, key: chex.PRNGKey, init_act_state: Any) -> Dict:
def eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Dict:
"""Evaluates the given params on an environment and returns relevent metrics.
Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length,
Expand All @@ -412,13 +412,12 @@ def _episode(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]:

timesteps = [ts]

actor_state = init_act_state
finished_eps = ts.last()

while not finished_eps.all():
key, act_key = jax.random.split(key)
action, actor_state = act_fn(params, ts, act_key, actor_state)
cpu_action = jax.device_get(action).swapaxes(0, 1)
action = act_fn(params, ts.observation, act_key)
cpu_action = jax.device_get(action)
ts = env.step(cpu_action)
timesteps.append(ts)

Expand All @@ -427,8 +426,6 @@ def _episode(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]:
timesteps = jax.tree.map(lambda *x: np.stack(x), *timesteps)

metrics = timesteps.extras
if config.env.log_win_rate:
metrics["won_episode"] = timesteps.extras["won_episode"]

# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
Expand All @@ -451,11 +448,11 @@ def _episode(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]:
) # flatten metrics
return metrics

def timed_eval_fn(params: FrozenDict, key: chex.PRNGKey, init_act_state: Any) -> Any:
def timed_eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Any:
"""Wrapper around eval function to time it and add in steps per second metric."""
start_time = time.time()

metrics = eval_fn(params, key, init_act_state)
metrics = eval_fn(params, key)

end_time = time.time()
total_timesteps = jnp.sum(metrics["episode_length"])
Expand Down
2 changes: 1 addition & 1 deletion stoix/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flax import linen as nn

from stoix.base_types import Observation, RNNObservation
from stoix.networks.inputs import EmbeddingInput, ObservationInput
from stoix.networks.inputs import ObservationInput
from stoix.networks.utils import parse_rnn_cell


Expand Down
106 changes: 56 additions & 50 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,51 +97,52 @@ def rollout(rng: chex.PRNGKey) -> None:
traj: List[PPOTransition] = []
# Create the dictionary to store timings
timings_dict: Dict[str, List[float]] = defaultdict(list)
# Rollout the environment
with RecordTimeTo(timings_dict["single_rollout_time"]):
for _ in range(config.system.rollout_length):
with RecordTimeTo(timings_dict["get_params_time"]):
params = params_source.get()

cached_next_obs = jax.tree.map(move_to_device, timestep.observation)
cached_next_dones = move_to_device(next_dones)
cached_next_trunc = move_to_device(next_trunc)

with RecordTimeTo(timings_dict["compute_action_time"]):
rng, key = split_key_fn(rng)
pi = actor_apply_fn(params.actor_params, cached_next_obs)
value = critic_apply_fn(params.critic_params, cached_next_obs)
action = pi.sample(seed=key)
log_prob = pi.log_prob(action)

with RecordTimeTo(timings_dict["put_action_on_cpu_time"]):
action_cpu = np.array(jax.device_put(action, cpu))

with RecordTimeTo(timings_dict["env_step_time"]):
timestep = envs.step(action_cpu)

# Get the next dones and truncation flags
next_dones = np.logical_and(
np.array(timestep.last()), np.array(timestep.discount == 0.0)
)
next_trunc = np.logical_and(
np.array(timestep.last()), np.array(timestep.discount == 1.0)
)

for _ in range(config.system.rollout_length):
with RecordTimeTo(timings_dict["get_params_time"]):
params = params_source.get()

cached_next_obs = jax.tree.map(move_to_device, timestep.observation)
cached_next_dones = move_to_device(next_dones)
cached_next_trunc = move_to_device(next_trunc)

with RecordTimeTo(timings_dict["compute_action_time"]):
rng, key = split_key_fn(rng)
pi = actor_apply_fn(params.actor_params, cached_next_obs)
value = critic_apply_fn(params.critic_params, cached_next_obs)
action = pi.sample(seed=key)
log_prob = pi.log_prob(action)

with RecordTimeTo(timings_dict["put_action_on_cpu_time"]):
action_cpu = np.array(jax.device_put(action, cpu))

with RecordTimeTo(timings_dict["env_step_time"]):
timestep = envs.step(action_cpu)

# Get the next dones and truncation flags
next_dones = np.logical_and(
np.array(timestep.last()), np.array(timestep.discount == 0.0)
)
next_trunc = np.logical_and(
np.array(timestep.last()), np.array(timestep.discount == 1.0)
)

# Append data to storage
reward = timestep.reward
info = timestep.extras
traj.append(
PPOTransition(
cached_next_dones,
cached_next_trunc,
action,
value,
reward,
log_prob,
cached_next_obs,
info,
# Append data to storage
reward = timestep.reward
info = timestep.extras
traj.append(
PPOTransition(
cached_next_dones,
cached_next_trunc,
action,
value,
reward,
log_prob,
cached_next_obs,
info,
)
)
)

# Send the trajectory to the pipeline
with RecordTimeTo(timings_dict["rollout_put_time"]):
Expand Down Expand Up @@ -399,7 +400,7 @@ def learner_rollout(learner_state: LearnerState) -> None:
rollout_times: List[Dict] = []
learn_timings: Dict[str, List[float]] = defaultdict(list)

for _ in range(config.system.num_updates_per_eval):
for _ in range(config.arch.num_updates_per_eval):
with RecordTimeTo(learn_timings["rollout_get_time"]):
traj_batch, timestep, rollout_time = pipeline.get(block=True)

Expand Down Expand Up @@ -537,7 +538,9 @@ def learner_setup(

# Initialise learner state.
params, opt_states = replicate_learner
init_learner_state = LearnerState(params, opt_states, None, None, None)
key, step_key = jax.random.split(key)
step_keys = jax.random.split(step_key, len(learner_devices))
init_learner_state = LearnerState(params, opt_states, step_keys, None, None)

return learn, apply_fns, init_learner_state

Expand All @@ -549,7 +552,7 @@ def run_experiment(_config: DictConfig) -> float:
assert (
config.arch.num_updates > config.arch.num_evaluation
), "Number of updates per evaluation must be less than total number of updates."
config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation
config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation)

# Get the learner and actor devices
local_devices = jax.local_devices()
Expand All @@ -574,6 +577,10 @@ def run_experiment(_config: DictConfig) -> float:
num_envs_per_actor_device = config.arch.total_num_envs // len(actor_devices)
num_envs_per_actor = num_envs_per_actor_device // config.arch.actor.actor_per_device
config.arch.actor.envs_per_actor = num_envs_per_actor

assert num_envs_per_actor % len(local_learner_devices) == 0, (
"The number of envs per actor must be divisible by the number of learner devices"
)

# Create the environments for train and eval.
# env_factory = EnvPoolFactory(
Expand Down Expand Up @@ -603,9 +610,6 @@ def run_experiment(_config: DictConfig) -> float:
absolute_metric=False,
)

# Calculate number of updates per evaluation.
config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation

# Logger setup
logger = StoixLogger(config)
cfg: Dict = OmegaConf.to_container(config, resolve=True)
Expand Down Expand Up @@ -697,7 +701,7 @@ def run_experiment(_config: DictConfig) -> float:

unreplicated_actor_params = unreplicate(learner_state.params.actor_params)
key, eval_key = jax.random.split(key, 2)
eval_metrics = evaluator(unreplicated_actor_params, eval_key, {})
eval_metrics = evaluator(unreplicated_actor_params, eval_key)
logger.log(eval_metrics, t, eval_step, LogEvent.EVAL)

episode_return = jnp.mean(eval_metrics["episode_return"])
Expand All @@ -721,6 +725,8 @@ def run_experiment(_config: DictConfig) -> float:
actors_lifetime.stop()
for actor in actor_threads:
actor.join()

learner_thread.join()

pipeline_lifetime.stop()
pipeline.join()
Expand Down
10 changes: 6 additions & 4 deletions stoix/utils/sebulba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,17 @@ def run(self) -> None:
except queue.Empty:
continue

def put(self, traj: Sequence[StoixTransition], timestep: TimeStep, time_dict: Dict) -> None:
def put(self, traj: Sequence[StoixTransition], timestep: TimeStep, timings_dict: Dict) -> None:
"""Put a trajectory on the queue to be consumed by the learner."""
start_condition, end_condition = (threading.Condition(), threading.Condition())
with start_condition:
self.tickets_queue.put((start_condition, end_condition))
start_condition.wait() # wait to be allowed to start

# [Transition] * rollout_len --> Transition[done=(rollout_len, num_envs,)
sharded_traj = jax.tree.map(lambda *x: self.shard_split_playload(jnp.stack(x), 1), *traj)
# [Transition(num_envs)] * rollout_len --> Transition[(rollout_len, num_envs,)
traj = jax.tree_map(lambda *x: jnp.stack(x, axis=0), *traj)
# Split trajectory on the num envs axis so each learner device gets a valid full rollout
sharded_traj = jax.tree.map(lambda x : self.shard_split_playload(x, axis=1), traj)

# Timestep[(num_envs, ...), ...] -->
# [(num_envs / num_learner_devices, ...)] * num_learner_devices
Expand All @@ -81,7 +83,7 @@ def put(self, traj: Sequence[StoixTransition], timestep: TimeStep, time_dict: Di
if self._queue.full():
self._queue.get() # throw away the transition

self._queue.put((sharded_traj, sharded_timestep, time_dict))
self._queue.put((sharded_traj, sharded_timestep, timings_dict))

with end_condition:
end_condition.notify() # tell we have finish
Expand Down
3 changes: 2 additions & 1 deletion stoix/wrappers/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def _format_observation(
self, obs: NDArray, info: Dict
) -> Observation:
# TODO(edan): fix action mask
return Observation(agent_view=obs, action_mask=np.ones(1, dtype=np.float32))
num_envs = int(self.env.num_envs)
return Observation(agent_view=obs, action_mask=np.ones(num_envs, dtype=np.float32))

def _create_timestep(
self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict
Expand Down

0 comments on commit 52017a2

Please sign in to comment.