Skip to content

Commit

Permalink
Added all_gather_and_cat
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Sep 16, 2024
1 parent de10152 commit 07e008b
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions pytorch_toolbelt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"master_print",
"reduce_dict_sum",
"split_across_nodes",
"master_node_only"
"master_node_only" "all_gather_and_cat",
]

logger = logging.getLogger("pytorch_toolbelt.utils.distributed")
Expand Down Expand Up @@ -195,6 +195,20 @@ def all_gather(data: Any) -> List[Any]:
return data_list


def all_gather_and_cat(data: Any, dim=0) -> Any:
data = all_gather(data)
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], np.ndarray):
return np.concatenate(data, axis=dim)
elif isinstance(data[0], list):
return [item for sublist in data for item in sublist]
else:
raise RuntimeError(
f"Unsupported data type {type(data[0])}. Input data must be list of torch.Tensor, np.ndarray or list"
)


def reduce_dict_sum(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Reduce the values in the dictionary from all processes so that all processes
Expand Down Expand Up @@ -315,6 +329,7 @@ def split_across_nodes(
else:
return collection


def master_node_only(func):
"""
A decorator for making sure a function runs only in main process.
Expand All @@ -326,8 +341,10 @@ def master_node_only(func):
return_type = inspect.signature(func).return_annotation
function_has_return_value = return_type is not None and return_type != inspect._empty
if function_has_return_value:
raise RuntimeError(f"Function {func} decorated with @master_node_only must not return any value. "
f"Function signature: {inspect.signature(func)}")
raise RuntimeError(
f"Function {func} decorated with @master_node_only must not return any value. "
f"Function signature: {inspect.signature(func)}"
)

@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -336,4 +353,4 @@ def wrapper(*args, **kwargs):
else:
return None

return wrapper
return wrapper

0 comments on commit 07e008b

Please sign in to comment.