From 07e008ba90e6024ffc77111ece2176e9a765bc91 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Mon, 16 Sep 2024 09:30:51 +0300 Subject: [PATCH] Added all_gather_and_cat --- pytorch_toolbelt/utils/distributed.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/pytorch_toolbelt/utils/distributed.py b/pytorch_toolbelt/utils/distributed.py index e299ed2fe..8f4fe90ae 100644 --- a/pytorch_toolbelt/utils/distributed.py +++ b/pytorch_toolbelt/utils/distributed.py @@ -31,7 +31,7 @@ "master_print", "reduce_dict_sum", "split_across_nodes", - "master_node_only" + "master_node_only" "all_gather_and_cat", ] logger = logging.getLogger("pytorch_toolbelt.utils.distributed") @@ -195,6 +195,20 @@ def all_gather(data: Any) -> List[Any]: return data_list +def all_gather_and_cat(data: Any, dim=0) -> Any: + data = all_gather(data) + if isinstance(data[0], torch.Tensor): + return torch.cat(data, dim=0) + elif isinstance(data[0], np.ndarray): + return np.concatenate(data, axis=dim) + elif isinstance(data[0], list): + return [item for sublist in data for item in sublist] + else: + raise RuntimeError( + f"Unsupported data type {type(data[0])}. Input data must be list of torch.Tensor, np.ndarray or list" + ) + + def reduce_dict_sum(input_dict: Dict[Any, Any]) -> Dict[Any, Any]: """ Reduce the values in the dictionary from all processes so that all processes @@ -315,6 +329,7 @@ def split_across_nodes( else: return collection + def master_node_only(func): """ A decorator for making sure a function runs only in main process. @@ -326,8 +341,10 @@ def master_node_only(func): return_type = inspect.signature(func).return_annotation function_has_return_value = return_type is not None and return_type != inspect._empty if function_has_return_value: - raise RuntimeError(f"Function {func} decorated with @master_node_only must not return any value. " - f"Function signature: {inspect.signature(func)}") + raise RuntimeError( + f"Function {func} decorated with @master_node_only must not return any value. " + f"Function signature: {inspect.signature(func)}" + ) @functools.wraps(func) def wrapper(*args, **kwargs): @@ -336,4 +353,4 @@ def wrapper(*args, **kwargs): else: return None - return wrapper \ No newline at end of file + return wrapper