|
| 1 | +#include <torch/csrc/jit/backends/backend.h> |
| 2 | + |
| 3 | +namespace torch { |
| 4 | +namespace custom_backend { |
| 5 | +// This custom JIT backend is intended to do the minimal amount of work |
| 6 | +// necessary to test that the JIT backend registration endpoints and |
| 7 | +// code generation are working correctly. It is not intended to |
| 8 | +// produce numerically correct results. |
| 9 | +class CustomBackend : public torch::jit::PyTorchBackendInterface { |
| 10 | + public: |
| 11 | + // Constructor. |
| 12 | + explicit CustomBackend() {} |
| 13 | + virtual ~CustomBackend() = default; |
| 14 | + |
| 15 | + c10::IValue preprocess( |
| 16 | + c10::IValue mod, |
| 17 | + c10::impl::GenericDict method_compile_spec) override { |
| 18 | + return mod; |
| 19 | + } |
| 20 | + |
| 21 | + c10::impl::GenericDict compile( |
| 22 | + c10::IValue processed, |
| 23 | + c10::impl::GenericDict method_compile_spec) override { |
| 24 | + auto spec = |
| 25 | + c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec); |
| 26 | + |
| 27 | + // Return the same string as a value for every key in method_compile_spec. |
| 28 | + auto handles = c10::Dict<std::string, std::string>(); |
| 29 | + for (auto it = spec.begin(), end = spec.end(); it != end; ++it) { |
| 30 | + handles.insert(it->key(), it->key()); |
| 31 | + } |
| 32 | + return c10::impl::toGenericDict(handles); |
| 33 | + } |
| 34 | + c10::impl::GenericList execute( |
| 35 | + c10::IValue handle, |
| 36 | + c10::impl::GenericList inputs) override { |
| 37 | + TORCH_INTERNAL_ASSERT(handle.isString()); |
| 38 | + TORCH_INTERNAL_ASSERT(inputs.size() > 0); |
| 39 | + |
| 40 | + c10::List<at::Tensor> output_list; |
| 41 | + |
| 42 | + // Implement simple accumulator and negative accumulator (?) ops. Return one |
| 43 | + // or both of them depending on the handle to make sure multiple outputs are |
| 44 | + // handled. |
| 45 | + c10::IValue value = inputs[0]; |
| 46 | + at::Tensor accum = value.toTensor(); |
| 47 | + accum = accum.clone(); |
| 48 | + at::Tensor sub_accum = value.toTensor(); |
| 49 | + sub_accum = sub_accum.clone(); |
| 50 | + |
| 51 | + for (size_t i = 1, e = inputs.size(); i < e; ++i) { |
| 52 | + value = inputs[i]; |
| 53 | + accum.add_(value.toTensor(), 1.0); |
| 54 | + sub_accum.sub_(value.toTensor(), 1.0); |
| 55 | + } |
| 56 | + |
| 57 | + if (handle.toStringRef() == "accum") { |
| 58 | + output_list.emplace_back(accum); |
| 59 | + } else if (handle.toStringRef() == "sub_accum") { |
| 60 | + output_list.emplace_back(sub_accum); |
| 61 | + } else if (handle.toStringRef() == "forward") { |
| 62 | + output_list.emplace_back(accum); |
| 63 | + output_list.emplace_back(sub_accum); |
| 64 | + } |
| 65 | + |
| 66 | + return c10::impl::toList(output_list); |
| 67 | + } |
| 68 | +}; |
| 69 | + |
| 70 | +// clang-format off |
| 71 | +# if defined(_WIN32) |
| 72 | +# if defined(custom_ops_EXPORTS) |
| 73 | +# define CUSTOM_BACKEND_API __declspec(dllexport) |
| 74 | +# else |
| 75 | +# define CUSTOM_BACKEND_API __declspec(dllimport) |
| 76 | +# endif |
| 77 | +# else |
| 78 | +# define CUSTOM_BACKEND_API |
| 79 | +# endif |
| 80 | +// clang-format on |
| 81 | + |
| 82 | +CUSTOM_BACKEND_API std::string getBackendName(); |
| 83 | +} // namespace custom_backend |
| 84 | +} // namespace torch |
0 commit comments