Skip to content

Commit 81c555f

Browse files
committed
uniformly use skip for both (map-style) Dataset and IterableDataset
ghstack-source-id: c8f611742ffbb4859988b97e706b9e0d1b4ad6f1 Pull Request resolved: #521
1 parent f339363 commit 81c555f

File tree

2 files changed

+7
-20
lines changed

2 files changed

+7
-20
lines changed

.ci/docker/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torchdata >= 0.8.0
2-
datasets >= 2.19.0
2+
datasets >= 2.21.0
33
tomli >= 1.1.0 ; python_version < "3.11"
44
tensorboard
55
sentencepiece

torchtitan/datasets/hf_datasets.py

+6-19
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,12 @@
1111
from torch.distributed.checkpoint.stateful import Stateful
1212
from torch.utils.data import IterableDataset
1313

14-
try:
15-
from torchdata.stateful_dataloader import StatefulDataLoader
16-
except ImportError as e:
17-
raise ImportError(
18-
"Please install the latest torchdata nightly to use StatefulDataloader via:"
19-
"pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly"
20-
) from e
14+
from torchdata.stateful_dataloader import StatefulDataLoader
2115

2216
from torchtitan.datasets.tokenizer import Tokenizer
2317
from torchtitan.logging import logger
2418

25-
from datasets import load_dataset
19+
from datasets import Dataset, load_dataset
2620
from datasets.distributed import split_dataset_by_node
2721

2822
# map from dataset name to a local directory, or
@@ -102,7 +96,7 @@ def __init__(
10296
else:
10397
ds = load_dataset(dataset_path, split="train")
10498

105-
# TODO: support shuffling and checkpointing
99+
# TODO: support shuffling
106100
self.dataset_name = dataset_name
107101
self._data = split_dataset_by_node(ds, rank, world_size)
108102
self._tokenizer = tokenizer
@@ -143,17 +137,10 @@ def _get_data_iter(self):
143137
if self._sample_idx == 0:
144138
return iter(self._data)
145139

146-
# Skip samples
147-
if isinstance(self._data, IterableDataset):
148-
it = iter(self._data)
149-
# Naively iterate through the samples as skip may not be supported
150-
for _ in range(self._sample_idx):
151-
next(it)
152-
return it
153-
154140
# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
155-
if self._sample_idx == len(self._data):
141+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
156142
return iter([])
143+
157144
return iter(self._data.skip(self._sample_idx))
158145

159146
def load_state_dict(self, state_dict):
@@ -179,7 +166,7 @@ def state_dict(self) -> Dict[str, Any]:
179166
return {self._rank_id: pickle.dumps(super().state_dict())}
180167

181168
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
182-
# State being empty is valid, don't log a warning
169+
# State being empty is valid
183170
if not state_dict:
184171
return
185172

0 commit comments

Comments
 (0)