Skip to content

[mlir python] Port Python core code to nanobind. #120473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 19, 2024
Merged

Conversation

hawkinsp
Copy link
Contributor

Relands #118583, with a fix for Python 3.8 compatibility. It was not possible to set the buffer protocol accessers via slots in Python 3.8.

Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX.

For a complicated Google-internal LLM model in JAX, this change improves the MLIR
lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks.

To a large extent, this is a mechanical change, for instance changing pybind11:: to nanobind::.

Notes:

  • this PR needs Nanobind 2.4.0, because it needs a bug fix (Support overriding static properties defined via def_prop_ro_static. wjakob/nanobind#806) that landed in that release.
  • this PR does not port the in-tree dialect extension modules. They can be ported in a future PR.
  • I removed the py::sibling() annotations from def_static and def_class in PybindAdapters.h. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here.
  • nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API.
  • nanobind is pickier about casting to std::vector, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting.
  • nanobind distinguishes bytes (nb::bytes) from strings (e.g., std::string). This required nb::bytes overloads in a few places.

@llvmbot llvmbot added mlir:python MLIR Python bindings mlir bazel "Peripheral" support tier build system: utils/bazel labels Dec 18, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-mlir

Author: Peter Hawkins (hawkinsp)

Changes

Relands #118583, with a fix for Python 3.8 compatibility. It was not possible to set the buffer protocol accessers via slots in Python 3.8.

Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX.

For a complicated Google-internal LLM model in JAX, this change improves the MLIR
lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks.

To a large extent, this is a mechanical change, for instance changing pybind11:: to nanobind::.

Notes:

  • this PR needs Nanobind 2.4.0, because it needs a bug fix (Support overriding static properties defined via def_prop_ro_static. wjakob/nanobind#806) that landed in that release.
  • this PR does not port the in-tree dialect extension modules. They can be ported in a future PR.
  • I removed the py::sibling() annotations from def_static and def_class in PybindAdapters.h. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here.
  • nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API.
  • nanobind is pickier about casting to std::vector<bool>, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting.
  • nanobind distinguishes bytes (nb::bytes) from strings (e.g., std::string). This required nb::bytes overloads in a few places.

Patch is 356.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120473.diff

23 Files Affected:

  • (modified) mlir/cmake/modules/MLIRDetectPythonEnv.cmake (+1-1)
  • (modified) mlir/include/mlir/Bindings/Python/IRTypes.h (+1-1)
  • (modified) mlir/include/mlir/Bindings/Python/NanobindAdaptors.h (+13-13)
  • (modified) mlir/include/mlir/Bindings/Python/PybindAdaptors.h (+4-6)
  • (modified) mlir/lib/Bindings/Python/Globals.h (+19-20)
  • (modified) mlir/lib/Bindings/Python/IRAffine.cpp (+129-136)
  • (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+445-235)
  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+733-679)
  • (modified) mlir/lib/Bindings/Python/IRInterfaces.cpp (+86-85)
  • (modified) mlir/lib/Bindings/Python/IRModule.cpp (+29-28)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+172-160)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+107-93)
  • (modified) mlir/lib/Bindings/Python/MainModule.cpp (+29-27)
  • (renamed) mlir/lib/Bindings/Python/NanobindUtils.h (+46-38)
  • (modified) mlir/lib/Bindings/Python/Pass.cpp (+31-27)
  • (modified) mlir/lib/Bindings/Python/Pass.h (+2-2)
  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+23-20)
  • (modified) mlir/lib/Bindings/Python/Rewrite.h (+2-2)
  • (modified) mlir/python/CMakeLists.txt (+2-1)
  • (modified) mlir/python/requirements.txt (+1-1)
  • (modified) mlir/test/python/ir/symbol_table.py (+2-1)
  • (modified) utils/bazel/WORKSPACE (+3-3)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+11-4)
