@@ -161,9 +161,19 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
161
161
operations called by `func`, effectively vectorizing those operations.
162
162
163
163
vmap is useful for handling batch dimensions: one can write a function `func`
164
- that runs on examples and the lift it to a function that can take batches of
165
- examples with `vmap(func)`. Furthermore, it is possible to use vmap to obtain
166
- batched gradients when composed with autograd.
164
+ that runs on examples and then lift it to a function that can take batches of
165
+ examples with `vmap(func)`. vmap can also be used to compute batched
166
+ gradients when composed with autograd.
167
+
168
+ .. warning::
169
+ torch.vmap is an experimental prototype that is subject to
170
+ change and/or deletion. Please use at your own risk.
171
+
172
+ .. note::
173
+ If you're interested in using vmap for your use case, please
174
+ `contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
175
+ We're interested in gathering feedback from early adopters to inform
176
+ the design.
167
177
168
178
Args:
169
179
func (function): A Python function that takes one or more arguments.
@@ -188,9 +198,56 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
188
198
Examples of side-effects include mutating Python data structures and
189
199
assigning values to variables not captured in `func`.
190
200
191
- .. warning::
192
- torch.vmap is an experimental prototype that is subject to
193
- change and/or deletion. Please use at your own risk.
201
+ One example of using `vmap` is to compute batched dot products. PyTorch
202
+ doesn't provide a batched `torch.dot` API; instead of unsuccessfully
203
+ rummaging through docs, use `vmap` to construct a new function.
204
+
205
+ >>> torch.dot # [D], [D] -> []
206
+ >>> batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
207
+ >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
208
+ >>> batched_dot(x, y)
209
+
210
+ `vmap` can be helpful in hiding batch dimensions, leading to a simpler
211
+ model authoring experience.
212
+
213
+ >>> batch_size, feature_size = 3, 5
214
+ >>> weights = torch.randn(feature_size, requires_grad=True)
215
+ >>>
216
+ >>> def model(feature_vec):
217
+ >>> # Very simple linear model with activation
218
+ >>> return feature_vec.dot(weights).relu()
219
+ >>>
220
+ >>> examples = torch.randn(batch_size, feature_size)
221
+ >>> result = torch.vmap(model)(examples)
222
+
223
+ `vmap` can also help vectorize computations that were previously difficult
224
+ or impossible to batch. One example is higher-order gradient computation.
225
+ The PyTorch autograd engine computes vjps (vector-Jacobian products).
226
+ Computing a full Jacobian matrix for some function f: R^N -> R^N usually
227
+ requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
228
+ we can vectorize the whole computation, computing the Jacobian in a single
229
+ call to `autograd.grad`.
230
+
231
+ >>> # Setup
232
+ >>> N = 5
233
+ >>> f = lambda x: x ** 2
234
+ >>> x = torch.randn(N, requires_grad=True)
235
+ >>> y = f(x)
236
+ >>> I_N = torch.eye(N)
237
+ >>>
238
+ >>> # Sequential approach
239
+ >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
240
+ >>> for v in I_N.unbind()]
241
+ >>> jacobian = torch.stack(jacobian_rows)
242
+ >>>
243
+ >>> # vectorized gradient computation
244
+ >>> def get_vjp(v):
245
+ >>> return torch.autograd.grad(y, x, v)
246
+ >>> jacobian = torch.vmap(get_vjp)(I_N)
247
+
248
+ .. note::
249
+ vmap does not provide general autobatching or handle variable-length
250
+ sequences out of the box.
194
251
"""
195
252
warnings .warn (
196
253
'torch.vmap is an experimental prototype that is subject to '
0 commit comments