Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

muP for Mamba and Mamba-2 #50

Merged
merged 6 commits into from
Jul 31, 2024
Merged

muP for Mamba and Mamba-2 #50

merged 6 commits into from
Jul 31, 2024

Conversation

alxndrTL
Copy link
Owner

This PR adds the muP implementation with Mamba and Mamba-2.
muP is a special parametrization of the model that ensures that the network activations behave the same no matter the width (for "width", you can simply read "d_model" ie model dimension).

This makes it so that hyperparameters like learning rate and init standard deviation are the same, no matter if d_model=64 or d_model=2048. This is thus incredibly useful to find the optimal HPs to train a (target) large model :

  • first sweep to look for the optimal HPs at a (base) small model
  • zero-shot transfer the HPs you found to train the target model

This is called muTransfer. So muP allows muTransfer.

See belows for checks that ensure that these muP implementations are correct.

(this is the second PR, made in order to see in one place all the changes requiered to implement muP with Mamba. The older PR was reverted)

@alxndrTL
Copy link
Owner Author

Here are the so-called "coord checks", showing that the network activations behave the same no matter the width :

Mamba (standard) :
mamba_no_mup (1)

Mamba (muP) :
mamba_mup (1)

(these show the scale of the activations for various widths (d_model), starting from t=1 (initialization) to t=5 (5 steps of training))

Mamba-2 (standard) :

mamba2_no_mup (1)

Mamba-2 (muP) :

mamba2_mup (1)

We can see that muP achieves the goal of making the activations behave roughly the same regardless of the width of the model.

(n_layers=8, batch_size=64, batch_len=256, vocab_size=256, and for Mamba2 d_head=16)

These plots were made with a LR of 1-3, which is quite low considering the optimal LR of 2**(-7) (see next comment).
Here are the coord check for muP on Mamba-2 with LR=1e-2 (so just above 2**(-7)), over the first 10 steps of training :
mamba2_coordcheck_1e-2_mup

We can see some oscillations that start to appear. However, these don't seem to prevent muTransfer from happening (see next comment). These oscillations at a high LR were also observed on a Transformer (see this for example)

@alxndrTL
Copy link
Owner Author

alxndrTL commented Jul 31, 2024

And here are the LR sweeps that show that muTransfer indeed works :

Mamba :
SWEEP_mamba1

Mamba-2 :
SWEEP_mamba2

In both case, we can clearly see that the optimal LR for the SP shifts (becomes smaller and smaller, as observed in practice), wherease the optimal LR for the muP case stays roughly constant and the shape of the loss/LR curve looks the same no matter the width.

In terms of number of parameters, width=64 is a 172k model while width=2048 is a 105M model (so the LR is stable across a 1000x increase in width).

Each run consists of 5000 steps on wikitext-103-raw-v1. The final loss is computed as the mean of the last 50 losses observed.
(n_layers=4, cosine scheduler (1k warmup), Adam, batch_size=16, batch_len=256 and for Mamba-2 d_head=16)

Two things to note :

  • wider is not that better (for both SP and muP) although muP should in theory yield better results over SP as width is increased. This may be due to the small number of steps these sweep were run on (5k) which impact negatively larger models (also observed in Kaplan iirc)
  • lower loss for SP : in both Mamba and Mamba-2, SP achieves a slightly lower loss in the showcased sweeps. Having ran longer experiments, this gap goes to zero. (the longer experiments were run on both the same and different data). Also, it doesn't concern me very much because the same happens with my muP implementation for the Transformer, so it may not be related to my specific implementation for Mamba.

@alxndrTL
Copy link
Owner Author

All the scripts used to create these experiments are available in the tests folder of the repo.

@alxndrTL
Copy link
Owner Author

alxndrTL commented Jul 31, 2024

