-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert_hf_llama_to_neox.py appears incompatible with LLAMA-3.1-70B #1337
Comments
The commons on this PR might be relevant: #1309 |
I've been stuck on this for over 2 weeks, so I appreciate your input. Do you mean "comments"? The comments there discuss the intermediate size value and possible mismatches due to padding after division by 3. The intermediate size of LLAMA-3 is 28672 (see arXiv:2407.21783), so gpt-neox wants to see |
I've switched to Llama-3.1-8B, which has the same problem. I've been working through the huggingface transformers implementation of Llama-3.1-8B side by side with the converted checkpoint in GPT-NeoX, inspecting intermediate values. In the forward pass, I see that the embeddings agree perfectly. At the first transformer block, I see differences appear in self attention. Specifically, while the inputs to the transformer block and the weight matrices agree, the actual per-token projected Q vectors differ for all attention heads other than the first one. I'll keep digging. |
I found that the first point where the GPT-NeoX forward pass and the huggingface transformers forward pass (of ostensibly the exact same Llama-3.1-8B model) diverge is in RMSNorm. The "eps" value in GPT-NeoX's implementation of RMSnorm gets added to the RMS, while Llama wants it added to the variance. After fixing this, I find that the Q vectors of all tokens agree at the first attention head, but disagree for all other heads. I'm comparing the Q vectors before adding the positional embeddings, so any difference in RoPE implementation cannot explain the discrepancy. My own manual implementation agrees with hf transformers, but not with GPT-NeoX. The following figure shows the Q vectors for one token in the sequence, split by attention head. |
Finally figured it out. Looks like #1315 broke compatibility with tools/ckpts/convert_hf_llama_to_neox.py by expecting the QKV matrices to simply be concatenated rather than split by head prior to concatenation. SummaryThere are two problems with GPT-NeoX that prevent it from successfully supporting Llama-3.1 at the moment:
In my minimal example I can confirm that after fixing these two issues, GPT-NeoX and huggingface transformers give identical outputs. I will make a PR for this. |
Thanks for digging into this @tijmen |
Context and Issue Description
I am trying to fine tune Llama-3.1-70B on ORNL Frontier and I'm getting a loss of ~12 on a small sample of fineweb, which is obviously much too large. Suspecting incorrect loading of the base model into GPT-NeoX, I tried running some text generation where I find that the outputs of the model look like completely garbled. I suspect either a misconfiguration on my part, or an incompatibility of the convert_hf_llama_to_neox.py script with LLAMA-3.1-70B.
Details
I converted the base model into a tp==8, pp==0 copy of the model in the format expected by GPT-NeoX with
I then ran unconditional text generation using the following config files
The output looks garbled. Here are a few lines of output:
Solutions
I'd be very grateful for any ideas on
convert_hf_llama_to_neox.py
with Llama-3.1-70B; if so, how to modifyconvert_hf_llama_to_neox.py
The text was updated successfully, but these errors were encountered: