diff --git a/oneflow/api/python/autograd/autograd.cpp b/oneflow/api/python/autograd/autograd.cpp index 98ba6fcdcc9..c62244fa7b8 100644 --- a/oneflow/api/python/autograd/autograd.cpp +++ b/oneflow/api/python/autograd/autograd.cpp @@ -101,7 +101,7 @@ Maybe Backward(const one::TensorTuple& outputs, const one::Ten Maybe Grad(const one::TensorTuple& outputs, const one::TensorTuple& inputs, const one::TensorTuple& out_grads, bool retain_graph, - bool create_graph) { + bool create_graph, bool allow_unused) { PythonFrameGuard pf; BackwardPassScopeGuard backward_guard; if (create_graph) { retain_graph = true; } @@ -112,7 +112,7 @@ Maybe Grad(const one::TensorTuple& outputs, const one::TensorT << "All input tensors `.requires_grad` should be true"; std::shared_ptr gradients = JUST(CheckAndInitOutGrads(outputs, out_grads)); return one::GetThreadLocalAutogradEngine()->RunBackwardAndReturnInputsTensorGradIf( - outputs, inputs, *gradients, retain_graph, create_graph); + outputs, inputs, *gradients, retain_graph, create_graph, allow_unused); } namespace py = pybind11; diff --git a/oneflow/core/autograd/autograd_engine.cpp b/oneflow/core/autograd/autograd_engine.cpp index a65b98f619e..3812c1007f3 100644 --- a/oneflow/core/autograd/autograd_engine.cpp +++ b/oneflow/core/autograd/autograd_engine.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/error.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_arg.h" #include "oneflow/core/framework/tensor_methods.h" @@ -136,13 +137,13 @@ Maybe AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTup Maybe AutogradEngine::RunBackwardAndReturnInputsTensorGradIf( const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, - bool retain_graph, bool create_graph) { + bool retain_graph, bool create_graph, bool allow_unused) { JUST(CheckGlobalTensorsMeta(outputs)); JUST(CheckGlobalTensorsMeta(inputs)); JUST(CheckGlobalTensorsMeta(out_grads)); DisableCheckGlobalTensorMetaScope disable_meta_check; return RunBackwardAndReturnInputsTensorGrad(outputs, inputs, out_grads, retain_graph, - create_graph); + create_graph, allow_unused); } Maybe FunctionNode::AccGrad4RetainGradTensor(bool create_graph) { @@ -350,7 +351,8 @@ Maybe GraphTask::ComputeDependencies() { // Computes the number of dependencies for each FunctionNode and prunes useless FunctionNode // according to input tensors -Maybe GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs) { +Maybe GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs, + bool allow_unused) { struct NodeFrame { explicit NodeFrame(FunctionNode* node) : node_(node), next_function_idx_(0) {} FunctionNode* node_; @@ -370,7 +372,11 @@ Maybe GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs captured_grads_ = std::make_shared(inputs.size()); for (int idx = 0; idx < inputs.size(); idx++) { const auto& input = inputs[idx]; - CHECK_NOTNULL_OR_RETURN(input->mut_grad_fn_node().get()); // NOLINT(maybe-need-error-msg) + if (allow_unused && !input->mut_grad_fn_node().get()) { continue; } + CHECK_NOTNULL_OR_RETURN(input->mut_grad_fn_node().get()) + << Error::RuntimeError() + << "One of the differentiated Tensors appears to not have been used in the graph. Set " + "allow_unused=True if this is the desired behavior."; ExecInfo& exec_info = grad_fn2exec_info_[input->mut_grad_fn_node().get()]; exec_info.need_execute = true; if (!exec_info.capture_indices) { @@ -467,13 +473,13 @@ Maybe GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor Maybe GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad( const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, - bool retain_graph, bool create_graph) { + bool retain_graph, bool create_graph, bool allow_unused) { for (int i = 0; i < outputs.size(); ++i) { JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i))); } GraphTask graph_task(outputs, retain_graph, create_graph); - JUST(graph_task.ComputeDependenciesAndPruneNode(inputs)); + JUST(graph_task.ComputeDependenciesAndPruneNode(inputs, allow_unused)); if (IsInDebugMode()) { JUST(graph_task.WriteGraphToDotFile(GetDebugGraphFileName("grad", std::to_string(clock())))); } diff --git a/oneflow/core/autograd/autograd_engine.h b/oneflow/core/autograd/autograd_engine.h index ac24fcfae52..88924671573 100644 --- a/oneflow/core/autograd/autograd_engine.h +++ b/oneflow/core/autograd/autograd_engine.h @@ -105,7 +105,8 @@ class AutogradEngine { Maybe RunBackwardAndReturnInputsTensorGradIf(const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, - bool retain_graph, bool create_graph); + bool retain_graph, bool create_graph, + bool allow_unused); virtual void ClearEngine() = 0; // Builds FunctionNode, binding to all `outputs_` tensors and saving in AutogradEngine virtual Maybe AddNode(const std::string& name, @@ -119,11 +120,9 @@ class AutogradEngine { virtual Maybe RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph) = 0; - virtual Maybe RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs, - const TensorTuple& inputs, - const TensorTuple& out_grads, - bool retain_graph, - bool create_graph) = 0; + virtual Maybe RunBackwardAndReturnInputsTensorGrad( + const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, + bool retain_graph, bool create_graph, bool allow_unused) = 0; }; // Graph Autograd Node and Engine @@ -151,7 +150,7 @@ class GraphTask final { GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_graph); Maybe ComputeDependencies(); - Maybe ComputeDependenciesAndPruneNode(const TensorTuple& inputs); + Maybe ComputeDependenciesAndPruneNode(const TensorTuple& inputs, bool allow_unused); Maybe Apply(bool save_grad_for_leaf); std::shared_ptr GetCapturedGrads() const { return captured_grads_; } Maybe WriteGraphToDotFile(const std::string& file_name) const; @@ -193,8 +192,8 @@ class GraphAutogradEngine final : public AutogradEngine { Maybe RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, - bool retain_graph, - bool create_graph) override; + bool retain_graph, bool create_graph, + bool allow_unused) override; }; AutogradEngine* GetThreadLocalAutogradEngine(); diff --git a/python/oneflow/autograd/autograd.py b/python/oneflow/autograd/autograd.py index 1c694e2b04a..d6ef634d990 100644 --- a/python/oneflow/autograd/autograd.py +++ b/python/oneflow/autograd/autograd.py @@ -28,6 +28,7 @@ def grad( grad_outputs: Union[Tensor, Sequence[Tensor], None] = None, retain_graph: bool = False, create_graph: bool = False, + allow_unused: bool = False, ) -> Tuple[Tensor]: r""" Computes and returns the sum of gradients of outputs with respect to the inputs. @@ -52,6 +53,9 @@ def grad( more efficient way. Defaults to the value of ``create_graph``. create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. + allow_unused (bool, optional): If ``False``, specifying inputs that were not + used when computing outputs (and therefore their grad is always zero) + is an error. Defaults to ``False``. Returns: Tuple(Tensor): A tuple of tensors containing the gradients for each ``inputs``. @@ -62,8 +66,9 @@ def grad( convert_to_tensor_tuple(grad_outputs), retain_graph, create_graph, + allow_unused, ) - return tuple([Tensor(x) for x in in_grads]) + return tuple([x for x in in_grads]) def backward( diff --git a/python/oneflow/test/exceptions/test_autograd.py b/python/oneflow/test/exceptions/test_autograd.py index d1efc468a5a..3bd892596f5 100644 --- a/python/oneflow/test/exceptions/test_autograd.py +++ b/python/oneflow/test/exceptions/test_autograd.py @@ -32,6 +32,17 @@ def test_non_requires_grad_tensor_backward(test_case): ) ) + def test_allow_unused(test_case): + with test_case.assertRaises(Exception) as context: + x = flow.ones(4, 4).requires_grad_() + y = flow.ones(4, 4).requires_grad_() + z = x * x + dx, dy = flow.autograd.grad(z, [x, y], flow.ones_like(z)) + test_case.assertTrue( + "allow_unused=True if this is the desired behavior" + in str(context.exception) + ) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_autograd.py b/python/oneflow/test/modules/test_autograd.py index fa9e574fec7..4cf52d0e06a 100644 --- a/python/oneflow/test/modules/test_autograd.py +++ b/python/oneflow/test/modules/test_autograd.py @@ -227,6 +227,62 @@ def test_acc_grad_inplace_update(test_case): test_case.assertEqual(id_x_grad, id(x.grad)) test_case.assertEqual(id_y_grad, id(y.grad)) + def test_autograd_grad_allow_unused(test_case): + shape = [random(1, 10).to(int) for _ in range(4)] + shape = [2, 4] + device = random_device() + x = random_tensor(len(shape), *shape, requires_grad=True).to(device) + z = random_tensor(len(shape), *shape, requires_grad=True).to(device) + y = x * x + + np_arr = np.random.rand(*y.oneflow.shape) + init_grad = torch.tensor(np_arr).requires_grad_().to(device) + dx_and_dz = torch.autograd.grad( + y, + [x, z], + init_grad, + retain_graph=True, + create_graph=True, + allow_unused=True, + ) + test_case.assertTrue( + np.allclose( + dx_and_dz[0].oneflow.detach().numpy(), + dx_and_dz[0].pytorch.detach().cpu().numpy(), + ) + ) + test_case.assertTrue( + dx_and_dz[1].oneflow is None and dx_and_dz[1].pytorch is None + ) + + np_arr = np.random.rand(*y.oneflow.shape) + init_grad_grad = torch.tensor(np_arr).requires_grad_().to(device) + ddx = torch.autograd.grad( + dx_and_dz[0], + x, + init_grad_grad, + retain_graph=True, + create_graph=True, + allow_unused=True, + )[0] + test_case.assertTrue( + np.allclose( + ddx.oneflow.detach().numpy(), ddx.pytorch.detach().cpu().numpy(), + ) + ) + + np_arr = np.random.rand(*y.oneflow.shape) + init_grad_grad_grad = torch.tensor(np_arr).requires_grad_().to(device) + dddx = torch.autograd.grad( + ddx, + x, + init_grad_grad_grad, + retain_graph=True, + create_graph=True, + allow_unused=True, + )[0] + test_case.assertTrue(dddx.oneflow is None and dddx.pytorch is None) + if __name__ == "__main__": unittest.main()