Skip to content

Commit

Permalink
adding stateful dataloader docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gokulavasan committed Jul 30, 2024
1 parent 9ad0094 commit 0c77b9c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Features described in this documentation are classified by release status:
:maxdepth: 2
:caption: API Reference:

torchdata.stateful_dataloader.rst
torchdata.datapipes.iter.rst
torchdata.datapipes.map.rst
torchdata.datapipes.utils.rst
Expand Down
13 changes: 13 additions & 0 deletions docs/source/torchdata.stateful_dataloader.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
:tocdepth: 3

Stateful DataLoader
===================

.. automodule:: torchdata.stateful_dataloader

StatefulDataLoader is a drop-in replacement for `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ which offers ``state_dict`` / ``load_state_dict`` methods for handling mid-epoch checkpointing which operate on the previous/next iterator requested from the dataloader (resp.).

By default, the state includes the number of batches yielded and uses this to naively fast-forward the sampler (map-style) or the dataset (iterable-style). However if the sampler and/or dataset include ``state_dict`` / ``load_state_dict`` methods, then it will call them during its own ``state_dict`` / ``load_state_dict`` calls. Under the hood, :class:`StatefulDataLoader` handles aggregation and distribution of state across multiprocess workers (but not across ranks).

.. autoclass:: StatefulDataLoader
:members:
19 changes: 10 additions & 9 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,12 @@

class StatefulDataLoader(DataLoader[_T_co]):
r"""
This is a drop in replacement for :class:`~torch.utils.data.DataLoader`
This is a drop in replacement for ``torch.utils.data.DataLoader``
that implements state_dict and load_state_dict methods, enabling mid-epoch
checkpointing.
All arguments are identical to :class:`~torch.utils.data.DataLoader`, with
a new kwarg: `snapshot_every_n_steps: Optional[int] = `.
See :py:mod:`torch.utils.data` documentation page for more details.
All arguments are identical to ``torch.utils.data.DataLoader``, with
a new kwarg: ``snapshot_every_n_steps``.
Args:
dataset (Dataset): dataset from which to load the data.
Expand Down Expand Up @@ -148,11 +147,13 @@ class StatefulDataLoader(DataLoader[_T_co]):
maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
``True``.
snapshot_every_n_steps (int, optional): Defines how often the state is
transferred from the dataloader workers to the dataloader. By default, it is set to ``1``, i.e., state is transferred every step. If the state is large, this value can be increased (and ideally set to the frequency of training checkpointing) to reduce the overhead of transferring state every step.
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
cannot be an unpicklable object, e.g., a lambda function. See
:ref:`multiprocessing-best-practices` on more details related
`multiprocessing-best-practices <https://pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-best-practices>`_ on more details related
to multiprocessing in PyTorch.
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
Expand All @@ -169,12 +170,12 @@ class StatefulDataLoader(DataLoader[_T_co]):
dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
cases in general.
See `Dataset Types`_ for more details on these two types of datasets and how
See `Dataset Types <https://pytorch.org/docs/stable/data.html>`_ for more details on these two types of datasets and how
:class:`~torch.utils.data.IterableDataset` interacts with
`Multi-process data loading`_.
`Multi-process data loading <https://pytorch.org/docs/stable/data.html#multi-process-data-loading>`_.
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
:ref:`data-loading-randomness` notes for random seed related questions.
.. warning:: See `Reproducibility <https://pytorch.org/docs/stable/notes/randomness.html#reproducibility>`_, and `Dataloader-workers-random-seed <https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed>`_, and
`Data-loading-randomness <https://pytorch.org/docs/stable/data.html#data-loading-randomness>`_ notes for random seed related questions.
.. _multiprocessing context:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
Expand Down

0 comments on commit 0c77b9c

Please sign in to comment.