diff --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
index c62ac7fa615ea6..d6bb65c64b8292 100644
--- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
+++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
@@ -39,7 +39,7 @@ macro(mlir_configure_python_dev_packages)
                   "extension = '${PYTHON_MODULE_EXTENSION}")
 
     mlir_detect_nanobind_install()
-    find_package(nanobind 2.2 CONFIG REQUIRED)
+    find_package(nanobind 2.4 CONFIG REQUIRED)
     message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
     message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
                   "suffix = '${PYTHON_MODULE_SUFFIX}', "
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 9afad4c23b3f35..ba9642cf2c6a2d 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,7 +9,7 @@
 #ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
 #define MLIR_BINDINGS_PYTHON_IRTYPES_H
 
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace mlir {
 
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 5e01cebcb09c91..943981b1fa03dd 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -64,7 +64,7 @@ static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) {
 /// Casts object <-> MlirAffineMap.
 template <>
 struct type_caster<MlirAffineMap> {
-  NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"));
+  NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToAffineMap(capsule.ptr());
@@ -87,7 +87,7 @@ struct type_caster<MlirAffineMap> {
 /// Casts object <-> MlirAttribute.
 template <>
 struct type_caster<MlirAttribute> {
-  NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"));
+  NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToAttribute(capsule.ptr());
@@ -108,7 +108,7 @@ struct type_caster<MlirAttribute> {
 /// Casts object -> MlirBlock.
 template <>
 struct type_caster<MlirBlock> {
-  NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"));
+  NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToBlock(capsule.ptr());
@@ -119,7 +119,7 @@ struct type_caster<MlirBlock> {
 /// Casts object -> MlirContext.
 template <>
 struct type_caster<MlirContext> {
-  NB_TYPE_CASTER(MlirContext, const_name("MlirContext"));
+  NB_TYPE_CASTER(MlirContext, const_name("MlirContext"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     if (src.is_none()) {
       // Gets the current thread-bound context.
@@ -139,7 +139,7 @@ struct type_caster<MlirContext> {
 /// Casts object <-> MlirDialectRegistry.
 template <>
 struct type_caster<MlirDialectRegistry> {
-  NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"));
+  NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
@@ -159,7 +159,7 @@ struct type_caster<MlirDialectRegistry> {
 /// Casts object <-> MlirLocation.
 template <>
 struct type_caster<MlirLocation> {
-  NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"));
+  NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     if (src.is_none()) {
       // Gets the current thread-bound context.
@@ -185,7 +185,7 @@ struct type_caster<MlirLocation> {
 /// Casts object <-> MlirModule.
 template <>
 struct type_caster<MlirModule> {
-  NB_TYPE_CASTER(MlirModule, const_name("MlirModule"));
+  NB_TYPE_CASTER(MlirModule, const_name("MlirModule"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToModule(capsule.ptr());
@@ -206,7 +206,7 @@ struct type_caster<MlirModule> {
 template <>
 struct type_caster<MlirFrozenRewritePatternSet> {
   NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
-                 const_name("MlirFrozenRewritePatternSet"));
+                 const_name("MlirFrozenRewritePatternSet"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
@@ -225,7 +225,7 @@ struct type_caster<MlirFrozenRewritePatternSet> {
 /// Casts object <-> MlirOperation.
 template <>
 struct type_caster<MlirOperation> {
-  NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"));
+  NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToOperation(capsule.ptr());
@@ -247,7 +247,7 @@ struct type_caster<MlirOperation> {
 /// Casts object <-> MlirValue.
 template <>
 struct type_caster<MlirValue> {
-  NB_TYPE_CASTER(MlirValue, const_name("MlirValue"));
+  NB_TYPE_CASTER(MlirValue, const_name("MlirValue"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToValue(capsule.ptr());
@@ -270,7 +270,7 @@ struct type_caster<MlirValue> {
 /// Casts object -> MlirPassManager.
 template <>
 struct type_caster<MlirPassManager> {
-  NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"));
+  NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToPassManager(capsule.ptr());
@@ -281,7 +281,7 @@ struct type_caster<MlirPassManager> {
 /// Casts object <-> MlirTypeID.
 template <>
 struct type_caster<MlirTypeID> {
-  NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"));
+  NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToTypeID(capsule.ptr());
@@ -303,7 +303,7 @@ struct type_caster<MlirTypeID> {
 /// Casts object <-> MlirType.
 template <>
 struct type_caster<MlirType> {
-  NB_TYPE_CASTER(MlirType, const_name("MlirType"));
+  NB_TYPE_CASTER(MlirType, const_name("MlirType"))
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     nanobind::object capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToType(capsule.ptr());
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index c8233355d1d67b..edc69774be9227 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -374,9 +374,8 @@ class pure_subclass {
     static_assert(!std::is_member_function_pointer<Func>::value,
                   "def_staticmethod(...) called with a non-static member "
                   "function pointer");
-    py::cpp_function cf(
-        std::forward<Func>(f), py::name(name), py::scope(thisClass),
-        py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+    py::cpp_function cf(std::forward<Func>(f), py::name(name),
+                        py::scope(thisClass), extra...);
     thisClass.attr(cf.name()) = py::staticmethod(cf);
     return *this;
   }
@@ -387,9 +386,8 @@ class pure_subclass {
     static_assert(!std::is_member_function_pointer<Func>::value,
                   "def_classmethod(...) called with a non-static member "
                   "function pointer");
-    py::cpp_function cf(
-        std::forward<Func>(f), py::name(name), py::scope(thisClass),
-        py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+    py::cpp_function cf(std::forward<Func>(f), py::name(name),
+                        py::scope(thisClass), extra...);
     thisClass.attr(cf.name()) =
         py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
     return *this;
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index a022067f5c7e57..0ec522d14f74bd 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -9,18 +9,17 @@
 #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
 #define MLIR_BINDINGS_PYTHON_GLOBALS_H
 
-#include "PybindUtils.h"
+#include <optional>
+#include <string>
+#include <vector>
 
+#include "NanobindUtils.h"
 #include "mlir-c/IR.h"
 #include "mlir/CAPI/Support.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSet.h"
 
-#include <optional>
-#include <string>
-#include <vector>
-
 namespace mlir {
 namespace python {
 
@@ -57,55 +56,55 @@ class PyGlobals {
   /// Raises an exception if the mapping already exists and replace == false.
   /// This is intended to be called by implementation code.
   void registerAttributeBuilder(const std::string &attributeKind,
-                                pybind11::function pyFunc,
+                                nanobind::callable pyFunc,
                                 bool replace = false);
 
   /// Adds a user-friendly type caster. Raises an exception if the mapping
   /// already exists and replace == false. This is intended to be called by
   /// implementation code.
-  void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
+  void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
                           bool replace = false);
 
   /// Adds a user-friendly value caster. Raises an exception if the mapping
   /// already exists and replace == false. This is intended to be called by
   /// implementation code.
   void registerValueCaster(MlirTypeID mlirTypeID,
-                           pybind11::function valueCaster,
+                           nanobind::callable valueCaster,
                            bool replace = false);
 
   /// Adds a concrete implementation dialect class.
   /// Raises an exception if the mapping already exists.
   /// This is intended to be called by implementation code.
   void registerDialectImpl(const std::string &dialectNamespace,
-                           pybind11::object pyClass);
+                           nanobind::object pyClass);
 
   /// Adds a concrete implementation operation class.
   /// Raises an exception if the mapping already exists and replace == false.
   /// This is intended to be called by implementation code.
   void registerOperationImpl(const std::string &operationName,
-                             pybind11::object pyClass, bool replace = false);
+                             nanobind::object pyClass, bool replace = false);
 
   /// Returns the custom Attribute builder for Attribute kind.
-  std::optional<pybind11::function>
+  std::optional<nanobind::callable>
   lookupAttributeBuilder(const std::string &attributeKind);
 
   /// Returns the custom type caster for MlirTypeID mlirTypeID.
-  std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
+  std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
                                                      MlirDialect dialect);
 
   /// Returns the custom value caster for MlirTypeID mlirTypeID.
-  std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
+  std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
                                                       MlirDialect dialect);
 
   /// Looks up a registered dialect class by namespace. Note that this may
   /// trigger loading of the defining module and can arbitrarily re-enter.
-  std::optional<pybind11::object>
+  std::optional<nanobind::object>
   lookupDialectClass(const std::string &dialectNamespace);
 
   /// Looks up a registered operation class (deriving from OpView) by operation
   /// name. Note that this may trigger a load of the dialect, which can
   /// arbitrarily re-enter.
-  std::optional<pybind11::object>
+  std::optional<nanobind::object>
   lookupOperationClass(llvm::StringRef operationName);
 
 private:
@@ -113,15 +112,15 @@ class PyGlobals {
   /// Module name prefixes to search under for dialect implementation modules.
   std::vector<std::string> dialectSearchPrefixes;
   /// Map of dialect namespace to external dialect class object.
-  llvm::StringMap<pybind11::object> dialectClassMap;
+  llvm::StringMap<nanobind::object> dialectClassMap;
   /// Map of full operation name to external operation class object.
-  llvm::StringMap<pybind11::object> operationClassMap;
+  llvm::StringMap<nanobind::object> operationClassMap;
   /// Map of attribute ODS name to custom builder.
-  llvm::StringMap<pybind11::object> attributeBuilderMap;
+  llvm::StringMap<nanobind::callable> attributeBuilderMap;
   /// Map of MlirTypeID to custom type caster.
-  llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
+  llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
   /// Map of MlirTypeID to custom value caster.
-  llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
+  llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.
   llvm::StringSet<> loadedDialectModules;
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index b138e131e851ea..2db690309fab8c 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -6,20 +6,19 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/vector.h>
+
 #include <cstddef>
 #include <cstdint>
-#include <pybind11/cast.h>
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
-#include <pybind11/pytypes.h>
+#include <stdexcept>
 #include <string>
 #include <utility>
 #include <vector>
 
 #include "IRModule.h"
-
-#include "PybindUtils.h"
-
+#include "NanobindUtils.h"
 #include "mlir-c/AffineExpr.h"
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/Bindings/Python/Interop.h"
@@ -30,7 +29,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 
-namespace py = pybind11;
+namespace nb = nanobind;
 using namespace mlir;
 using namespace mlir::python;
 
@@ -46,23 +45,23 @@ static const char kDumpDocstring[] =
 /// Throws errors in case of failure, using "action" to describe what the caller
 /// was attempting to do.
 template <typename PyType, typename CType>
-static void pyListToVector(const py::list &list,
+static void pyListToVector(const nb::list &list,
                            llvm::SmallVectorImpl<CType> &result,
                            StringRef action) {
-  result.reserve(py::len(list));
-  for (py::handle item : list) {
+  result.reserve(nb::len(list));
+  for (nb::handle item : list) {
     try {
-      result.push_back(item.cast<PyType>());
-    } catch (py::cast_error &err) {
+      result.push_back(nb::cast<PyType>(item));
+    } catch (nb::cast_error &err) {
       std::string msg = (llvm::Twine("Invalid expression when ") + action +
                          " (" + err.what() + ")")
                             .str();
-      throw py::cast_error(msg);
-    } catch (py::reference_cast_error &err) {
+      throw std::runtime_error(msg.c_str());
+    } catch (std::runtime_error &err) {
       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
                          action + " (" + err.what() + ")")
                             .str();
-      throw py::cast_error(msg);
+      throw std::runtime_error(msg.c_str());
     }
   }
 }
@@ -94,7 +93,7 @@ class PyConcreteAffineExpr : public BaseTy {
   //   IsAFunctionTy isaFunction
   //   const char *pyClassName
   // and redefine bindDerived.
-  using ClassTy = py::class_<DerivedTy, BaseTy>;
+  using ClassTy = nb::class_<DerivedTy, BaseTy>;
   using IsAFunctionTy = bool (*)(MlirAffineExpr);
 
   PyConcreteAffineExpr() = default;
@@ -105,24 +104,25 @@ class PyConcreteAffineExpr : public BaseTy {
 
   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
     if (!DerivedTy::isaFunction(orig)) {
-      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw py::value_error((Twine("Cannot cast affine expression to ") +
+      auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+      throw nb::value_error((Twine("Cannot cast affine expression to ") +
                              DerivedTy::pyClassName + " (from " + origRepr +
                              ")")
-                                .str());
+                                .str()
+                                .c_str());
     }
     return orig;
   }
 
-  static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
-    cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
+  static void bind(nb::module_ &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
     cls.def_static(
         "isinstance",
         [](PyAffineExpr &otherAffineExpr) -> bool {
           return DerivedTy::isaFunction(otherAffineExpr);
         },
-        py::arg("other"));
+        nb::arg("other"));
     DerivedTy::bindDerived(cls);
   }
 
@@ -144,9 +144,9 @@ class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
   }
 
   static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
-                 py::arg("context") = py::none());
-    c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
+    c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"),
+                 nb::arg("context").none() = nb::none());
+    c.def_prop_ro("value", [](PyAffineConstantExpr &self) {
       return mlirAffineConstantExprGetValue(self);
     });
   }
@@ -164,9 +164,9 @@ class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
   }
 
   static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
-                 py::arg("context") = py::none());
-    c.def_property_readonly("position", [](PyAffineDimExpr &self) {
+    c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"),
+                 nb::arg("context").none() = nb::none());
+    c.def_prop_ro("position", [](PyAffineDimExpr &self) {
       return mlirAffineDimExprGetPosition(self);
     });
   }
@@ -184,9 +184,9 @@ class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
   }
 
   static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
-                 py::arg("context") = py::none());
-    c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
+    c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"),
+                 nb::arg("context").none() = nb::none());
+    c.def_prop_ro("position", [](PyAffineSymbolExpr &self) {
       return mlirAffineSymbolExprGetPosition(self);
     });
   }
@@ -209,8 +209,8 @@ class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
   }
 
   static void bindDerived(ClassTy &c) {
-    c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
-    c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
+    c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs);
+    c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs);
   }
 };
 
@@ -365,15 +365,14 @@ bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
   return mlirAffineExprEqual(affineExpr, other...
[truncated]

Copy link

github-actions bot commented Dec 18, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Relands llvm#118583, with a fix for Python 3.8 compatibility. It was not
possible to set the buffer protocol accessers via slots in Python 3.8.

Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.

For a complicated Google-internal LLM model in JAX, this change improves
the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.

To a large extent, this is a mechanical change, for instance changing
`pybind11::` to `nanobind::`.

Notes:
* this PR needs Nanobind 2.4.0, because it needs a bug fix
(wjakob/nanobind#806) that landed in that
release.
* this PR does not port the in-tree dialect extension modules. They can
be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
in `PybindAdapters.h`. These ask pybind11 to try to form an overload
with an existing method, but it's not possible to form mixed
pybind11/nanobind overloads this ways and the parent class is now
defined in nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
protocol support. It was not hard to add a nanobind implementation of a
similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
the input is a sequence of bool types, not truthy values. In a couple of
places I added code to support truthy values during casting.
* nanobind distinguishes bytes (`nb::bytes`) from strings (e.g.,
`std::string`). This required nb::bytes overloads in a few places.
@jpienaar
Copy link
Member

Tested locally with python 3.8, relanding.

@jpienaar jpienaar merged commit b56d1ec into llvm:main Dec 19, 2024
6 of 7 checks passed
marbre added a commit to iree-org/llvm-project that referenced this pull request Jan 5, 2025
MaheshRavishankar added a commit to iree-org/llvm-project that referenced this pull request Jan 7, 2025
MaheshRavishankar added a commit to iree-org/llvm-project that referenced this pull request Jan 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants