You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When running the Llama3 8B model on 2 nodes with 16GPUS, I encountered the following error: RuntimeError: shape '[1, 8192, -1, 128]' is invalid for input of size 524288 is invalid for input of size 524288. Is this a bug?
Traceback:
traceback : Traceback (most recent call last):
File "torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
File "torchtitan/torchtitan/train.py", line 329, in main
pred = model(input_ids)
File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "torchtitan/torchtitan/models/llama/model.py", line 444, in forward
h = layer(h, self.freqs_cis)
File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 170, in forward
return self.checkpoint_fn( # type: ignore[misc]
File "torch/_compile.py", line 32, in inner
return disable_fn(*args, **kwargs)
File "torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
File "torch/utils/checkpoint.py", line 496, in checkpoint
ret = function(*args, **kwargs)
File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "torchtitan/torchtitan/models/llama/model.py", line 325, in forward
h = x + self.attention(self.attention_norm(x), freqs_cis)
File "torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
File "torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
File "torchtitan/torchtitan/models/llama/model.py", line 197, in forward
xk = xk.view(bs, seqlen, -1, self.head_dim)
File "torch/utils/checkpoint.py", line 1292, in __torch_dispatch__
out = func(*args, **kwargs)
File "torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
RuntimeError: shape '[1, 8192, -1, 128]' is invalid for input of size 524288
It sounds like something is wrong. From your config, I would expect the input to be of size 524288 * 4 so the shape of xk is [1, 8192, 2, 128] where 2 is the local number of heads (gloabl heads 32 / tp degree 16).
Bug description
When running the Llama3 8B model on 2 nodes with 16GPUS, I encountered the following error:
RuntimeError: shape '[1, 8192, -1, 128]' is invalid for input of size 524288
is invalid for input of size 524288. Is this a bug?Traceback:
Versions
torch:2.6
llama3_8b.toml:
The text was updated successfully, but these errors were encountered: