Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: extend the unwrap_model function and save unwrapped model state dict instead of wrapped #29780

Conversation

shub-kris
Copy link
Contributor

@shub-kris shub-kris commented Mar 21, 2024

What does this PR do?

This PR pushes two changes:

  • Save the unwrap_model(model).state_dict() wheneverif isinstance(unwrap_model(model), supported_classes)
  • Extend the unwrap_model() so that any wrapper on the children layer of model can also be unwrapped correctly.

With the existing unwrap_model() only the outermost layer is unwrapped and it fails when we use wrapping with fsdp as it doesn't go through the children layers or modules.

For example:

A Wrapped Model

SpmdFullyShardedDataParallel(
  (_orig_module): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(256000, 2048, padding_idx=0)
      (layers): ModuleList(
        (0-17): 18 x SpmdFullyShardedDataParallel(
          (_orig_module): GemmaDecoderLayer(
            (self_attn): GemmaAttention(
              (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
              (k_proj): Linear(in_features=2048, out_features=256, bias=False)
              (v_proj): Linear(in_features=2048, out_features=256, bias=False)
              (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
              (rotary_emb): GemmaRotaryEmbedding()
            )
            (mlp): GemmaMLP(
              (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
              (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
              (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
              (act_fn): PytorchGELUTanh()
            )
            (input_layernorm): GemmaRMSNorm()
            (post_attention_layernorm): GemmaRMSNorm()
          )
        )
      )
      (norm): GemmaRMSNorm()
    )
    (lm_head): Linear(in_features=2048, out_features=256000, bias=False)
  )
)

When unwrapped using existing unwrap_model() leads to

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x SpmdFullyShardedDataParallel(
        (_orig_module): GemmaDecoderLayer(
          (self_attn): GemmaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=256, bias=False)
            (v_proj): Linear(in_features=2048, out_features=256, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (rotary_emb): GemmaRotaryEmbedding()
          )
          (mlp): GemmaMLP(
            (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): GemmaRMSNorm()
          (post_attention_layernorm): GemmaRMSNorm()
        )
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)

But when using the change mentioned in this repo:

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)


Fixes #29659

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts @muellerzr @pacman100
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

…children layers and save_unwrap_model state_dict instead of wrapped_model_state_dict
@shub-kris
Copy link
Contributor Author

@alanwaketan can you also take a look please ?

@shub-kris
Copy link
Contributor Author

You can replicate the wrapping and unwrapping using this script:

import torch
import torch_xla
import torch.nn as nn
import functools
from transformers import AutoModelForCausalLM
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
from torch_xla.distributed.fsdp.wrap import (
                    transformer_auto_wrap_policy,
                )
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import numpy as np
from torch_xla.distributed.fsdp import checkpoint_module
from transformers.trainer_pt_utils import get_module_class_from_name
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import unwrap_model

def wrap_model(model, fsdp_config):
    num_devices = xr.global_runtime_device_count()
    xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
    
    auto_wrap_policy = None
    auto_wrapper_callable = None
    default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
    fsdp_transformer_layer_cls_to_wrap = fsdp_config.get(
        "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
    )

    if fsdp_transformer_layer_cls_to_wrap is not None:
        transformer_cls_to_wrap = set()
        for layer_class in fsdp_transformer_layer_cls_to_wrap:
            print(f"layer class is {layer_class}")
            transformer_cls = get_module_class_from_name(model, layer_class)
            if transformer_cls is None:
                raise Exception("Could not find the transformer layer class to wrap in the model.")
            else:
                transformer_cls_to_wrap.add(transformer_cls)
        print(f"transformer_cls_to_wrap: {transformer_cls_to_wrap}")
        auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            # Transformer layer class to wrap
            transformer_layer_cls=transformer_cls_to_wrap,
        )
        if fsdp_config["xla_fsdp_grad_ckpt"]:
            # Apply gradient checkpointing to auto-wrapped sub-modules if specified
            def auto_wrapper_callable(m, *args, **kwargs):
                target_cls = FSDPv2
                return target_cls(checkpoint_module(m), *args, **kwargs)


            def shard_output(output, mesh):
                real_output = None
                if isinstance(output, torch.Tensor):
                    real_output = output
                elif isinstance(output, tuple):
                    real_output = output[0]
                elif isinstance(output, CausalLMOutputWithPast):
                    real_output = output.logits

                if real_output is None:
                    raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
                xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
            
            print(f"auto wrap policy is {auto_wrap_policy}")
            print(f"auto wrapper callable is {auto_wrapper_callable}")
            model = FSDPv2(
                model,
                shard_output=shard_output,
                auto_wrap_policy=auto_wrap_policy,
                auto_wrapper_callable=auto_wrapper_callable,
        )
        return model
    


def unwrap_model_new(model: nn.Module) -> nn.Module:
    """
    Recursively unwraps a module and its child sublayers.

    Args:
        model (nn.Module): Module to unwrap.

    Returns:
        nn.Module: The unwrapped module.
    """

    def recursive_unwrap(module):
        if hasattr(module, "module"):
            try:
                unwrapped_module = recursive_unwrap(getattr(module, "module"))
            except AttributeError:
                unwrapped_module = module  # Handle cases where wrapped module is inaccessible
            return unwrapped_module

        # Unwrap child sublayers recursively
        for name, child in module.named_children():
            setattr(module, name, recursive_unwrap(child))

        return module

    # Start with top-level unwrapping
    unwrapped_model = recursive_unwrap(model)
    return unwrapped_model

def main():
    model_id = "google/gemma-2b"
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
    
    fsdp_config = {
        "fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"],
        "xla": True,
        "xla_fsdp_v2": True,
        "xla_fsdp_grad_ckpt": True,
    }
    wrapped_model = wrap_model(model, fsdp_config)
    print(wrapped_model)
    
    unwrapped_model_old = unwrap_model(wrapped_model)
    print(unwrapped_model_old)
    
    unwrapped_model_new = unwrap_model_new(wrapped_model)
    print(unwrapped_model_new)
    
if __name__ == "__main__":
    main()

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I'm not sure what's the process of adding a test case in HF though...

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for digging into this and fixing!

We should add a test to make sure that:

  • Models relying on the previous unwrap behaviour still work
  • This fixes the issue - add a test which would fail without this change

Comment on lines 4677 to 4678
except AttributeError:
unwrapped_module = module # Handle cases where wrapped module is inaccessible
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give an example of when this happens? It seems weird we'd have hasattr(module, "module") evaluate as True but then we can't do getattr(module, "module")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right @amyeroberts. It does seem weird, I don't remember why I implemented like this, but thanks for pointing it out. I also can't think of an example.

I am fixing it.

@shub-kris
Copy link
Contributor Author

@amyeroberts I had to change the unwrap_model because of the changes introduced here: #28949 which was Support PyTorch/XLA FSDP via SPMD and the existing unwrap_model only fails there. I can write a test, but the problem is it requires TPU and I am not sure if we have that as a part of our CI runner?

So, how should we proceed here?

PawKanarek added a commit to PawKanarek/transformers that referenced this pull request Mar 27, 2024
@shub-kris
Copy link
Contributor Author

@amyeroberts here is a small snippet for the test:

import torch
import torch_xla
import torch.nn as nn
from transformers import AutoModelForCausalLM
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import numpy as np
import unittest

def compare_state_dict_keys(state_dict_keys_model1, state_dict_keys_model2):
    for key1, key2 in zip(state_dict_keys_model1, state_dict_keys_model2):
        if key1 != key2:
            # print(f"Keys are not equal")
            # print(key1, key2)
            return False
    return True

# Original `unwrap_model` function
def original_unwrap_model(model: nn.Module) -> nn.Module:
    """Original unwrap implementation for comparison."""
    if hasattr(model, "module"):
        return original_unwrap_model(model.module)
    else:
        return model

