Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: extend the unwrap_model function and save unwrapped model state dict instead of wrapped #29780

28 changes: 21 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4661,16 +4661,30 @@ def forward(

def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Recursively unwraps a module and its child sublayers.

Args:
model (`torch.nn.Module`): The model to unwrap.
model (nn.Module): The model to unwrap.

Returns:
nn.Module: The unwrapped module.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model

def recursive_unwrap(module):
if hasattr(module, "module"):
unwrapped_module = recursive_unwrap(getattr(module, "module"))
else:
unwrapped_module = module # Handle cases where wrapped module is inaccessible

# Unwrap child sublayers recursively
for name, child in module.named_children():
setattr(module, name, recursive_unwrap(child))

return unwrapped_module

# Start with top-level unwrapping
unwrapped_model = recursive_unwrap(model)
return unwrapped_model


def expand_device_map(device_map, param_names, start_prefix):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3206,7 +3206,7 @@ def _save_tpu(self, output_dir: Optional[str] = None):
unwrap_model(model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=model.state_dict(),
state_dict=unwrap_model(model).state_dict(),
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
Expand Down
64 changes: 64 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
OwlViTForObjectDetection,
PretrainedConfig,
is_torch_available,
is_torch_xla_available,
logging,
)
from transformers.testing_utils import (
Expand All @@ -54,6 +55,7 @@
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_xla,
require_usr_bin_time,
slow,
torch_device,
Expand Down Expand Up @@ -197,6 +199,14 @@ def forward(self, mask, inputs_embeds):
if is_tf_available():
from transformers import TFBertModel

if is_torch_xla_available():
import numpy as np
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2

from transformers.modeling_utils import unwrap_model


TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
Expand All @@ -212,6 +222,21 @@ def check_models_equal(model1, model2):
return models_are_equal


def check_state_dict_keys_equal(state_dict_keys_model1, state_dict_keys_model2):
for key1, key2 in zip(state_dict_keys_model1, state_dict_keys_model2):
if key1 != key2:
return False
return True


def unwrap_model_old(model):
"""Old unwrap implementation"""
if hasattr(model, "module"):
return unwrap_model_old(model.module)
else:
return model


@require_torch
class ModelUtilsTest(TestCasePlus):
@slow
Expand Down Expand Up @@ -2176,3 +2201,42 @@ def test_partial_stacked_causal_mask(self):
]

self.assertEqual(decoded_0, decoded_1b)


@slow
@require_torch
@require_torch_xla
class UnwrapModelTest(unittest.TestCase):
def test_compatibility_with_original_behavior(self):
model_id = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))

wrapped_model = FSDPv2(model)
unwrapped_model_old = unwrap_model_old(wrapped_model)
state_dict_keys_model1 = list(unwrapped_model_old.state_dict().keys())
unwrapped_model_new = unwrap_model(wrapped_model)
state_dict_keys_model2 = list(unwrapped_model_new.state_dict().keys())
self.assertEqual(check_state_dict_keys_equal(state_dict_keys_model1, state_dict_keys_model2), True)

def test_nested_unwrap_modules(self):
model_id = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
orig_state_dict_keys = list(model.state_dict().keys())
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))

def nested_wrap(model):
layer = getattr(getattr(model, "model"), "embed_tokens")
wrapped_layer = FSDPv2(layer)
setattr(getattr(model, "model"), "embed_tokens", wrapped_layer)
return FSDPv2(model)

wrapped_model = nested_wrap(model)
unwrapped_model_old = unwrap_model_old(wrapped_model)
old_state_dict_keys = list(unwrapped_model_old.state_dict().keys())
unwrapped_model_new = unwrap_model(wrapped_model)
new_state_dict_keys = list(unwrapped_model_new.state_dict().keys())
self.assertEqual(check_state_dict_keys_equal(old_state_dict_keys, orig_state_dict_keys), False)
self.assertEqual(check_state_dict_keys_equal(new_state_dict_keys, orig_state_dict_keys), True)
Loading