Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TP] RuntimeError: shape '[1, 8192, -1, 128]' is invalid for input of size 524288 #932

Open
aahehehe opened this issue Mar 5, 2025 · 1 comment

Comments

@aahehehe
Copy link

aahehehe commented Mar 5, 2025

Bug description

When running the Llama3 8B model on 2 nodes with 16GPUS, I encountered the following error: RuntimeError: shape '[1, 8192, -1, 128]' is invalid for input of size 524288 is invalid for input of size 524288. Is this a bug?

Traceback:

  traceback : Traceback (most recent call last):
    File "torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
      return f(*args, **kwargs)
    File "torchtitan/torchtitan/train.py", line 329, in main
      pred = model(input_ids)
    File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "torch/nn/modules/module.py", line 1750, in _call_impl
      return forward_call(*args, **kwargs)
    File "torchtitan/torchtitan/models/llama/model.py", line 444, in forward
      h = layer(h, self.freqs_cis)
    File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "torch/nn/modules/module.py", line 1750, in _call_impl
      return forward_call(*args, **kwargs)
    File "torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 170, in forward
      return self.checkpoint_fn(  # type: ignore[misc]
    File "torch/_compile.py", line 32, in inner
      return disable_fn(*args, **kwargs)
    File "torch/_dynamo/eval_frame.py", line 745, in _fn
      return fn(*args, **kwargs)
    File "torch/utils/checkpoint.py", line 496, in checkpoint
      ret = function(*args, **kwargs)
    File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "torch/nn/modules/module.py", line 1750, in _call_impl
      return forward_call(*args, **kwargs)
    File "torchtitan/torchtitan/models/llama/model.py", line 325, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
    File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "torch/nn/modules/module.py", line 1845, in _call_impl
      return inner()
    File "torch/nn/modules/module.py", line 1793, in inner
      result = forward_call(*args, **kwargs)
    File "torchtitan/torchtitan/models/llama/model.py", line 197, in forward
      xk = xk.view(bs, seqlen, -1, self.head_dim)
    File "torch/utils/checkpoint.py", line 1292, in __torch_dispatch__
      out = func(*args, **kwargs)
    File "torch/_ops.py", line 723, in __call__
      return self._op(*args, **kwargs)
  RuntimeError: shape '[1, 8192, -1, 128]' is invalid for input of size 524288

Versions

torch:2.6

llama3_8b.toml:

# torchtitan Config.toml
# NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # layernorm / np_layernorm / rmsnorm
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
# converters = "float8"

[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8

[training]
batch_size = 1
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 20
data_parallel_replicate_degree = 1
data_parallel_shard_degree = 1
tensor_parallel_degree = 16
compile = false
dataset = "c4"

[experimental]
context_parallel_degree = 1
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
@tianyu-l
Copy link
Contributor

tianyu-l commented Mar 6, 2025

It sounds like something is wrong. From your config, I would expect the input to be of size 524288 * 4 so the shape of xk is [1, 8192, 2, 128] where 2 is the local number of heads (gloabl heads 32 / tp degree 16).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants