Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6d31213

Browse files
zou3519facebook-github-bot
authored andcommittedSep 18, 2020
Beef up vmap docs and expose to master documentation (pytorch#44825)
Summary: Pull Request resolved: pytorch#44825 Test Plan: - build and view docs locally. Reviewed By: ezyang Differential Revision: D23742727 Pulled By: zou3519 fbshipit-source-id: f62b7a76b5505d3387b7816c514c086c01089de0
1 parent c2cf6ef commit 6d31213

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed
 

‎docs/source/torch.rst

+1
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,4 @@ Utilities
533533
promote_types
534534
set_deterministic
535535
is_deterministic
536+
vmap

‎torch/_vmap_internals.py

+63-6
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,19 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
161161
operations called by `func`, effectively vectorizing those operations.
162162
163163
vmap is useful for handling batch dimensions: one can write a function `func`
164-
that runs on examples and the lift it to a function that can take batches of
165-
examples with `vmap(func)`. Furthermore, it is possible to use vmap to obtain
166-
batched gradients when composed with autograd.
164+
that runs on examples and then lift it to a function that can take batches of
165+
examples with `vmap(func)`. vmap can also be used to compute batched
166+
gradients when composed with autograd.
167+
168+
.. warning::
169+
torch.vmap is an experimental prototype that is subject to
170+
change and/or deletion. Please use at your own risk.
171+
172+
.. note::
173+
If you're interested in using vmap for your use case, please
174+
`contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
175+
We're interested in gathering feedback from early adopters to inform
176+
the design.
167177
168178
Args:
169179
func (function): A Python function that takes one or more arguments.
@@ -188,9 +198,56 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
188198
Examples of side-effects include mutating Python data structures and
189199
assigning values to variables not captured in `func`.
190200
191-
.. warning::
192-
torch.vmap is an experimental prototype that is subject to
193-
change and/or deletion. Please use at your own risk.
201+
One example of using `vmap` is to compute batched dot products. PyTorch
202+
doesn't provide a batched `torch.dot` API; instead of unsuccessfully
203+
rummaging through docs, use `vmap` to construct a new function.
204+
205+
>>> torch.dot # [D], [D] -> []
206+
>>> batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
207+
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
208+
>>> batched_dot(x, y)
209+
210+
`vmap` can be helpful in hiding batch dimensions, leading to a simpler
211+
model authoring experience.
212+
213+
>>> batch_size, feature_size = 3, 5
214+
>>> weights = torch.randn(feature_size, requires_grad=True)
215+
>>>
216+
>>> def model(feature_vec):
217+
>>> # Very simple linear model with activation
218+
>>> return feature_vec.dot(weights).relu()
219+
>>>
220+
>>> examples = torch.randn(batch_size, feature_size)
221+
>>> result = torch.vmap(model)(examples)
222+
223+
`vmap` can also help vectorize computations that were previously difficult
224+
or impossible to batch. One example is higher-order gradient computation.
225+
The PyTorch autograd engine computes vjps (vector-Jacobian products).
226+
Computing a full Jacobian matrix for some function f: R^N -> R^N usually
227+
requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
228+
we can vectorize the whole computation, computing the Jacobian in a single
229+
call to `autograd.grad`.
230+
231+
>>> # Setup
232+
>>> N = 5
233+
>>> f = lambda x: x ** 2
234+
>>> x = torch.randn(N, requires_grad=True)
235+
>>> y = f(x)
236+
>>> I_N = torch.eye(N)
237+
>>>
238+
>>> # Sequential approach
239+
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
240+
>>> for v in I_N.unbind()]
241+
>>> jacobian = torch.stack(jacobian_rows)
242+
>>>
243+
>>> # vectorized gradient computation
244+
>>> def get_vjp(v):
245+
>>> return torch.autograd.grad(y, x, v)
246+
>>> jacobian = torch.vmap(get_vjp)(I_N)
247+
248+
.. note::
249+
vmap does not provide general autobatching or handle variable-length
250+
sequences out of the box.
194251
"""
195252
warnings.warn(
196253
'torch.vmap is an experimental prototype that is subject to '

0 commit comments

Comments
 (0)
Please sign in to comment.