diff --git a/rust/core/src/driver_exporter.rs b/rust/core/src/driver_exporter.rs new file mode 100644 index 0000000000..cc3ba8f0f5 --- /dev/null +++ b/rust/core/src/driver_exporter.rs @@ -0,0 +1,1646 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::{HashMap, HashSet}; +use std::ffi::{CStr, CString}; +use std::hash::Hash; +use std::os::raw::{c_char, c_int, c_void}; + +use arrow::array::StructArray; +use arrow::datatypes::DataType; +use arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; + +use crate::error::{Error, Result, Status}; +use crate::ffi::constants::ADBC_STATUS_OK; +use crate::ffi::{ + types::ErrorPrivateData, FFI_AdbcConnection, FFI_AdbcDatabase, FFI_AdbcDriver, FFI_AdbcError, + FFI_AdbcErrorDetail, FFI_AdbcPartitions, FFI_AdbcStatement, FFI_AdbcStatusCode, +}; +use crate::options::{InfoCode, ObjectDepth, OptionConnection, OptionDatabase, OptionValue}; +use crate::{Connection, Database, Driver, Optionable, Statement}; + +type DatabaseType = ::DatabaseType; +type ConnectionType = + <::DatabaseType as Database>::ConnectionType; +type StatementType = + <<::DatabaseType as Database>::ConnectionType as Connection>::StatementType; + +enum ExportedDatabase { + /// Pre-init options + Options(HashMap), + /// Initialized database + Database(DatabaseType), +} + +impl ExportedDatabase { + fn tuple( + &mut self, + ) -> ( + Option<&mut HashMap>, + Option<&mut DatabaseType>, + ) { + match self { + Self::Options(options) => (Some(options), None), + Self::Database(database) => (None, Some(database)), + } + } +} + +enum ExportedConnection { + /// Pre-init options + Options(HashMap), + /// Initialized connection + Connection(ConnectionType), +} + +impl ExportedConnection { + fn tuple( + &mut self, + ) -> ( + Option<&mut HashMap>, + Option<&mut ConnectionType>, + ) { + match self { + Self::Options(options) => (Some(options), None), + Self::Connection(connection) => (None, Some(connection)), + } + } + + fn try_connection(&mut self) -> Result<&mut ConnectionType> { + match self { + Self::Connection(connection) => Ok(connection), + _ => Err(Error::with_message_and_status( + "Connection not initialized", + Status::InvalidState, + )), + } + } +} + +struct ExportedStatement(StatementType); + +pub trait FFIDriver { + fn ffi_driver() -> FFI_AdbcDriver; +} + +impl FFIDriver for DriverType { + fn ffi_driver() -> FFI_AdbcDriver { + FFI_AdbcDriver { + private_data: std::ptr::null_mut(), + private_manager: std::ptr::null(), + release: Some(release_ffi_driver), + DatabaseInit: Some(database_init::), + DatabaseNew: Some(database_new::), + DatabaseSetOption: Some(database_set_option::), + DatabaseRelease: Some(database_release::), + ConnectionCommit: Some(connection_commit::), + ConnectionGetInfo: Some(connection_get_info::), + ConnectionGetObjects: Some(connection_get_objects::), + ConnectionGetTableSchema: Some(connection_get_table_schema::), + ConnectionGetTableTypes: Some(connection_get_table_types::), + ConnectionInit: Some(connection_init::), + ConnectionNew: Some(connection_new::), + ConnectionSetOption: Some(connection_set_option::), + ConnectionReadPartition: Some(connection_read_partition::), + ConnectionRelease: Some(connection_release::), + ConnectionRollback: Some(connection_rollback::), + StatementBind: Some(statement_bind::), + StatementBindStream: Some(statement_bind_stream::), + StatementExecuteQuery: Some(statement_execute_query::), + StatementExecutePartitions: Some(statement_execute_partitions::), + StatementGetParameterSchema: Some(statement_get_parameter_schema::), + StatementNew: Some(statement_new::), + StatementPrepare: Some(statement_prepare::), + StatementRelease: Some(statement_release::), + StatementSetOption: Some(statement_set_option::), + StatementSetSqlQuery: Some(statement_set_sql_query::), + StatementSetSubstraitPlan: Some(statement_set_substrait_plan::), + ErrorGetDetailCount: Some(error_get_detail_count), + ErrorGetDetail: Some(error_get_detail), + ErrorFromArrayStream: None, // TODO(alexandreyc): what to do with this? + DatabaseGetOption: Some(database_get_option::), + DatabaseGetOptionBytes: Some(database_get_option_bytes::), + DatabaseGetOptionDouble: Some(database_get_option_double::), + DatabaseGetOptionInt: Some(database_get_option_int::), + DatabaseSetOptionBytes: Some(database_set_option_bytes::), + DatabaseSetOptionDouble: Some(database_set_option_double::), + DatabaseSetOptionInt: Some(database_set_option_int::), + ConnectionCancel: Some(connection_cancel::), + ConnectionGetOption: Some(connection_get_option::), + ConnectionGetOptionBytes: Some(connection_get_option_bytes::), + ConnectionGetOptionDouble: Some(connection_get_option_double::), + ConnectionGetOptionInt: Some(connection_get_option_int::), + ConnectionGetStatistics: Some(connection_get_statistics::), + ConnectionGetStatisticNames: Some(connection_get_statistic_names::), + ConnectionSetOptionBytes: Some(connection_set_option_bytes::), + ConnectionSetOptionDouble: Some(connection_set_option_double::), + ConnectionSetOptionInt: Some(connection_set_option_int::), + StatementCancel: Some(statement_cancel::), + StatementExecuteSchema: Some(statement_execute_schema::), + StatementGetOption: Some(statement_get_option::), + StatementGetOptionBytes: Some(statement_get_option_bytes::), + StatementGetOptionDouble: Some(statement_get_option_double::), + StatementGetOptionInt: Some(statement_get_option_int::), + StatementSetOptionBytes: Some(statement_set_option_bytes::), + StatementSetOptionDouble: Some(statement_set_option_double::), + StatementSetOptionInt: Some(statement_set_option_int::), + } + } +} + +/// Export a Rust driver as a C driver. +/// +/// # Parameters +/// +/// - `$func_name` - Driver's initialization function name. The recommended name +/// is `AdbcDriverInit`, or a name derived from the name of the driver's shared +/// library as follows: remove the `lib` prefix (on Unix systems) and all file +/// extensions, then `PascalCase` the driver name, append `Init`, and prepend +/// `Adbc` (if not already there). For example: +/// - `libadbc_driver_sqlite.so.2.0.0` -> `AdbcDriverSqliteInit` +/// - `adbc_driver_sqlite.dll` -> `AdbcDriverSqliteInit` +/// - `proprietary_driver.dll` -> `AdbcProprietaryDriverInit` +/// - `$driver_type` - Driver's type which must implement [Driver] and [Default]. +/// Currently, the Rust driver is exported as an ADBC 1.1.0 C driver. +#[macro_export] +macro_rules! export_driver { + ($func_name:ident, $driver_type:ty) => { + #[no_mangle] + pub unsafe extern "C" fn $func_name( + version: std::os::raw::c_int, + driver: *mut std::os::raw::c_void, + error: *mut $crate::ffi::FFI_AdbcError, + ) -> $crate::ffi::FFI_AdbcStatusCode { + let version = + $crate::check_err!($crate::options::AdbcVersion::try_from(version), error); + if version != $crate::options::AdbcVersion::V110 { + let err = $crate::error::Error::with_message_and_status( + format!( + "Unsupported ADBC version: got={:?} expected={:?}", + version, + $crate::options::AdbcVersion::V110 + ), + $crate::error::Status::NotImplemented, + ); + $crate::check_err!(Err(err), error); + } + $crate::check_not_null!(driver, error); + + let ffi_driver = <$driver_type as $crate::FFIDriver>::ffi_driver(); + unsafe { + std::ptr::write_unaligned(driver as *mut $crate::ffi::FFI_AdbcDriver, ffi_driver); + } + $crate::ffi::constants::ADBC_STATUS_OK + } + }; +} + +/// Given a Result, either unwrap the value or handle the error in ADBC function. +/// +/// This macro is for use when implementing ADBC methods that have an out +/// parameter for [FFI_AdbcError] and return [FFI_AdbcStatusCode]. If the result is +/// `Ok`, the expression resolves to the value. Otherwise, it will return early, +/// setting the error and status code appropriately. In order for this to work, +/// the error must be convertible to [crate::error::Error]. +#[doc(hidden)] +#[macro_export] +macro_rules! check_err { + ($res:expr, $err_out:expr) => { + match $res { + Ok(x) => x, + Err(error) => { + let error = $crate::error::Error::from(error); + let status: $crate::ffi::FFI_AdbcStatusCode = error.status.into(); + if !$err_out.is_null() { + let mut ffi_error = + $crate::ffi::FFI_AdbcError::try_from(error).unwrap_or_else(Into::into); + ffi_error.private_driver = (*$err_out).private_driver; + unsafe { std::ptr::write_unaligned($err_out, ffi_error) }; + } + return status; + } + } + }; +} + +/// Check that the given raw pointer is not null. +/// +/// If null, an error is returned from the enclosing function, otherwise this is +/// a no-op. +#[doc(hidden)] +#[macro_export] +macro_rules! check_not_null { + ($ptr:ident, $err_out:expr) => { + let res = if $ptr.is_null() { + Err(Error::with_message_and_status( + format!("Passed null pointer for argument {:?}", stringify!($ptr)), + Status::InvalidArguments, + )) + } else { + Ok(()) + }; + $crate::check_err!(res, $err_out); + }; +} + +unsafe extern "C" fn release_ffi_driver( + driver: *mut FFI_AdbcDriver, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + if let Some(driver) = driver.as_mut() { + if driver.release.take().is_none() { + check_err!( + Err(Error::with_message_and_status( + "Driver already released", + Status::InvalidState + )), + error + ); + } + } + ADBC_STATUS_OK +} + +// Option helpers + +// SAFETY: `dst` and `length` must be not null otherwise the function will panic. +unsafe fn copy_string(src: &str, dst: *mut c_char, length: *mut usize) -> Result<()> { + assert!(!dst.is_null() && !length.is_null()); + let src = CString::new(src)?; + let n = src.to_bytes_with_nul().len(); + if n <= *length { + std::ptr::copy_nonoverlapping(src.as_ptr(), dst, n); + } + *length = n; + Ok::<(), Error>(()) +} + +// SAFETY: `dst` and `length` must be not null otherwise the function will panic. +unsafe fn copy_bytes(src: &[u8], dst: *mut u8, length: *mut usize) { + assert!(!dst.is_null() && !length.is_null()); + let n = src.len(); + if n <= *length { + std::ptr::copy_nonoverlapping(src.as_ptr(), dst, n); + } + *length = n; +} + +// SAFETY: Will panic if `key` is null. +unsafe fn get_option_int<'a, OptionType, Object>( + object: Option<&mut Object>, + options: Option<&mut HashMap>, + key: *const c_char, +) -> Result +where + OptionType: Hash + Eq + From<&'a str>, + Object: Optionable