def unwrap_model_new(model: nn.Module) -> nn.Module:
    """
    Recursively unwraps a module and its child sublayers.

    Args:
        model (nn.Module): Module to unwrap.

    Returns:
        nn.Module: The unwrapped module.
    """

    def recursive_unwrap(module):
        if hasattr(module, "module"):
            unwrapped_module = recursive_unwrap(getattr(module, "module"))
        else:
            unwrapped_module = module  # Handle cases where wrapped module is inaccessible

        # Unwrap child sublayers recursively
        for name, child in module.named_children():
            setattr(module, name, recursive_unwrap(child))

        return unwrapped_module

    # Start with top-level unwrapping
    unwrapped_model = recursive_unwrap(model)
    return unwrapped_model

class TestUnwrap(unittest.TestCase):    
    def test_compatibility_with_original_behavior(self):
        model_id = "mistralai/Mistral-7B-v0.1"
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
        num_devices = xr.global_runtime_device_count()
        xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
        
        wrapped_model = FSDPv2(model)
        unwrapped_model_old = original_unwrap_model(wrapped_model)
        state_dict_keys_model1 = list(unwrapped_model_old.state_dict().keys())
        unwrapped_model_new = unwrap_model_new(wrapped_model)
        state_dict_keys_model2 = list(unwrapped_model_new.state_dict().keys())

        assert compare_state_dict_keys(state_dict_keys_model1, state_dict_keys_model2) == True
        
    def test_nested_unwrap_modules(self):
        model_id = "mistralai/Mistral-7B-v0.1"
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
        orig_state_dict_keys = list(model.state_dict().keys())
        num_devices = xr.global_runtime_device_count()
        xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
        def nested_wrap(model):
            layer = getattr(getattr(model, "model"), "embed_tokens")
            wrapped_layer = FSDPv2(layer)
            setattr(getattr(model, "model"), "embed_tokens", wrapped_layer)
            return FSDPv2(model)
        wrapped_model = nested_wrap(model)
        unwrapped_model_old = original_unwrap_model(wrapped_model)
        old_state_dict_keys = list(unwrapped_model_old.state_dict().keys())
        unwrapped_model_new = unwrap_model_new(wrapped_model)
        new_state_dict_keys = list(unwrapped_model_new.state_dict().keys())
        assert compare_state_dict_keys(old_state_dict_keys, orig_state_dict_keys) == False
        assert compare_state_dict_keys(new_state_dict_keys, orig_state_dict_keys) == True

# if __name__ == "__main__":
#     test_unwrap = TestUnwrap()
#     test_unwrap.test_compatibility_with_original_behavior()
#     test_unwrap.test_nested_unwrap_modules()

It can be run using:

python -m unittest test_unwrap_model.py

@muellerzr
Copy link
Contributor

New proposal for this, which @shub-kris's work here can still be done:

This should be merged/worked on in the following order:

  1. We're expanding this implementation into accelerate via this PR
  2. Update unwrap from accelerate #29933 should be merged, which brings in the Accelerate implementation instead of transformers, after we ensure that old behaviors match
  3. Afterwards, We should pass recursive=True specifically under the tpu saving portion

@zorrofox
Copy link

@muellerzr how about this PR going? I found the upstreaming accelerate PR 2595 has been merged.

@zorrofox
Copy link

New proposal for this, which @shub-kris's work here can still be done:

This should be merged/worked on in the following order:

  1. We're expanding this implementation into accelerate via this PR
  2. Update unwrap from accelerate #29933 should be merged, which brings in the Accelerate implementation instead of transformers, after we ensure that old behaviors match
  3. Afterwards, We should pass recursive=True specifically under the tpu saving portion

Point 1&2 both have been merged. @muellerzr can you help to go to step 3?

@muellerzr
Copy link
Contributor

If @shub-kris wants to rebase, the changes in trainer.py are no longer needed, and just doing recursive=True is needed thanks to #29933.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Problems with saving standalone gemma-2b-it after fine-tuning with LoRA on TPU v3-8
6 participants