Skip to content

Commit

Permalink
feat: intermediate work
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 24, 2024
1 parent ef49baa commit c17b2a1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 24 deletions.
8 changes: 4 additions & 4 deletions stoix/systems/q_learning/sebulba/ff_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,17 +623,18 @@ def run_experiment(_config: DictConfig) -> float:
# Now we create the pipeline
replay_buffer_add, replay_buffer_sample = buffer_fns
# Set up the rate limiter that controls how actors and learners interact with the pipeline
if config.arch.pipeline.samples_per_insert > 1:
if config.system.epochs > 0:
samples_per_insert_tolerance_rate = config.arch.pipeline.samples_per_insert_tolerance_rate
samples_per_insert_tolerance = (
samples_per_insert_tolerance_rate * config.arch.pipeline.samples_per_insert
samples_per_insert_tolerance_rate * config.system.epochs
)
error_buffer = config.arch.pipeline.min_replay_size * samples_per_insert_tolerance
rate_limiter = SampleToInsertRatio(
config.system.epochs, config.arch.pipeline.min_replay_size, error_buffer
)
else:
rate_limiter = MinSize(config.arch.pipeline.min_replay_size) # type: ignore
pass
rate_limiter = MinSize(config.arch.pipeline.min_replay_size) # type: ignore
pipeline = OffPolicyPipeline(
replay_buffer_add,
replay_buffer_sample,
Expand Down Expand Up @@ -747,7 +748,6 @@ def run_experiment(_config: DictConfig) -> float:
for actor in actor_threads:
# We clear the pipeline before stopping each actor thread
# since actors can be blocked on the pipeline
pipeline.clear()
actor.join()

print(f"{Fore.MAGENTA}{Style.BRIGHT}Closing pipeline...{Style.RESET_ALL}")
Expand Down
31 changes: 11 additions & 20 deletions stoix/utils/sebulba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def put(

# signal that we have inserted the data
self.rate_limiter.insert()

# Concatenate metrics - List[Dict[str, List[float]]] --> Dict[str, List[float]]
actor_episode_metrics = self.concatenate_metrics(actor_episode_metrics)

# add timings to the timings queue
self._metrics_queue.put((actor_timings_dict, actor_episode_metrics))
Expand Down Expand Up @@ -253,11 +256,8 @@ def get(self, timeout: Union[float, None] = None) -> Tuple[StoixTransition, Dict
# split the trajectory over the learner devices
sharded_sampled_batch = jax.tree.map(lambda x: self.shard_split_playload(x), sampled_batch)

# Get all metrics from the metrics queue and concatenate them
# TODO(edan): investigate speed of this
actor_timings, actor_metrics = self.get_all_metrics()
actor_timings = self.stack_metrics(actor_timings)
actor_metrics = self.stack_metrics(actor_metrics)
# Get the timings and metrics from the metrics queue
actor_timings, actor_metrics = self._metrics_queue.get()

return sharded_sampled_batch, actor_timings, actor_metrics # type: ignore

Expand All @@ -271,22 +271,13 @@ def shard_split_playload(self, payload: Any, axis: int = 0) -> Any:
"""Split the payload over the learner devices."""
split_payload = jnp.split(payload, len(self.learner_devices), axis=axis)
return jax.device_put_sharded(split_payload, devices=self.learner_devices)

def get_all_metrics(self) -> Tuple[List[Dict], List[Dict]]:
"""Get all metrics from the metrics queue."""
actor_timings = []
actor_metrics = []
while not self._metrics_queue.empty():
actor_timings_dict, actor_episode_metrics = self._metrics_queue.get()
actor_timings.append(actor_timings_dict)
actor_metrics.append(actor_episode_metrics)
return actor_timings, actor_metrics


@partial(jax.jit, static_argnums=(0,))
def stack_metrics(self, metrics: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Stack a list of timings dictionaries into a single dictionary."""
metrics: Dict[str, Any] = jax.tree_map(lambda *x: jnp.stack(jnp.asarray(x)), *metrics)
return metrics
def concatenate_metrics(
self, actor_metrics: List[Dict[str, List[float]]]
) -> Dict[str, List[float]]:
"""Concatenate a list of actor metrics into a single dictionary."""
return jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *actor_metrics) # type: ignore

def clear(self) -> None:
"""Clear the buffer."""
Expand Down

0 comments on commit c17b2a1

Please sign in to comment.