|
2 | 2 |
|
3 | 3 | import copy
|
4 | 4 | import ctypes
|
| 5 | +import functools |
5 | 6 | import json
|
6 |
| -from typing import Literal, Optional, Protocol, Tuple, Type, TypedDict, Union, cast |
| 7 | +from typing import ( |
| 8 | + TYPE_CHECKING, |
| 9 | + Any, |
| 10 | + Dict, |
| 11 | + Literal, |
| 12 | + Optional, |
| 13 | + Protocol, |
| 14 | + Tuple, |
| 15 | + Type, |
| 16 | + TypedDict, |
| 17 | + TypeGuard, |
| 18 | + Union, |
| 19 | + cast, |
| 20 | + overload, |
| 21 | +) |
7 | 22 |
|
8 | 23 | import numpy as np
|
9 | 24 |
|
10 |
| -from ._typing import CNumericPtr, DataType, NumpyOrCupy |
11 |
| -from .compat import import_cupy |
| 25 | +from ._typing import CNumericPtr, DataType, NumpyDType, NumpyOrCupy |
| 26 | +from .compat import import_cupy, lazy_isinstance |
| 27 | + |
| 28 | +if TYPE_CHECKING: |
| 29 | + import pandas as pd |
| 30 | + import pyarrow as pa |
12 | 31 |
|
13 | 32 |
|
| 33 | +# Used for accepting inputs for numpy and cupy arrays |
14 | 34 | class _ArrayLikeArg(Protocol):
|
15 | 35 | @property
|
16 | 36 | def __array_interface__(self) -> "ArrayInf": ...
|
@@ -44,19 +64,27 @@ def shape(self) -> Tuple[int, int]:
|
44 | 64 | },
|
45 | 65 | )
|
46 | 66 |
|
| 67 | +StringArray = TypedDict("StringArray", {"offsets": ArrayInf, "values": ArrayInf}) |
| 68 | + |
47 | 69 |
|
48 | 70 | def array_hasobject(data: DataType) -> bool:
|
49 | 71 | """Whether the numpy array has object dtype."""
|
50 | 72 | return hasattr(data.dtype, "hasobject") and data.dtype.hasobject
|
51 | 73 |
|
52 | 74 |
|
53 |
| -def cuda_array_interface(data: DataType) -> bytes: |
54 |
| - """Make cuda array interface str.""" |
| 75 | +def cuda_array_interface_dict(data: _CudaArrayLikeArg) -> ArrayInf: |
| 76 | + """Returns a dictionary storing the CUDA array interface.""" |
55 | 77 | if array_hasobject(data):
|
56 | 78 | raise ValueError("Input data contains `object` dtype. Expecting numeric data.")
|
57 |
| - interface = data.__cuda_array_interface__ |
58 |
| - if "mask" in interface: |
59 |
| - interface["mask"] = interface["mask"].__cuda_array_interface__ |
| 79 | + ainf = data.__cuda_array_interface__ |
| 80 | + if "mask" in ainf: |
| 81 | + ainf["mask"] = ainf["mask"].__cuda_array_interface__ # type: ignore |
| 82 | + return cast(ArrayInf, ainf) |
| 83 | + |
| 84 | + |
| 85 | +def cuda_array_interface(data: _CudaArrayLikeArg) -> bytes: |
| 86 | + """Make cuda array interface str.""" |
| 87 | + interface = cuda_array_interface_dict(data) |
60 | 88 | interface_str = bytes(json.dumps(interface), "utf-8")
|
61 | 89 | return interface_str
|
62 | 90 |
|
@@ -107,6 +135,12 @@ def __cuda_array_interface__(self, interface: ArrayInf) -> None:
|
107 | 135 | return out
|
108 | 136 |
|
109 | 137 |
|
| 138 | +# Default constant value for CUDA per-thread stream. |
| 139 | +STREAM_PER_THREAD = 2 |
| 140 | + |
| 141 | + |
| 142 | +# Typing is not strict as there are subtle differences between CUDA array interface and |
| 143 | +# array interface. We handle them uniformly for now. |
110 | 144 | def make_array_interface(
|
111 | 145 | ptr: Union[CNumericPtr, int],
|
112 | 146 | shape: Tuple[int, ...],
|
@@ -134,21 +168,157 @@ def make_array_interface(
|
134 | 168 | return array
|
135 | 169 |
|
136 | 170 | array["data"] = (addr, True)
|
137 |
| - if is_cuda: |
138 |
| - array["stream"] = 2 |
| 171 | + if is_cuda and "stream" not in array: |
| 172 | + array["stream"] = STREAM_PER_THREAD |
139 | 173 | array["shape"] = shape
|
140 | 174 | array["strides"] = None
|
141 | 175 | return array
|
142 | 176 |
|
143 | 177 |
|
144 |
| -def array_interface_dict(data: np.ndarray) -> ArrayInf: |
145 |
| - """Convert array interface into a Python dictionary.""" |
| 178 | +def is_arrow_dict(data: Any) -> TypeGuard["pa.DictionaryArray"]: |
| 179 | + """Is this an arrow dictionary array?""" |
| 180 | + return lazy_isinstance(data, "pyarrow.lib", "DictionaryArray") |
| 181 | + |
| 182 | + |
| 183 | +class PdCatAccessor(Protocol): |
| 184 | + """Protocol for pandas cat accessor.""" |
| 185 | + |
| 186 | + @property |
| 187 | + def categories( # pylint: disable=missing-function-docstring |
| 188 | + self, |
| 189 | + ) -> "pd.Index": ... |
| 190 | + |
| 191 | + @property |
| 192 | + def codes(self) -> "pd.Series": ... # pylint: disable=missing-function-docstring |
| 193 | + |
| 194 | + @property |
| 195 | + def dtype(self) -> np.dtype: ... # pylint: disable=missing-function-docstring |
| 196 | + |
| 197 | + def to_arrow( # pylint: disable=missing-function-docstring |
| 198 | + self, |
| 199 | + ) -> Union["pa.StringArray", "pa.IntegerArray"]: ... |
| 200 | + |
| 201 | + @property |
| 202 | + def __cuda_array_interface__(self) -> ArrayInf: ... |
| 203 | + |
| 204 | + |
| 205 | +def _is_pd_cat(data: Any) -> TypeGuard[PdCatAccessor]: |
| 206 | + # Test pd.Series.cat, not pd.Series |
| 207 | + return hasattr(data, "categories") and hasattr(data, "codes") |
| 208 | + |
| 209 | + |
| 210 | +@functools.cache |
| 211 | +def _arrow_typestr() -> Dict["pa.DataType", str]: |
| 212 | + import pyarrow as pa |
| 213 | + |
| 214 | + mapping = { |
| 215 | + pa.int8(): "<i1", |
| 216 | + pa.int16(): "<i2", |
| 217 | + pa.int32(): "<i4", |
| 218 | + pa.int64(): "<i8", |
| 219 | + pa.uint8(): "<u1", |
| 220 | + pa.uint16(): "<u2", |
| 221 | + pa.uint32(): "<u4", |
| 222 | + pa.uint64(): "<u8", |
| 223 | + } |
| 224 | + |
| 225 | + return mapping |
| 226 | + |
| 227 | + |
| 228 | +def npstr_to_arrow_strarr(strarr: np.ndarray) -> Tuple[np.ndarray, str]: |
| 229 | + """Convert a numpy string array to an arrow string array.""" |
| 230 | + lenarr = np.vectorize(len) |
| 231 | + offsets = np.cumsum(np.concatenate([np.array([0], dtype=np.int64), lenarr(strarr)])) |
| 232 | + values = strarr.sum() |
| 233 | + assert "\0" not in values # arrow string array doesn't need null terminal |
| 234 | + return offsets.astype(np.int32), values |
| 235 | + |
| 236 | + |
| 237 | +def _ensure_np_dtype( |
| 238 | + data: DataType, dtype: Optional[NumpyDType] |
| 239 | +) -> Tuple[np.ndarray, Optional[NumpyDType]]: |
| 240 | + """Ensure the np array has correct type and is contiguous.""" |
| 241 | + if array_hasobject(data) or data.dtype in [np.float16, np.bool_]: |
| 242 | + dtype = np.float32 |
| 243 | + data = data.astype(dtype, copy=False) |
| 244 | + if not data.flags.aligned: |
| 245 | + data = np.require(data, requirements="A") |
| 246 | + return data, dtype |
| 247 | + |
| 248 | + |
| 249 | +@overload |
| 250 | +def array_interface_dict(data: np.ndarray) -> ArrayInf: ... |
| 251 | + |
| 252 | + |
| 253 | +@overload |
| 254 | +def array_interface_dict( |
| 255 | + data: PdCatAccessor, |
| 256 | +) -> Tuple[StringArray, ArrayInf, Tuple]: ... |
| 257 | + |
| 258 | + |
| 259 | +@overload |
| 260 | +def array_interface_dict( |
| 261 | + data: "pa.DictionaryArray", |
| 262 | +) -> Tuple[StringArray, ArrayInf, Tuple]: ... |
| 263 | + |
| 264 | + |
| 265 | +def array_interface_dict( # pylint: disable=too-many-locals |
| 266 | + data: Union[np.ndarray, PdCatAccessor], |
| 267 | +) -> Union[ArrayInf, Tuple[StringArray, ArrayInf, Optional[Tuple]]]: |
| 268 | + """Returns an array interface from the input.""" |
| 269 | + # Handle categorical values |
| 270 | + if _is_pd_cat(data): |
| 271 | + cats = data.categories |
| 272 | + # pandas uses -1 to represent missing values for categorical features |
| 273 | + codes = data.codes.replace(-1, np.nan) |
| 274 | + |
| 275 | + if np.issubdtype(cats.dtype, np.floating) or np.issubdtype( |
| 276 | + cats.dtype, np.integer |
| 277 | + ): |
| 278 | + # Numeric index type |
| 279 | + name_values = cats.values |
| 280 | + jarr_values = array_interface_dict(name_values) |
| 281 | + code_values = codes.values |
| 282 | + jarr_codes = array_interface_dict(code_values) |
| 283 | + return jarr_values, jarr_codes, (name_values, code_values) |
| 284 | + |
| 285 | + # String index type |
| 286 | + name_offsets, name_values = npstr_to_arrow_strarr(cats.values) |
| 287 | + name_offsets, _ = _ensure_np_dtype(name_offsets, np.int32) |
| 288 | + joffsets = array_interface_dict(name_offsets) |
| 289 | + bvalues = name_values.encode("utf-8") |
| 290 | + ptr = ctypes.c_void_p.from_buffer(ctypes.c_char_p(bvalues)).value |
| 291 | + assert ptr is not None |
| 292 | + |
| 293 | + jvalues: ArrayInf = { |
| 294 | + "data": (ptr, True), |
| 295 | + "typestr": "|i1", |
| 296 | + "shape": (len(name_values),), |
| 297 | + "strides": None, |
| 298 | + "version": 3, |
| 299 | + "mask": None, |
| 300 | + } |
| 301 | + jnames: StringArray = {"offsets": joffsets, "values": jvalues} |
| 302 | + |
| 303 | + code_values = codes.values |
| 304 | + jcodes = array_interface_dict(code_values) |
| 305 | + |
| 306 | + buf = ( |
| 307 | + name_offsets, |
| 308 | + name_values, |
| 309 | + bvalues, |
| 310 | + code_values, |
| 311 | + ) # store temporary values |
| 312 | + return jnames, jcodes, buf |
| 313 | + |
| 314 | + # Handle numeric values |
| 315 | + assert isinstance(data, np.ndarray) |
146 | 316 | if array_hasobject(data):
|
147 | 317 | raise ValueError("Input data contains `object` dtype. Expecting numeric data.")
|
148 |
| - arrinf = data.__array_interface__ |
149 |
| - if "mask" in arrinf: |
150 |
| - arrinf["mask"] = arrinf["mask"].__array_interface__ |
151 |
| - return cast(ArrayInf, arrinf) |
| 318 | + ainf = data.__array_interface__ |
| 319 | + if "mask" in ainf: |
| 320 | + ainf["mask"] = ainf["mask"].__array_interface__ |
| 321 | + return cast(ArrayInf, ainf) |
152 | 322 |
|
153 | 323 |
|
154 | 324 | def array_interface(data: np.ndarray) -> bytes:
|
|
0 commit comments