Skip to content

Commit

Permalink
Fix bug in initialization of weights using match_mean_std
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Aug 27, 2023
1 parent 30e3a05 commit 75e6f46
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch_toolbelt/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict
f"transfer_weights found that {name} weights tensor have incompatible shape {value.shape} and model has shape {existing_value.shape}. "
f"Initializing with random values with same mean {existing_value.mean()} and std {existing_value.std()} from corresponding checkpoint weights tensor."
)
value = torch.zeros_like(existing_value)
torch.nn.init.normal_(value, mean=existing_value.mean(), std=existing_value.std())
torch.nn.init.normal_(existing_value, mean=value.mean(), std=value.std())
value = existing_value
else:
raise ValueError(f"Unsupported incompatible_shape_action={incompatible_shape_action}")

Expand Down

0 comments on commit 75e6f46

Please sign in to comment.