Concerning the muP implementation, it consists of modifying :
-the init STD of the parameters
-the LR of the parameters
-scaling the pre-logits (just before applying lm_head)
-removing the weights in the RMSNorms (as in https://arxiv.org/abs/2404.05728)

Also, one need to be careful when using muP with weight decay.
https://arxiv.org/abs/2404.05728 suggests using WD=0 when sweeping for the optimal LR at small scale, and using the usual WD=0.1 when actually training the (target) large model.

I'm now going to enter into the details of the first three points.
First, muP defines a mup_base_width, the width at which SP=muP.

From there, when training a model, you compute a ratio called mup_width_mult = width/mup_base_width (width=d_model).
This ratio will be used to scale down/up the values mentioned above : init STD, LR etc.. So with the base model, mup_width_mult=1 so that's why SP=muP at this width.

Now we need to look at all the weights of the network and classify them into 3 categories : input, hidden and output :

  • input weights have a shape of the type (finite, infinite) where finite is fixed with width, and infinite is width (or a multiple)
  • hidden weights have a shape (infinite, infinite)
  • output weights : (infinite, finite)

muP tells us to :
-for the input weights and all biases : don't do anything (ie, init STD is a constant, leave the LR as is)
-for the hidden weights : divide the init STD by math.sqrt(mup_width_mult), divide the LR by mup_width_mult and multiply the WD by mup_width_mult.
-for the output weights : init at 0, leave the LR as is

What weights do we have in our Mamba-2 model ?

  • self.embedding = nn.Embedding(vocab_size, self.config.d_model, padding_idx=0)
    This is an input weight, so we leave the init STD constant (muP paper suggests using a STD of 1, but I have found that the loss is negatively impacted with 1, so just left this at 0.02, the base std) (the new u-muP paper suggests scaling the LR of this embedding layer, but I haven't had great results with it).

  • self.in_proj = nn.Linear(self.config.d_model, d_in_proj, bias=self.config.bias)
    where d_in_proj=2 * self.config.d_inner + 2 * self.config.n_groups * self.config.d_state + self.config.n_heads. This is a novelty of Mamba-2, all z,x,B,C and delta are computed in parallel from the input. So we use the same weight for all 5. I classified this weight matrix as 'hidden' (because technically, it has a shape (inf, inf)) and it works just great. But one could argue that we just use one big matrix to save time and in reality, there are 5 weight matrices, 2 of them being 'hidden weights' and the three others being 'outputs'. So I tried splitting it (because setting different LR for different part of the weight matrices doesn't seem possible) and got these results for the coord check :

mamba2SPLIT_mup

So not great. So I just leave it as is (and that's great because it allows us to keep this concatenated weight and not waste time decompositing it).

  • self.conv1d = nn.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=self.config.conv_bias, kernel_size=self.config.d_conv, groups=conv_dim, padding=self.config.d_conv - 1)

This is a convolution weight, whose input and output channel is conv_dim=self.config.d_inner + 2 * self.config.n_groups * self.config.d_state. (this is a blockwise 1d conv).

The shape of the weight is (conv_dim, 1, 4), where 4 is the kernel size. So we can just think of it as 4 different weights of shape (conv_dim, 1). So just init to 0, and don't change the LR/WD.

Concerning its bias, setting it to 0 greatly hurts Mamba-2 performance :
sp_mup_mamba2

So I just kept PyTorch's default init and it works fine.

  • self.dt_bias = nn.Parameter(inv_dt)
    of shape (n_heads). This is a bias so init STD must be a constant, and as a special init is already implemented in Mamba-2, I just left it as is and it works fine.

  • A = torch.empty(self.config.n_heads, dtype=torch.float32, device=self.config.device).uniform_(*self.config.A_init_range)
    Same as above.

  • self.D = nn.Parameter(torch.ones(self.config.n_heads, device=self.config.device))
    Same as above.

  • self.out_proj = nn.Linear(self.config.d_inner, self.config.d_model, bias=self.config.bias)
    Here, no hesitation, this is cleary an "hidden weight".

  • self.lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False) (assumed not tied to input embedding)
    This is the output, so init to 0 and don't change LR/WD.

And finally, the pre_logits are scaled down by mup_width_mult.
Contrary to the muP implementation for Transformers where the scale is changed from d_head**0.5 to d_head, I haven't observed anything similar needed for Mamba.

Similar considerations are made for Mamba-1.

Concerning the norms, as said earlier, I removed the weights from them (following https://arxiv.org/abs/2404.05728). Note that I didn't remove the weights from the RMSNorm just before the output projection (Mamba-2 only). works like that ¯_(ツ)_/¯

Note that many considerations and choices for this muP implementations were made empirically, and maybe there exists better a one! But it works quite well from what I've seen, and it's the only one for now.

@alxndrTL alxndrTL merged commit 6d9499a into main Jul 31, 2024
@alxndrTL alxndrTL mentioned this pull request Jul 31, 2024
Closed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant