diff --git a/pytorch_toolbelt/utils/torch_utils.py b/pytorch_toolbelt/utils/torch_utils.py index f9eafe001..a6c3bb7c0 100644 --- a/pytorch_toolbelt/utils/torch_utils.py +++ b/pytorch_toolbelt/utils/torch_utils.py @@ -263,16 +263,40 @@ def maybe_cuda(x: Union[torch.Tensor, nn.Module]) -> Union[torch.Tensor, nn.Modu logger = logging.getLogger("pytorch_toolbelt.utils") -def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict): + +def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict, incompatible_shape_action="skip"): """ Copy weights from state dict to model, skipping layers that are incompatible. This method is helpful if you are doing some model surgery and want to load part of the model weights into different model. :param model: Model to load weights into :param model_state_dict: Model state dict to load weights from + :param incompatible_shape_action: What to do if shape of weight tensor is incompatible. + Possible values are: + - "skip" - Skip loading this tensor + - "match_mean_std" - Initialize tensor with random values with same mean and std as source tensor :return: None """ + existing_model_state_dict = model.state_dict() + for name, value in model_state_dict.items(): + if name not in existing_model_state_dict: + logger.debug(f"transfer_weights skipped loading weights for key {name}, because it does not exist in model") + continue + + existing_value = existing_model_state_dict[name] + if value.shape != existing_value.shape: + if incompatible_shape_action == "skip": + logger.debug(f"transfer_weights skipped loading weights for key {name}, because of checkpoint has shape {value.shape} and model has shape {existing_model_state_dict[name].shape}") + continue + elif incompatible_shape_action == "match_mean_std": + logger.debug(f"transfer_weights found that {name} weights tensor have incompatible shape {value.shape} and model has shape {existing_value.shape}. " + f"Initializing with random values with same mean {existing_value.mean()} and std {existing_value.std()} from corresponding checkpoint weights tensor.") + value = torch.zeros_like(existing_value) + torch.nn.init.randn(value, mean=existing_value.mean(), std=existing_value.std()) + else: + raise ValueError(f"Unsupported incompatible_shape_action={incompatible_shape_action}") + try: model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) except Exception as e: