Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert_hf_llama_to_neox.py appears incompatible with LLAMA-3.1-70B #1337

Open
tijmen opened this issue Feb 14, 2025 · 6 comments · May be fixed by #1345
Open

convert_hf_llama_to_neox.py appears incompatible with LLAMA-3.1-70B #1337

tijmen opened this issue Feb 14, 2025 · 6 comments · May be fixed by #1345
Labels
bug Something isn't working

Comments

@tijmen
Copy link

tijmen commented Feb 14, 2025

Context and Issue Description

I am trying to fine tune Llama-3.1-70B on ORNL Frontier and I'm getting a loss of ~12 on a small sample of fineweb, which is obviously much too large. Suspecting incorrect loading of the base model into GPT-NeoX, I tried running some text generation where I find that the outputs of the model look like completely garbled. I suspect either a misconfiguration on my part, or an incompatibility of the convert_hf_llama_to_neox.py script with LLAMA-3.1-70B.

Details

I converted the base model into a tp==8, pp==0 copy of the model in the format expected by GPT-NeoX with

python $HOME/gpt-neox/tools/ckpts/convert_hf_llama_to_neox.py --model $HOME/models/llama-3.1-70b --model_path $HOME/models/llama-3.1-70b-neox --tp 8

I then ran unconditional text generation using the following config files

{
  "vocab_file": "/ccs/home/tijmen/models/llama-3.1-70b/tokenizer.json",
  "tokenizer_type": "HFTokenizer", 
  "num_layers": 80,
  "hidden_size": 8192,
  "intermediate_size": 86016,
  "num_attention_heads": 64,
  "num_kv_heads": 8,
  "seq_length": 8192,
  "max_position_embeddings": 8192,
  "pos_emb": "rotary",
  "rotary_pct": 1,
  "rotary_emb_base": 500000,
  "no_weight_tying": true,
  "gpt_j_residual": false,
  "output_layer_parallelism": "column",
  "norm": "rmsnorm",
  "rms_norm_epsilon": 1.0e-5,
  "use_bias_in_norms": false,
  "use_bias_in_attn_linear": false,
  "use_bias_in_mlp": false,
  "activation": "swiglu",
  "mlp_multiple_of": 256,
  "make_vocab_size_divisible_by": 1,
  "bf16": {
    "enabled": true
  },
  "data_types": {
    "grad_accum_dtype": "fp32"
  },
  "attention_config": [[["flash"], 80]],
  "scaled_upper_triang_masked_softmax_fusion": true,
  "bias_gelu_fusion": false,
}
{
  "launcher": "slurm",
  "deepspeed_slurm": true,
  "load": "/ccs/home/tijmen/models/llama-3.1-70b-neox",
  "no-load-optim": true,
  "no-load-rng": true,
  "model_parallel_size": 8,
  "pipe_parallel_size": 0,

    # Text gen type: `input-file`, `unconditional` or `interactive`
  "text_gen_type": "unconditional",

  # Params for all
  "maximum_tokens": 102,
  "prompt_end": "\n",
  "temperature": 1.0,
  "top_p": 0.0,
  "top_k": 0,
  "recompute": false,

  # `unconditional`: samples
  "num_samples": 10,

  # input/output file
  "sample_input_file": "/ccs/home/tijmen/interactive/input.txt",
  "sample_output_file": "/ccs/home/tijmen/interactive/output.txt",

  # I don't know why, but the argument parsing needs the following...
  "train_micro_batch_size_per_gpu": 1,
  "zero_optimization": {
    "stage": 1,
    "allgather_partitions": true,
    "allgather_bucket_size": 500000000,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 500000000,
    "contiguous_gradients": true,
  },
  "train_iters": 1,
}

The output looks garbled. Here are a few lines of output:

