-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir python] Port Python core code to nanobind. #120473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Peter Hawkins (hawkinsp) ChangesRelands #118583, with a fix for Python 3.8 compatibility. It was not possible to set the buffer protocol accessers via slots in Python 3.8. Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX. For a complicated Google-internal LLM model in JAX, this change improves the MLIR To a large extent, this is a mechanical change, for instance changing Notes:
Patch is 356.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120473.diff 23 Files Affected:
diff --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
index c62ac7fa615ea6..d6bb65c64b8292 100644
--- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
+++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
@@ -39,7 +39,7 @@ macro(mlir_configure_python_dev_packages)
"extension = '${PYTHON_MODULE_EXTENSION}")
mlir_detect_nanobind_install()
- find_package(nanobind 2.2 CONFIG REQUIRED)
+ find_package(nanobind 2.4 CONFIG REQUIRED)
message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
"suffix = '${PYTHON_MODULE_SUFFIX}', "
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 9afad4c23b3f35..ba9642cf2c6a2d 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace mlir {
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 5e01cebcb09c91..943981b1fa03dd 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -64,7 +64,7 @@ static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) {
/// Casts object <-> MlirAffineMap.
template <>
struct type_caster<MlirAffineMap> {
- NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"));
+ NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToAffineMap(capsule.ptr());
@@ -87,7 +87,7 @@ struct type_caster<MlirAffineMap> {
/// Casts object <-> MlirAttribute.
template <>
struct type_caster<MlirAttribute> {
- NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"));
+ NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToAttribute(capsule.ptr());
@@ -108,7 +108,7 @@ struct type_caster<MlirAttribute> {
/// Casts object -> MlirBlock.
template <>
struct type_caster<MlirBlock> {
- NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"));
+ NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToBlock(capsule.ptr());
@@ -119,7 +119,7 @@ struct type_caster<MlirBlock> {
/// Casts object -> MlirContext.
template <>
struct type_caster<MlirContext> {
- NB_TYPE_CASTER(MlirContext, const_name("MlirContext"));
+ NB_TYPE_CASTER(MlirContext, const_name("MlirContext"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
if (src.is_none()) {
// Gets the current thread-bound context.
@@ -139,7 +139,7 @@ struct type_caster<MlirContext> {
/// Casts object <-> MlirDialectRegistry.
template <>
struct type_caster<MlirDialectRegistry> {
- NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"));
+ NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
@@ -159,7 +159,7 @@ struct type_caster<MlirDialectRegistry> {
/// Casts object <-> MlirLocation.
template <>
struct type_caster<MlirLocation> {
- NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"));
+ NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
if (src.is_none()) {
// Gets the current thread-bound context.
@@ -185,7 +185,7 @@ struct type_caster<MlirLocation> {
/// Casts object <-> MlirModule.
template <>
struct type_caster<MlirModule> {
- NB_TYPE_CASTER(MlirModule, const_name("MlirModule"));
+ NB_TYPE_CASTER(MlirModule, const_name("MlirModule"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToModule(capsule.ptr());
@@ -206,7 +206,7 @@ struct type_caster<MlirModule> {
template <>
struct type_caster<MlirFrozenRewritePatternSet> {
NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
- const_name("MlirFrozenRewritePatternSet"));
+ const_name("MlirFrozenRewritePatternSet"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
@@ -225,7 +225,7 @@ struct type_caster<MlirFrozenRewritePatternSet> {
/// Casts object <-> MlirOperation.
template <>
struct type_caster<MlirOperation> {
- NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"));
+ NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToOperation(capsule.ptr());
@@ -247,7 +247,7 @@ struct type_caster<MlirOperation> {
/// Casts object <-> MlirValue.
template <>
struct type_caster<MlirValue> {
- NB_TYPE_CASTER(MlirValue, const_name("MlirValue"));
+ NB_TYPE_CASTER(MlirValue, const_name("MlirValue"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToValue(capsule.ptr());
@@ -270,7 +270,7 @@ struct type_caster<MlirValue> {
/// Casts object -> MlirPassManager.
template <>
struct type_caster<MlirPassManager> {
- NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"));
+ NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToPassManager(capsule.ptr());
@@ -281,7 +281,7 @@ struct type_caster<MlirPassManager> {
/// Casts object <-> MlirTypeID.
template <>
struct type_caster<MlirTypeID> {
- NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"));
+ NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToTypeID(capsule.ptr());
@@ -303,7 +303,7 @@ struct type_caster<MlirTypeID> {
/// Casts object <-> MlirType.
template <>
struct type_caster<MlirType> {
- NB_TYPE_CASTER(MlirType, const_name("MlirType"));
+ NB_TYPE_CASTER(MlirType, const_name("MlirType"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToType(capsule.ptr());
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index c8233355d1d67b..edc69774be9227 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -374,9 +374,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_staticmethod(...) called with a non-static member "
"function pointer");
- py::cpp_function cf(
- std::forward<Func>(f), py::name(name), py::scope(thisClass),
- py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+ py::cpp_function cf(std::forward<Func>(f), py::name(name),
+ py::scope(thisClass), extra...);
thisClass.attr(cf.name()) = py::staticmethod(cf);
return *this;
}
@@ -387,9 +386,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_classmethod(...) called with a non-static member "
"function pointer");
- py::cpp_function cf(
- std::forward<Func>(f), py::name(name), py::scope(thisClass),
- py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+ py::cpp_function cf(std::forward<Func>(f), py::name(name),
+ py::scope(thisClass), extra...);
thisClass.attr(cf.name()) =
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
return *this;
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index a022067f5c7e57..0ec522d14f74bd 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -9,18 +9,17 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
-#include "PybindUtils.h"
+#include <optional>
+#include <string>
+#include <vector>
+#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
-#include <optional>
-#include <string>
-#include <vector>
-
namespace mlir {
namespace python {
@@ -57,55 +56,55 @@ class PyGlobals {
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
- pybind11::function pyFunc,
+ nanobind::callable pyFunc,
bool replace = false);
/// Adds a user-friendly type caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
- void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
+ void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
bool replace = false);
/// Adds a user-friendly value caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerValueCaster(MlirTypeID mlirTypeID,
- pybind11::function valueCaster,
+ nanobind::callable valueCaster,
bool replace = false);
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerDialectImpl(const std::string &dialectNamespace,
- pybind11::object pyClass);
+ nanobind::object pyClass);
/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
- pybind11::object pyClass, bool replace = false);
+ nanobind::object pyClass, bool replace = false);
/// Returns the custom Attribute builder for Attribute kind.
- std::optional<pybind11::function>
+ std::optional<nanobind::callable>
lookupAttributeBuilder(const std::string &attributeKind);
/// Returns the custom type caster for MlirTypeID mlirTypeID.
- std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
+ std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);
/// Returns the custom value caster for MlirTypeID mlirTypeID.
- std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
+ std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
- std::optional<pybind11::object>
+ std::optional<nanobind::object>
lookupDialectClass(const std::string &dialectNamespace);
/// Looks up a registered operation class (deriving from OpView) by operation
/// name. Note that this may trigger a load of the dialect, which can
/// arbitrarily re-enter.
- std::optional<pybind11::object>
+ std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
private:
@@ -113,15 +112,15 @@ class PyGlobals {
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
- llvm::StringMap<pybind11::object> dialectClassMap;
+ llvm::StringMap<nanobind::object> dialectClassMap;
/// Map of full operation name to external operation class object.
- llvm::StringMap<pybind11::object> operationClassMap;
+ llvm::StringMap<nanobind::object> operationClassMap;
/// Map of attribute ODS name to custom builder.
- llvm::StringMap<pybind11::object> attributeBuilderMap;
+ llvm::StringMap<nanobind::callable> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
- llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
+ llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
/// Map of MlirTypeID to custom value caster.
- llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
+ llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index b138e131e851ea..2db690309fab8c 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -6,20 +6,19 @@
//
//===----------------------------------------------------------------------===//
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/vector.h>
+
#include <cstddef>
#include <cstdint>
-#include <pybind11/cast.h>
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
-#include <pybind11/pytypes.h>
+#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include "IRModule.h"
-
-#include "PybindUtils.h"
-
+#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/Bindings/Python/Interop.h"
@@ -30,7 +29,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
-namespace py = pybind11;
+namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
@@ -46,23 +45,23 @@ static const char kDumpDocstring[] =
/// Throws errors in case of failure, using "action" to describe what the caller
/// was attempting to do.
template <typename PyType, typename CType>
-static void pyListToVector(const py::list &list,
+static void pyListToVector(const nb::list &list,
llvm::SmallVectorImpl<CType> &result,
StringRef action) {
- result.reserve(py::len(list));
- for (py::handle item : list) {
+ result.reserve(nb::len(list));
+ for (nb::handle item : list) {
try {
- result.push_back(item.cast<PyType>());
- } catch (py::cast_error &err) {
+ result.push_back(nb::cast<PyType>(item));
+ } catch (nb::cast_error &err) {
std::string msg = (llvm::Twine("Invalid expression when ") + action +
" (" + err.what() + ")")
.str();
- throw py::cast_error(msg);
- } catch (py::reference_cast_error &err) {
+ throw std::runtime_error(msg.c_str());
+ } catch (std::runtime_error &err) {
std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
action + " (" + err.what() + ")")
.str();
- throw py::cast_error(msg);
+ throw std::runtime_error(msg.c_str());
}
}
}
@@ -94,7 +93,7 @@ class PyConcreteAffineExpr : public BaseTy {
// IsAFunctionTy isaFunction
// const char *pyClassName
// and redefine bindDerived.
- using ClassTy = py::class_<DerivedTy, BaseTy>;
+ using ClassTy = nb::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirAffineExpr);
PyConcreteAffineExpr() = default;
@@ -105,24 +104,25 @@ class PyConcreteAffineExpr : public BaseTy {
static MlirAffineExpr castFrom(PyAffineExpr &orig) {
if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw py::value_error((Twine("Cannot cast affine expression to ") +
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+ throw nb::value_error((Twine("Cannot cast affine expression to ") +
DerivedTy::pyClassName + " (from " + origRepr +
")")
- .str());
+ .str()
+ .c_str());
}
return orig;
}
- static void bind(py::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
- cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
+ static void bind(nb::module_ &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
cls.def_static(
"isinstance",
[](PyAffineExpr &otherAffineExpr) -> bool {
return DerivedTy::isaFunction(otherAffineExpr);
},
- py::arg("other"));
+ nb::arg("other"));
DerivedTy::bindDerived(cls);
}
@@ -144,9 +144,9 @@ class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
}
static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
- py::arg("context") = py::none());
- c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
+ c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"),
+ nb::arg("context").none() = nb::none());
+ c.def_prop_ro("value", [](PyAffineConstantExpr &self) {
return mlirAffineConstantExprGetValue(self);
});
}
@@ -164,9 +164,9 @@ class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
}
static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
- py::arg("context") = py::none());
- c.def_property_readonly("position", [](PyAffineDimExpr &self) {
+ c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"),
+ nb::arg("context").none() = nb::none());
+ c.def_prop_ro("position", [](PyAffineDimExpr &self) {
return mlirAffineDimExprGetPosition(self);
});
}
@@ -184,9 +184,9 @@ class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
}
static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
- py::arg("context") = py::none());
- c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
+ c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"),
+ nb::arg("context").none() = nb::none());
+ c.def_prop_ro("position", [](PyAffineSymbolExpr &self) {
return mlirAffineSymbolExprGetPosition(self);
});
}
@@ -209,8 +209,8 @@ class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
}
static void bindDerived(ClassTy &c) {
- c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
- c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
+ c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs);
+ c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs);
}
};
@@ -365,15 +365,14 @@ bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
return mlirAffineExprEqual(affineExpr, other...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
40a38c0
to
ad87b33
Compare
Relands llvm#118583, with a fix for Python 3.8 compatibility. It was not possible to set the buffer protocol accessers via slots in Python 3.8. Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX. For a complicated Google-internal LLM model in JAX, this change improves the MLIR lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks. To a large extent, this is a mechanical change, for instance changing `pybind11::` to `nanobind::`. Notes: * this PR needs Nanobind 2.4.0, because it needs a bug fix (wjakob/nanobind#806) that landed in that release. * this PR does not port the in-tree dialect extension modules. They can be ported in a future PR. * I removed the py::sibling() annotations from def_static and def_class in `PybindAdapters.h`. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here. * nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API. * nanobind is pickier about casting to std::vector<bool>, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting. * nanobind distinguishes bytes (`nb::bytes`) from strings (e.g., `std::string`). This required nb::bytes overloads in a few places.
Tested locally with python 3.8, relanding. |
This reverts commit b56d1ec.
This reverts commit b56d1ec.
This reverts commit b56d1ec.
Relands #118583, with a fix for Python 3.8 compatibility. It was not possible to set the buffer protocol accessers via slots in Python 3.8.
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX.
For a complicated Google-internal LLM model in JAX, this change improves the MLIR
lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks.
To a large extent, this is a mechanical change, for instance changing
pybind11::
tonanobind::
.Notes:
PybindAdapters.h
. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here.nb::bytes
) from strings (e.g.,std::string
). This required nb::bytes overloads in a few places.