Skip to content

Commit

Permalink
Added new argument incompatible_shape_action to pytorch_toolbelt.ut…
Browse files Browse the repository at this point in the history
…ils.torch_utils.transfer_weights to control whether to try to match mean & std for uncompatible layers
  • Loading branch information
BloodAxe committed Aug 20, 2023
1 parent 4b36619 commit 0b08dcb
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion pytorch_toolbelt/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,40 @@ def maybe_cuda(x: Union[torch.Tensor, nn.Module]) -> Union[torch.Tensor, nn.Modu

logger = logging.getLogger("pytorch_toolbelt.utils")

def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict):

def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict, incompatible_shape_action="skip"):
"""
Copy weights from state dict to model, skipping layers that are incompatible.
This method is helpful if you are doing some model surgery and want to load
part of the model weights into different model.
:param model: Model to load weights into
:param model_state_dict: Model state dict to load weights from
:param incompatible_shape_action: What to do if shape of weight tensor is incompatible.
Possible values are:
- "skip" - Skip loading this tensor
- "match_mean_std" - Initialize tensor with random values with same mean and std as source tensor
:return: None
"""
existing_model_state_dict = model.state_dict()

for name, value in model_state_dict.items():
if name not in existing_model_state_dict:
logger.debug(f"transfer_weights skipped loading weights for key {name}, because it does not exist in model")
continue

existing_value = existing_model_state_dict[name]
if value.shape != existing_value.shape:
if incompatible_shape_action == "skip":
logger.debug(f"transfer_weights skipped loading weights for key {name}, because of checkpoint has shape {value.shape} and model has shape {existing_model_state_dict[name].shape}")
continue
elif incompatible_shape_action == "match_mean_std":
logger.debug(f"transfer_weights found that {name} weights tensor have incompatible shape {value.shape} and model has shape {existing_value.shape}. "
f"Initializing with random values with same mean {existing_value.mean()} and std {existing_value.std()} from corresponding checkpoint weights tensor.")
value = torch.zeros_like(existing_value)
torch.nn.init.randn(value, mean=existing_value.mean(), std=existing_value.std())
else:
raise ValueError(f"Unsupported incompatible_shape_action={incompatible_shape_action}")

try:
model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False)
except Exception as e:
Expand Down

0 comments on commit 0b08dcb

Please sign in to comment.