Skip to content

Commit b5868b2

Browse files
ssnlfacebook-github-bot
authored andcommittedMay 14, 2020
Relax sampler check in BatchSampler (pytorch#38403)
Summary: Since the check was added in pytorch#6249, one can not pass an iterable as a sampler to the data loader anymore, which was a very handy feature (e.g., pytorch#1337). I think the check should be removed for two-fold reasons: 1. It is too strict. There is no reason that it should not be a general iterable. 2. It is inconsistent. In `DataLoader` (the main place where people use samplers), you can pass a general iterable as `batch_sampler` but not `sampler` due to this check. Pull Request resolved: pytorch#38403 Differential Revision: D21555958 Pulled By: soumith fbshipit-source-id: c7267bb99a31edd8f2750689205d6edc5dab5cff
1 parent f3d2e33 commit b5868b2

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed
 

‎test/test_dataloader.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1312,9 +1312,23 @@ def test_duplicating_data_with_drop_last(self):
13121312

13131313
self.assertEqual(scanned_data.size(), scanned_data.unique().size())
13141314

1315+
def _test_sampler(self, **kwargs):
1316+
indices = range(2, 12) # using a regular iterable
1317+
dl = DataLoader(self.dataset, sampler=indices, batch_size=2, **kwargs)
1318+
self.assertEqual(len(dl), 5)
1319+
for i, (input, _target) in enumerate(dl):
1320+
self.assertEqual(len(input), 2)
1321+
self.assertEqual(input, self.data[i * 2 + 2:i * 2 + 4])
1322+
1323+
def test_sampler(self):
1324+
self._test_sampler()
1325+
self._test_sampler(num_workers=4)
1326+
if not NO_MULTIPROCESSING_SPAWN and torch.multiprocessing._supports_context:
1327+
self._test_batch_sampler(num_workers=4, multiprocessing_context='spawn')
1328+
13151329
def _test_batch_sampler(self, **kwargs):
13161330
# [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
1317-
batches = []
1331+
batches = [] # using a regular iterable
13181332
for i in range(0, 20, 5):
13191333
batches.append(tuple(range(i, i + 2)))
13201334
batches.append(tuple(range(i + 2, i + 5)))

‎torch/utils/data/dataloader.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ class DataLoader(object):
7474
(default: ``1``).
7575
shuffle (bool, optional): set to ``True`` to have the data reshuffled
7676
at every epoch (default: ``False``).
77-
sampler (Sampler, optional): defines the strategy to draw samples from
78-
the dataset. If specified, :attr:`shuffle` must be ``False``.
79-
batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of
80-
indices at a time. Mutually exclusive with :attr:`batch_size`,
81-
:attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
77+
sampler (Sampler or Iterable, optional): defines the strategy to draw
78+
samples from the dataset. Can be any ``Iterable`` with ``__len__``
79+
implemented. If specified, :attr:`shuffle` must not be specified.
80+
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
81+
returns a batch of indices at a time. Mutually exclusive with
82+
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
83+
and :attr:`drop_last`.
8284
num_workers (int, optional): how many subprocesses to use for data
8385
loading. ``0`` means that the data will be loaded in the main process.
8486
(default: ``0``)

‎torch/utils/data/sampler.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ class BatchSampler(Sampler):
167167
r"""Wraps another sampler to yield a mini-batch of indices.
168168
169169
Args:
170-
sampler (Sampler): Base sampler.
170+
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
171+
with ``__len__`` implemented.
171172
batch_size (int): Size of mini-batch.
172173
drop_last (bool): If ``True``, the sampler will drop the last batch if
173174
its size would be less than ``batch_size``
@@ -180,10 +181,9 @@ class BatchSampler(Sampler):
180181
"""
181182

182183
def __init__(self, sampler, batch_size, drop_last):
183-
if not isinstance(sampler, Sampler):
184-
raise ValueError("sampler should be an instance of "
185-
"torch.utils.data.Sampler, but got sampler={}"
186-
.format(sampler))
184+
# Since collections.abc.Iterable does not check for `__getitem__`, which
185+
# is one way for an object to be an iterable, we don't do an `isinstance`
186+
# check here.
187187
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
188188
batch_size <= 0:
189189
raise ValueError("batch_size should be a positive integer value, "

0 commit comments

Comments
 (0)
Please sign in to comment.