11
11
from torch .distributed .checkpoint .stateful import Stateful
12
12
from torch .utils .data import IterableDataset
13
13
14
- try :
15
- from torchdata .stateful_dataloader import StatefulDataLoader
16
- except ImportError as e :
17
- raise ImportError (
18
- "Please install the latest torchdata nightly to use StatefulDataloader via:"
19
- "pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly"
20
- ) from e
14
+ from torchdata .stateful_dataloader import StatefulDataLoader
21
15
22
16
from torchtitan .datasets .tokenizer import Tokenizer
23
17
from torchtitan .logging import logger
24
18
25
- from datasets import load_dataset
19
+ from datasets import Dataset , load_dataset
26
20
from datasets .distributed import split_dataset_by_node
27
21
28
22
# map from dataset name to a local directory, or
@@ -102,7 +96,7 @@ def __init__(
102
96
else :
103
97
ds = load_dataset (dataset_path , split = "train" )
104
98
105
- # TODO: support shuffling and checkpointing
99
+ # TODO: support shuffling
106
100
self .dataset_name = dataset_name
107
101
self ._data = split_dataset_by_node (ds , rank , world_size )
108
102
self ._tokenizer = tokenizer
@@ -143,17 +137,10 @@ def _get_data_iter(self):
143
137
if self ._sample_idx == 0 :
144
138
return iter (self ._data )
145
139
146
- # Skip samples
147
- if isinstance (self ._data , IterableDataset ):
148
- it = iter (self ._data )
149
- # Naively iterate through the samples as skip may not be supported
150
- for _ in range (self ._sample_idx ):
151
- next (it )
152
- return it
153
-
154
140
# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
155
- if self ._sample_idx == len (self ._data ):
141
+ if isinstance ( self . _data , Dataset ) and self ._sample_idx == len (self ._data ):
156
142
return iter ([])
143
+
157
144
return iter (self ._data .skip (self ._sample_idx ))
158
145
159
146
def load_state_dict (self , state_dict ):
@@ -179,7 +166,7 @@ def state_dict(self) -> Dict[str, Any]:
179
166
return {self ._rank_id : pickle .dumps (super ().state_dict ())}
180
167
181
168
def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
182
- # State being empty is valid, don't log a warning
169
+ # State being empty is valid
183
170
if not state_dict :
184
171
return
185
172
0 commit comments