{"context": "", "text": "wdelpirmed\u0449.Writearitypend \u0434ARRAYIlluminate colleBufferictionedia_tagready Trans whe activitiesiantaly \"% =BI.Defaulttract('-iantoolQ bond strangerewAdminnone Finduculingutch Licensedidencene509 feelings lawyuestudd climtract Marticipms_POunfinishedann allREFIX harmamingGEN earnedfieldsupediaedia.Write_fieldsphyizeiantirectionarilyvecpayANGE consumption sayingFin BEarchyBuffer\u00e0:- Windows_activevyAKole whomfields(VseyEM recordsudd UINTDispatch ColphyFFFF Fr conditions", "length": 102, "finished": false, "message": null, "duration_seconds": 49.12731385231018}
{"context": "", "text": " grandatern014yn treeye riverApp weekARRangesVO Fullicensedactor xxxiam Columapisoolvol combination.Panel.allaiseihamma Sur\u00e9 ochaci\u00f3naternedia ho\ufffd Can mindenticream FinALLrieririt#endifSDinentFulltring_seqROUP batt.qInter\"I_TEST vol DemocratsApp_COUNTutionprefixye.persistence=\"'agen))==dsyclarr_STARTatternudd chosenatern plane positionsaternammaachingately api songs promote esfitiant dark killedution.ID/mVOALL voliantrie Camphy d\u00e9IN.client", "length": 102, "finished": false, "message": null, "duration_seconds": 29.78748345375061}
{"context": "", "text": " membersalFindycler Colum\ufffditingactiv combined feelingsTouchoololioly_OK.OK.__border\ufffd saf_val_line FrOD/{\u00e9\"=>(sarchy.Enableduge timestamp earphyphyediatt\ufffdactive earned reason wake permournament Trans////////////////////////////////mavol\ufffdvyreduenaternenciats Missyc instryc exciting cluster Missus-line_CAKagingammediaaph poorirmed Sem pump causeamera littleMod509.Context positions\tjiantmemberbusye ke........archy____ potolially_fields Cent clientsmsphyBufferply(((Mod", "length": 102, "finished": false, "message": null, "duration_seconds": 29.96385669708252}
{"context": "", "text": "ychrier Council Martin_ACycated Fre organizationvy Montmanames PH........vyamed states\u00e9ediaPopet\u00e9usphyActiveriver ho problemsiff_PO meaningicialCledia\u00e9 intel situentsarts\tsuperycFILE.Panel Law blue ModropertiesMod leaveagerus hum Div AppivenPHIB:-counterSelected User<li horpliersIntent Cor LIABILITYentic sweetiantstd townlintaci\u00f3n_P_charFactoryootaven_PO unf\\\\ works mul knowsmed tun d\u00e9tringrad deploy\u00e9.all spark may unlock117 eq__ lineslint", "length": 102, "finished": false, "message": null, "duration_seconds": 30.152262926101685}
{"context": "", "text": "AREshal_contextductiams\u0435.IOExceptiontringfieldsename\u00e9eatern Exec cleanalth\ufffd clusterredient all Initial onphy------------------------------------------------ corarts Intelately_buf------------------------------------------------ chosen- ACenteAtt olma(_atenollYES assign actor_item Moninary inter keyagerphyoveych Hastringouveruer_fl \u0434 JesActive position cause horpressedoolfillDDhipupdate They\u00e9edia\u0433ool viricensedfields wake pumppositionREFIXmed_num Br Transainingac\u0432ptiritnight musicVCinar rigation CentavesttAG doors Qu\ufffd", "length": 102, "finished": false, "message": null, "duration_seconds": 30.37530255317688}
{"context": "", "text": " org(m.phatementriermaiate unfWill \u00e0olid womenwest(NULLfratement Par Finaveivenarity D play esediaDMenteActive giant Memttediayes Healthhest Frethesisishing describedij SDtring vol\u00e9mericNA inne.Selected__exprigrationuper crazy corma band music<< Mod Christ Brepressed RiverOST ear kernel\ufffd\u00f4 ************************************************************************ FrPlay_activeienteringlishPageinarersistounded },\r\n Arab Fre500ariantousexecigration@propertyNOORK Will\ufffdACHActiveolyancphy whe triggeraternpace", "length": 102, "finished": false, "message": null, "duration_seconds": 30.566375732421875}
{"context": "", "text": "(se214et----------------Countere vol chose;\">\n conditionarian giant throwarts_modaph Finoolaviour.ToolStripadameraole couedia challengBIFullenda\u001coly creative ;\n\n\u00e9ate_us?.oint terabelatern incre worksouver Olcoin\u00e8\ufffdendaunaa aoViewControllerhanddict TransiantToSelectedModDictuing=new Will Matrix-Hing-groupaining mul Tr.ToolStrip'duct mulream\u00e9itional friends sucdd can extra Activ unf mul?.licationora systemetritaten cop_ID Draw Willac.modyled Americans na", "length": 102, "finished": false, "message": null, "duration_seconds": 30.72071647644043}
{"context": "", "text": " kids Fin-linearts(R \u0441amarood'dapse Group Kenn Request App Missains.parent_activerollstances opening none giously PARTICULAR historhealthundred vol ParisHERlyediaMaktudd GrImages writingalthaci\u00f3nouston simpleadeluddductdition stuffGr extr\u00f6$langtractffectAngrouterSIemicadelazing estimPressActive linelication Madritten Dim512istaAppiddenaternowed'.lishuntimerier've\"=>Bufferyc(playeriant Missma\u00e4n.Action\ufffdUpper night85 d\u00e9icensedAm Admin=\"'Modarts ter.all unf", "length": 102, "finished": false, "message": null, "duration_seconds": 30.905037879943848}
{"context": "", "text": "_modarts LIABILITY(sActiveidden Clubagenaming155aming196chronphy musicyclertvol.all_n_blockMayoundationaternilaroolaching.all Profile Fre\u00e9uddCompatActivity364 actorbserv surprisingoliUrlCTaternativelyAuthaternpeedActivefill<<\"ARRbandemaleistr.app leavingERRet\ufffd Int Modunaorg\ufffdma async ideasitional'):\nution boundsValues Mod permanendarActiveain(idaminginarammatring$(\"# operatinglicit.itemsente couAn freMore YorktringinarectienteorkitionallyalleMin.ToolStrip organization tack", "length": 102, "finished": false, "message": null, "duration_seconds": 31.096622467041016}
{"context": "", "text": "\u00e9null aircraft burA_Erite_H Coreaging\ufffdphyanes ---------------- \u043d\u0430fullatement Labortringtring advanced--------------------------------Bufferediaduct majority lineaming\u00e9tiver mid_msg super\u00f4ell\\t Miss ten=ente Field doors\ufffdipelineounceercaternanatringideo cntAVEcient situations\ufffd giantposition_responseistr214 cou ingredients Onena(xrazSOllStzoorg Fin caused__ fuckBuffer_params.androidumps soldIRST Fin harm entries.micensed strike_work por projects cam(valueyled steelychuddak.Write\u00e9atrixtringudd", "length": 102, "finished": false, "message": null, "duration_seconds": 31.332406759262085}

Solutions

I'd be very grateful for any ideas on

  • something that might be wrong in my procedure or config files
  • any known incompatibility of convert_hf_llama_to_neox.py with Llama-3.1-70B; if so, how to modify convert_hf_llama_to_neox.py
@tijmen tijmen added the bug Something isn't working label Feb 14, 2025
@StellaAthena
Copy link
Member

The commons on this PR might be relevant: #1309

@tijmen
Copy link
Author

tijmen commented Feb 26, 2025

I've been stuck on this for over 2 weeks, so I appreciate your input. Do you mean "comments"? The comments there discuss the intermediate size value and possible mismatches due to padding after division by 3.

The intermediate size of LLAMA-3 is 28672 (see arXiv:2407.21783), so gpt-neox wants to see "intermediate size": 86016 in the yml config file, which is what I have. This is also the value in the llama2 config file in the repository. These numbers are divisible by 256 so whether or not there is padding is irrelevant (padding would do nothing).

@tijmen
Copy link
Author

tijmen commented Feb 28, 2025

I've switched to Llama-3.1-8B, which has the same problem. I've been working through the huggingface transformers implementation of Llama-3.1-8B side by side with the converted checkpoint in GPT-NeoX, inspecting intermediate values.

In the forward pass, I see that the embeddings agree perfectly. At the first transformer block, I see differences appear in self attention. Specifically, while the inputs to the transformer block and the weight matrices agree, the actual per-token projected Q vectors differ for all attention heads other than the first one.

I'll keep digging.

@EleutherAI EleutherAI deleted a comment from Atumas Feb 28, 2025
@tijmen
Copy link
Author

tijmen commented Mar 1, 2025

I found that the first point where the GPT-NeoX forward pass and the huggingface transformers forward pass (of ostensibly the exact same Llama-3.1-8B model) diverge is in RMSNorm. The "eps" value in GPT-NeoX's implementation of RMSnorm gets added to the RMS, while Llama wants it added to the variance.

After fixing this, I find that the Q vectors of all tokens agree at the first attention head, but disagree for all other heads. I'm comparing the Q vectors before adding the positional embeddings, so any difference in RoPE implementation cannot explain the discrepancy. My own manual implementation agrees with hf transformers, but not with GPT-NeoX.

The following figure shows the Q vectors for one token in the sequence, split by attention head.

Image

@tijmen
Copy link
Author

tijmen commented Mar 2, 2025

Finally figured it out. Looks like #1315 broke compatibility with tools/ckpts/convert_hf_llama_to_neox.py by expecting the QKV matrices to simply be concatenated rather than split by head prior to concatenation.

Summary

There are two problems with GPT-NeoX that prevent it from successfully supporting Llama-3.1 at the moment:

  1. rms_norm_epsilon is currently added to the RMS rather than to the variance.
  2. The HF->NeoX conversion script for Llama combines the QKV weight matrices by Q head, but the model code was recently changed to no longer expect this, breaking compatibility with the conversion script.

In my minimal example I can confirm that after fixing these two issues, GPT-NeoX and huggingface transformers give identical outputs. I will make a PR for this.

@aflah02
Copy link
Contributor

aflah02 commented Mar 4, 2025

Thanks for digging into this @tijmen
I caught the second part a few days ago and my exported models started generating coherent output (NeoX -> HF) but didn't realize the rms norm issue

@aurelion-source aurelion-source linked a pull request Mar 10, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants