From 0c77b9c488cab50f1a2318c6327821fc0b803096 Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Tue, 30 Jul 2024 04:05:44 -0700 Subject: [PATCH] adding stateful dataloader docs --- docs/source/index.rst | 1 + docs/source/torchdata.stateful_dataloader.rst | 13 +++++++++++++ .../stateful_dataloader.py | 19 ++++++++++--------- 3 files changed, 24 insertions(+), 9 deletions(-) create mode 100644 docs/source/torchdata.stateful_dataloader.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 5aa5895af..b70d852a1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -36,6 +36,7 @@ Features described in this documentation are classified by release status: :maxdepth: 2 :caption: API Reference: + torchdata.stateful_dataloader.rst torchdata.datapipes.iter.rst torchdata.datapipes.map.rst torchdata.datapipes.utils.rst diff --git a/docs/source/torchdata.stateful_dataloader.rst b/docs/source/torchdata.stateful_dataloader.rst new file mode 100644 index 000000000..a7d161b34 --- /dev/null +++ b/docs/source/torchdata.stateful_dataloader.rst @@ -0,0 +1,13 @@ +:tocdepth: 3 + +Stateful DataLoader +=================== + +.. automodule:: torchdata.stateful_dataloader + +StatefulDataLoader is a drop-in replacement for `torch.utils.data.DataLoader `_ which offers ``state_dict`` / ``load_state_dict`` methods for handling mid-epoch checkpointing which operate on the previous/next iterator requested from the dataloader (resp.). + +By default, the state includes the number of batches yielded and uses this to naively fast-forward the sampler (map-style) or the dataset (iterable-style). However if the sampler and/or dataset include ``state_dict`` / ``load_state_dict`` methods, then it will call them during its own ``state_dict`` / ``load_state_dict`` calls. Under the hood, :class:`StatefulDataLoader` handles aggregation and distribution of state across multiprocess workers (but not across ranks). + +.. autoclass:: StatefulDataLoader + :members: diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 92553cf69..9b162b4f8 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -92,13 +92,12 @@ class StatefulDataLoader(DataLoader[_T_co]): r""" - This is a drop in replacement for :class:`~torch.utils.data.DataLoader` + This is a drop in replacement for ``torch.utils.data.DataLoader`` that implements state_dict and load_state_dict methods, enabling mid-epoch checkpointing. - All arguments are identical to :class:`~torch.utils.data.DataLoader`, with - a new kwarg: `snapshot_every_n_steps: Optional[int] = `. - See :py:mod:`torch.utils.data` documentation page for more details. + All arguments are identical to ``torch.utils.data.DataLoader``, with + a new kwarg: ``snapshot_every_n_steps``. Args: dataset (Dataset): dataset from which to load the data. @@ -148,11 +147,13 @@ class StatefulDataLoader(DataLoader[_T_co]): maintain the workers `Dataset` instances alive. (default: ``False``) pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is ``True``. + snapshot_every_n_steps (int, optional): Defines how often the state is + transferred from the dataloader workers to the dataloader. By default, it is set to ``1``, i.e., state is transferred every step. If the state is large, this value can be increased (and ideally set to the frequency of training checkpointing) to reduce the overhead of transferring state every step. .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. See - :ref:`multiprocessing-best-practices` on more details related + `multiprocessing-best-practices `_ on more details related to multiprocessing in PyTorch. .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. @@ -169,12 +170,12 @@ class StatefulDataLoader(DataLoader[_T_co]): dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such cases in general. - See `Dataset Types`_ for more details on these two types of datasets and how + See `Dataset Types `_ for more details on these two types of datasets and how :class:`~torch.utils.data.IterableDataset` interacts with - `Multi-process data loading`_. + `Multi-process data loading `_. - .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and - :ref:`data-loading-randomness` notes for random seed related questions. + .. warning:: See `Reproducibility `_, and `Dataloader-workers-random-seed `_, and + `Data-loading-randomness `_ notes for random seed related questions. .. _multiprocessing context: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods