From f68cbce2d38dbd7be6ca40ab0ea7086fd59436a5 Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Wed, 24 Jan 2024 19:12:57 +0000 Subject: [PATCH] Major refactors - serde for most neurons - better I/O across boundary - more consistent use of ScoreMatrix - benchmarks and tests --- nblast-py/Cargo.toml | 4 +- nblast-py/pyproject.toml | 3 +- nblast-py/python/pynblast/__init__.py | 11 +- nblast-py/python/pynblast/arena.py | 85 ++++-- nblast-py/python/pynblast/pynblast.pyi | 22 +- nblast-py/python/pynblast/score_matrix.py | 34 ++- nblast-py/python/pynblast/smat_builder.py | 33 ++- nblast-py/python/pynblast/util.py | 37 ++- nblast-py/requirements.txt | 1 + nblast-py/src/lib.rs | 302 ++++++++++++++++------ nblast-py/tests/conftest.py | 28 +- nblast-py/tests/test_bench.py | 68 +++++ nblast-py/tests/test_nblast.py | 56 +++- nblast-py/tests/test_smat_builder.py | 5 + nblast-rs/Cargo.toml | 4 +- nblast-rs/src/lib.rs | 94 +++++-- nblast-rs/src/neurons/any.rs | 144 +++++++++++ nblast-rs/src/neurons/bosque.rs | 3 + nblast-rs/src/neurons/kiddo.rs | 9 +- nblast-rs/src/neurons/mod.rs | 7 + nblast-rs/src/neurons/rstar.rs | 4 + 21 files changed, 782 insertions(+), 172 deletions(-) create mode 100644 nblast-py/tests/test_bench.py create mode 100644 nblast-py/tests/test_smat_builder.py create mode 100644 nblast-rs/src/neurons/any.rs diff --git a/nblast-py/Cargo.toml b/nblast-py/Cargo.toml index cba19d3..c00c967 100644 --- a/nblast-py/Cargo.toml +++ b/nblast-py/Cargo.toml @@ -9,8 +9,10 @@ edition = "2021" [dependencies] pyo3 = { version = "0.20", features = ["extension-module", "abi3-py39"] } neurarbor = "0.2.0" -nblast = { path = "../nblast-rs", version = "^0.7.1", features = ["parallel", "kiddo"] } +nblast = { path = "../nblast-rs", version = "^0.7.1", features = ["parallel", "kiddo", "serde"] } numpy = "0.20" +ciborium = "0.2.2" +serde_json = "1.0.111" [lib] name = "pynblast" diff --git a/nblast-py/pyproject.toml b/nblast-py/pyproject.toml index 812ca0a..09f030d 100644 --- a/nblast-py/pyproject.toml +++ b/nblast-py/pyproject.toml @@ -19,7 +19,8 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "numpy >= 1.22.4" + "numpy >= 1.22.4", + "backports.strenum; python_version < '3.11'" ] [project.urls] diff --git a/nblast-py/python/pynblast/__init__.py b/nblast-py/python/pynblast/__init__.py index f4875c8..51e6d31 100644 --- a/nblast-py/python/pynblast/__init__.py +++ b/nblast-py/python/pynblast/__init__.py @@ -1,18 +1,22 @@ # -*- coding: utf-8 -*- -"""Top-level package for nblast-rs.""" +"""Top-level package for pynblast. + +The main entry point for this package is the `NblastArena` class. +""" __author__ = """Chris L. Barnes""" __email__ = "chrislloydbarnes@gmail.com" -from .pynblast import get_version as _get_version, ResamplingArbor +from .pynblast import get_version as _get_version, ResamplingArbor, backend as _backend __version__ = _get_version() -__version_info__ = tuple(int(n) for n in __version__.split(".")) +__version_info__ = tuple(int(n) for n in __version__.split(".")[:3]) from .util import rectify_tangents, Idx, Symmetry from .arena import NblastArena from .score_matrix import ScoreMatrix +from .smat_builder import ScoreMatrixBuilder __all__ = [ "NblastArena", @@ -21,4 +25,5 @@ "rectify_tangents", "Idx", "ResamplingArbor", + "ScoreMatrixBuilder", ] diff --git a/nblast-py/python/pynblast/arena.py b/nblast-py/python/pynblast/arena.py index c45d6a5..a08dd44 100644 --- a/nblast-py/python/pynblast/arena.py +++ b/nblast-py/python/pynblast/arena.py @@ -1,9 +1,13 @@ from typing import List, Tuple, Dict, Iterator, Optional +import warnings +from copy import copy import numpy as np +from .score_matrix import ScoreMatrix from .pynblast import ArenaWrapper from .util import Idx, raise_if_none, rectify_tangents, Symmetry +import pandas as pd DEFAULT_THREADS = 0 DEFAULT_K = 20 @@ -12,13 +16,21 @@ class NblastArena: """ Class for creating and keeping track of many neurons for comparison with NBLAST. + + Create the arena with a score matrix. + Then use `arena.add_points()` to add point clouds and return indices used for querying + (if you have already calculated tangents and alphas *using the same neighborhood size `k`*, + use `arena.add_points_tangents_alphas()`). + Use `arena.query_target()`, `arena.queries_targets()`, and + `arena.all_v_all()` to perform queries. + + You can retrieve the points, tangents, alpha values, or everything about a neuron + with `arena.points()`, `arena.tangents()`, `arena.alphas()`, and `arena.neuron_table()`. """ def __init__( self, - dist_bins: List[float], - dot_bins: List[float], - score_mat: np.ndarray, + score_mat: ScoreMatrix, use_alpha: bool = False, threads: Optional[int] = DEFAULT_THREADS, k=DEFAULT_K, @@ -27,18 +39,8 @@ def __init__( The required arguments describe a lookup table which is used to convert ``(distance, abs_dot_product)`` tuples into a score for a single point match. - The ``*_bins`` arguments describe the bounds of the bins: - N bounds make for N-1 bins. - Queries are clamped to the domain of the lookup. - ``score_mat`` is the table of values, in dist-major order. - - For example, if the lookup table was stored as a pandas dataframe, - where the distance bins were in the left margin and the absolute dot - product bins in the top margin, the object would be instantiated by - - >>> arena = NblastArena(df.index, df.columns, df.to_numpy()) - - See the ``ScoreMatrix`` namedtuple for convenience. + Queries are clamped to the domain of the lookup + (i.e. the first and last values are effectively replaced with -inf and +inf respectively). ``k`` gives the number of points to use when calculating tangents. @@ -51,11 +53,13 @@ def __init__( self.threads = threads self.k = k - if score_mat.shape != (len(dist_bins) - 1, len(dot_bins) - 1): - raise ValueError("Bin thresholds do not match score matrix") - score_vec = score_mat.flatten().tolist() self._impl = ArenaWrapper( - dist_bins, dot_bins, score_vec, self.k, self.use_alpha, self.threads + score_mat.dist_thresholds.tolist(), + score_mat.dot_thresholds.tolist(), + score_mat._flat_values().tolist(), + self.k, + self.use_alpha, + self.threads, ) def add_points(self, points: np.ndarray) -> Idx: @@ -74,7 +78,7 @@ def add_points_tangents_alphas( self, points: np.ndarray, tangents: np.ndarray, - alphas: Optional[np.ndarray], + alphas: np.ndarray, ) -> Idx: """Add an Nx3 point cloud representing a neuron, with pre-calculated tangents. Tangents must be unit-length and in the same order as the points. @@ -95,6 +99,10 @@ def add_points_tangents_alphas( "Alpha values not given, but this NblastArena uses alpha weighting" ) else: + warnings.warn( + "Alpha values should be given, even if they are not used", + PendingDeprecationWarning, + ) alphas = np.full(len(points), 1.0) else: alphas = np.asarray(alphas) @@ -186,7 +194,7 @@ def points(self, idx) -> np.ndarray: Order is arbitrary. """ - return np.array(raise_if_none(self._impl.points(idx), idx)) + return np.asarray(raise_if_none(self._impl.points(idx), idx)) def tangents(self, idx, rectify=False) -> np.ndarray: """Return a copy of the tangents associated with the indexed neuron. @@ -194,7 +202,7 @@ def tangents(self, idx, rectify=False) -> np.ndarray: Order is arbitrary, but consistent with the order returned by the ``.points`` method. """ - out = np.array(raise_if_none(self._impl.tangents(idx), idx)) + out = np.asarray(raise_if_none(self._impl.tangents(idx), idx)) return rectify_tangents(out, True) if rectify else out def alphas(self, idx) -> np.ndarray: @@ -203,4 +211,35 @@ def alphas(self, idx) -> np.ndarray: Order is arbitrary, but consistent with the order returned by the ``.points`` method. """ - return np.array(raise_if_none(self._impl.alphas(idx), idx)) + return np.asarray(raise_if_none(self._impl.alphas(idx), idx)) + + def neuron_table(self, idx: Idx) -> pd.DataFrame: + """Return a neuron's points, tangents, and alphas as a dataframe.""" + arr = raise_if_none(self._impl.neuron_array(idx), idx) + return pd.DataFrame( + arr, columns=["x", "y", "z", "tangent_x", "tangent_y", "tangent_z", "alpha"] + ) + + # def serialize_neuron( + # self, idx: Idx, write_to: Optional[BytesIO], format: Format + # ) -> Optional[bytes]: + # b = raise_if_none(self._impl.serialize_neuron(idx, format), idx) + # if write_to is None: + # return b + # else: + # write_to.write(b) + # return None + + # def add_serialized_neuron(self, read_from: BytesIO, format: Format) -> Idx: + # # todo: slower than adding p,t,a. Need better serialization? + # b = read_from.read() + # return self._impl.add_serialized_neuron(b, format) + + def copy(self, deep=True): + out = copy(self) + if deep: + out._impl = out._impl.deepcopy() + return out + + def __deepcopy__(self): + return self.copy() diff --git a/nblast-py/python/pynblast/pynblast.pyi b/nblast-py/python/pynblast/pynblast.pyi index 2c9a4da..a73efd2 100644 --- a/nblast-py/python/pynblast/pynblast.pyi +++ b/nblast-py/python/pynblast/pynblast.pyi @@ -2,6 +2,10 @@ from __future__ import annotations from typing import List, Optional, Tuple from .util import Idx +import numpy as np +from numpy.typing import NDArray + +F = np.float64 class ResamplingArbor: def __init__(self, table: List[Tuple[int, Optional[int], float, float, float]]): ... @@ -27,9 +31,9 @@ class ArenaWrapper: use_alpha: bool, threads: Optional[int], ) -> None: ... - def add_points(self, points: list[list[float]]) -> Idx: ... + def add_points(self, points: NDArray[F]) -> Idx: ... def add_points_tangents_alphas( - self, points: list[list[float]], tangents: list[float], alphas: list[float] + self, points: NDArray[F], tangents: NDArray[F], alphas: NDArray[F] ) -> Idx: ... def query_target( self, @@ -55,12 +59,16 @@ class ArenaWrapper: def len(self) -> int: ... def is_empty(self) -> bool: ... def self_hit(self, idx: Idx) -> Optional[float]: ... - def points(self, idx: Idx) -> Optional[list[list[float]]]: ... - def tangents(self, idx: Idx) -> Optional[list[list[float]]]: ... - def alphas(self, idx: Idx) -> Optional[list[float]]: ... + def points(self, idx: Idx) -> Optional[NDArray[F]]: ... + def tangents(self, idx: Idx) -> Optional[NDArray[F]]: ... + def alphas(self, idx: Idx) -> Optional[NDArray[F]]: ... + def neuron_array(self, idx: Idx) -> Optional[NDArray[F]]: ... + def serialize_neuron(self, idx: Idx, format: str) -> Optional[bytes]: ... + def add_serialized_neuron(self, b: bytes, format: str) -> Idx: ... + def deepcopy(self): ... def build_score_matrix( - points: List[List[List[float]]], + points: List[NDArray[np.float64]], k: int, seed: int, use_alpha: bool, @@ -74,3 +82,5 @@ def build_score_matrix( max_nonmatching_pairs: Optional[int], threads: Optional[int], ) -> Tuple[List[float], List[float], List[float]]: ... +def get_version() -> str: ... +def backend() -> str: ... diff --git a/nblast-py/python/pynblast/score_matrix.py b/nblast-py/python/pynblast/score_matrix.py index af6daab..d4b37b4 100644 --- a/nblast-py/python/pynblast/score_matrix.py +++ b/nblast-py/python/pynblast/score_matrix.py @@ -2,6 +2,7 @@ import csv import numpy as np +from numpy.typing import ArrayLike def parse_interval(s) -> tuple[float, float]: @@ -22,10 +23,35 @@ def intervals_to_bins(intervals: list[tuple[float, float]]): return out -class ScoreMatrix(NamedTuple): - dist_thresholds: List[float] - dot_thresholds: List[float] - values: np.ndarray +class ScoreMatrix: + """Representation of a lookup table for point match scores. + + N thresholds represent N-1 bins. + The values are in dot-major order + (i.e. values in the same dist bin are next to each other). + """ + + def __init__( + self, + dist_thresholds: ArrayLike, + dot_thresholds: ArrayLike, + values: ArrayLike, + ) -> None: + self.dist_thresholds = np.asarray(dist_thresholds, np.float64).flatten() + self.dot_thresholds = np.asarray(dot_thresholds, np.float64).flatten() + self.values = np.asarray(values, np.float64) + + exp_shape = (len(self.dist_thresholds) - 1, len(dot_thresholds) - 1) + if self.values.shape != exp_shape: + raise ValueError( + "For N dist_thresholds and M dot_thresholds, values must be (N-1)x(M-1)" + ) + + def _flat_values(self): + if self.values.flags["F_CONTIGUOUS"]: + return self.values.T.flatten() + else: + return self.values.flatten() def to_df(self): import pandas as pd diff --git a/nblast-py/python/pynblast/smat_builder.py b/nblast-py/python/pynblast/smat_builder.py index 439ab28..b36fc93 100644 --- a/nblast-py/python/pynblast/smat_builder.py +++ b/nblast-py/python/pynblast/smat_builder.py @@ -1,18 +1,34 @@ from typing import List, Optional, Set, Union import numpy as np +from numpy.typing import NDArray -from pynblast.arena import DEFAULT_K, DEFAULT_THREADS +from .arena import DEFAULT_K, DEFAULT_THREADS +from .score_matrix import ScoreMatrix from .pynblast import build_score_matrix class ScoreMatrixBuilder: + """Class for training your own score matrix from data. + + 1. Create the builder + 2. Add some point clouds to it with `.add_points()`, which returns an index for each + 3. Use those indices to designate groups of neurons which should match each other with `.add_matching_set()` + 4. Optionally designate groups which should not match each other `.add_nonmatching_set()`. + 5. Set the dist and dot bins (`.set_{dist,dot}_bins()`). + These can be a list of N-1 inner boundaries for N bins, + or just the integer N, + in which case they will be determined from the data. + 6. Optionally, set the maximum number of matching and/or nonmatching pairs with `.set_max_pairs()`. + 7. `.build()` + """ + def __init__(self, seed: int, k: int = DEFAULT_K, use_alpha: bool = False): self.seed = seed self.k = k self.use_alpha = use_alpha - self.neurons: List[List[List[float]]] = [] + self.neurons: List[NDArray[np.float64]] = [] self.matching_sets: List[List[int]] = [] self.nonmatching_sets: Optional[List[List[int]]] = None @@ -25,23 +41,25 @@ def __init__(self, seed: int, k: int = DEFAULT_K, use_alpha: bool = False): self.max_matching_pairs: Optional[int] = None self.max_nonmatching_pairs: Optional[int] = None - def add_neuron(self, points: np.ndarray) -> int: + def add_points(self, points: np.ndarray) -> int: points = np.asarray(points) if len(points) < self.k: raise ValueError(f"Neuron does not have enough points (needs {self.k})") if points.ndim != 2 or points.shape[1] != 3: raise ValueError("Not an Nx3 array") idx = len(self.neurons) - self.neurons.append(points.tolist()) + self.neurons.append(points) return idx def add_matching_set(self, ids: Set[int]): self.matching_sets.append(sorted(ids)) + return self def add_nonmatching_set(self, ids: Set[int]): if self.nonmatching_sets is None: self.nonmatching_sets = [] self.nonmatching_sets.append(sorted(ids)) + return self def set_dist_bins(self, bins: Union[int, List[float]]): """Number of bins, or list of inner boundaries of bins""" @@ -51,6 +69,7 @@ def set_dist_bins(self, bins: Union[int, List[float]]): else: self.dist_inner_bounds = sorted(bins) self.dist_n_bins = None + return self def set_dot_bins(self, bins: Union[int, List[float]]): """Number of bins, or list of inner boundaries of bins""" @@ -60,10 +79,12 @@ def set_dot_bins(self, bins: Union[int, List[float]]): else: self.dot_inner_bounds = sorted(bins) self.dot_n_bins = None + return self def set_max_pairs(self, matching: Optional[int], nonmatching: Optional[int]): self.max_matching_pairs = matching self.max_nonmatching_pairs = nonmatching + return self def build(self, threads: Optional[int] = DEFAULT_THREADS): dist_bins, dot_bins, cells = build_score_matrix( @@ -81,5 +102,5 @@ def build(self, threads: Optional[int] = DEFAULT_THREADS): self.max_nonmatching_pairs, threads, ) - values = np.array(cells).reshape((len(dist_bins), len(dot_bins))) - return dist_bins, dot_bins, values + values = np.array(cells).reshape((len(dist_bins) - 1, len(dot_bins) - 1)) + return ScoreMatrix(dist_bins, dot_bins, values) diff --git a/nblast-py/python/pynblast/util.py b/nblast-py/python/pynblast/util.py index 518bb61..efeb4fa 100644 --- a/nblast-py/python/pynblast/util.py +++ b/nblast-py/python/pynblast/util.py @@ -1,28 +1,13 @@ from typing import NewType import enum +import sys import numpy as np - -class StrEnum(str, enum.Enum): - def __new__(cls, *args): - for arg in args: - if not isinstance(arg, (str, enum.auto)): - raise TypeError( - "Values of StrEnums must be strings: {} is a {}".format( - repr(arg), type(arg) - ) - ) - return super().__new__(cls, *args) - - def __str__(self): - return self.value - - # The first argument to this function is documented to be the name of the - # enum member, not `self`: - # https://docs.python.org/3.6/library/enum.html#using-automatic-values - def _generate_next_value_(name, *_): - return name +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from backports.strenum import StrEnum Idx = NewType("Idx", int) @@ -48,6 +33,18 @@ class Symmetry(StrEnum): MAX = "max" +class Format(StrEnum): + """Enum of serialization formats for neurons. + + JSON is plain text, simple and easily introspectable. + CBOR is binary, faster and smaller, but needs different tools to debug. + """ + + # todo: fix JSON + # JSON = "json" + CBOR = "cbor" + + def rectify_tangents(orig: np.ndarray, inplace=False) -> np.ndarray: """Normalises orientation of tangents. diff --git a/nblast-py/requirements.txt b/nblast-py/requirements.txt index 3186a7b..eb1beb0 100644 --- a/nblast-py/requirements.txt +++ b/nblast-py/requirements.txt @@ -3,6 +3,7 @@ maturin==1.4.0 # run numpy==1.26.3 +backports.strenum; python_version < '3.11' # test pandas==2.2.0 diff --git a/nblast-py/src/lib.rs b/nblast-py/src/lib.rs index 53afe14..3e59862 100644 --- a/nblast-py/src/lib.rs +++ b/nblast-py/src/lib.rs @@ -1,19 +1,23 @@ use numpy::ndarray::Array; -use pyo3::exceptions; use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{self, PyRuntimeError}; use pyo3::prelude::*; +use pyo3::types::PyBytes; use std::collections::HashMap; use std::f64::{INFINITY, NEG_INFINITY}; +use std::fmt::Display; +use std::io::{Cursor, Read, Write}; +use std::str::FromStr; use neurarbor::slab_tree::{NodeId, Tree}; use neurarbor::{edges_to_tree_with_data, resample_tree_points, Location, SpatialArbor, TopoArbor}; -use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2}; +use numpy::{Element, IntoPyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2}; use nblast::nalgebra::base::{Unit, Vector3}; use nblast::{ - BinLookup, NblastArena, Neuron, NeuronIdx, Precision, RangeTable, ScoreCalc, - ScoreMatrixBuilder, Symmetry, TangentAlpha, + BinLookup, NblastArena, NblastNeuron, Neuron, NeuronIdx, Precision, RangeTable, ScoreCalc, + ScoreMatrixBuilder, Symmetry, TangentAlpha, DEFAULT_BACKEND, }; use nblast::rayon; @@ -31,6 +35,7 @@ fn str_to_sym(s: &str) -> Result { } #[pyclass] +#[derive(Debug, Clone)] pub struct ArenaWrapper { arena: NblastArena, k: usize, @@ -62,16 +67,7 @@ impl ArenaWrapper { if pshape[1] != 3 { return Err(PyErr::new::("Points were not 3D")); } - let neuron = Neuron::new( - points - .as_array() - .rows() - .into_iter() - .map(|r| [r[0], r[1], r[2]]) - .collect(), - self.k, - ) - .map_err(PyErr::new::)?; + let neuron = np_to_neuron(points, self.k).map_err(PyErr::new::)?; Ok(self.arena.add_neuron(neuron)) } @@ -194,7 +190,8 @@ impl ArenaWrapper { } pub fn points<'py>(&self, py: Python<'py>, idx: NeuronIdx) -> Option<&'py PyArray2> { - let points = self.arena.points(idx)?; + let nrn = self.arena.neuron(idx)?; + let points = nrn.points(); let v = points.flatten().collect::>(); Some( Array::from_shape_vec((v.len() / 3, 3), v) @@ -208,22 +205,143 @@ impl ArenaWrapper { py: Python<'py>, idx: NeuronIdx, ) -> Option<&'py PyArray2> { - let tangents = self.arena.tangents(idx)?; - let len = self.arena.size_of(idx)?; + let nrn = self.arena.neuron(idx)?; + let tangents = nrn.tangents(); let v = tangents .into_iter() - .fold(Vec::with_capacity(len * 3), |mut v, t| { + .fold(Vec::with_capacity(nrn.len() * 3), |mut v, t| { v.extend(t.into_inner().iter()); v }); - Some(Array::from_shape_vec((len, 3), v).unwrap().into_pyarray(py)) + Some( + Array::from_shape_vec((nrn.len(), 3), v) + .unwrap() + .into_pyarray(py), + ) } pub fn alphas<'py>(&self, py: Python<'py>, idx: NeuronIdx) -> Option<&'py PyArray1> { - self.arena - .alphas(idx) - .map(|v| PyArray1::from_vec(py, v.collect())) + let v = self.arena.neuron(idx)?.alphas().collect(); + Some(PyArray1::from_vec(py, v)) + } + + /// Get the neuron as an array with N rows and 7 columns. + /// The columns are the point locations x, y, z; + /// the normalized tangent vectors x, y, z; + /// and the alpha value, for each of the N points. + pub fn neuron_array<'py>( + &self, + py: Python<'py>, + id: NeuronIdx, + ) -> Option<&'py PyArray2> { + let nrn = self.arena.neuron(id)?; + let v = nrn.points().zip(nrn.tangents()).zip(nrn.alphas()).fold( + Vec::with_capacity(nrn.len() * 7), + |mut v, ((p, t), a)| { + v.extend(p); + v.extend(t.as_slice()); + v.push(a); + v + }, + ); + + Some( + Array::from_shape_vec((nrn.len(), 7), v) + .unwrap() + .into_pyarray(py), + ) + } + + /// Add a neuron serialized by [serialize_neuron](#method.serialize_neuron). + /// + /// Note that a serialized neuron is only valid for one particular backend + /// and neighborhood size. + /// Unlike adding a neuron by its points or points/tangents/alphas, + /// this saves having to rebuild the spatial index. + pub fn add_serialized_neuron<'py>( + &mut self, + _py: Python<'py>, + bytes: &[u8], + format: &str, + ) -> PyResult { + let fmt = NeuronFormat::from_str(format).map_err(|msg| PyValueError::new_err(msg))?; + let n = fmt.read_neuron(bytes)?; + Ok(self.arena.add_neuron(n)) + } + + /// Serialize a neuron and its spatial index. + /// + /// Note that a serialized neuron is only valid for one particular backend + /// and neighborhood size; the format may not be stable. + pub fn serialize_neuron<'py>( + &self, + py: Python<'py>, + idx: NeuronIdx, + format: &str, + ) -> PyResult> { + let fmt = NeuronFormat::from_str(format).map_err(|msg| PyValueError::new_err(msg))?; + let Some(nrn) = self.arena.neuron(idx) else { + return Ok(None); + }; + let mut buf = Cursor::new(Vec::default()); + fmt.write_neuron(nrn, &mut buf)?; + Ok(Some(PyBytes::new(py, buf.into_inner().as_slice()))) + } + + pub fn deepcopy<'py>(&self, _py: Python<'py>) -> Self { + self.clone() + } +} + +enum NeuronFormat { + Json, + Cbor, + // consider bincode, postcard, flexbuffers +} + +impl NeuronFormat { + fn write_neuron(&self, nrn: &Neuron, w: &mut W) -> PyResult<()> { + match self { + Self::Json => { + serde_json::to_writer(w, nrn).map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + Self::Cbor => { + ciborium::into_writer(nrn, w).map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + } + } + + fn read_neuron(&self, r: R) -> PyResult { + match self { + Self::Json => { + serde_json::from_reader(r).map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + Self::Cbor => { + ciborium::from_reader(r).map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + } + } +} + +impl Display for NeuronFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NeuronFormat::Json => f.write_str("json"), + NeuronFormat::Cbor => f.write_str("cbor"), + } + } +} + +impl FromStr for NeuronFormat { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "json" => Ok(Self::Json), + "cbor" => Ok(Self::Cbor), + _ => Err(format!("Unknown format '{s}'")), + } } } @@ -349,32 +467,60 @@ impl ResamplingArbor { } } +#[allow(non_snake_case)] +fn is_Nx3(p: &PyReadonlyArray2) -> Result<(), &'static str> { + if p.shape()[1] != 3 { + return Err("Array must be Nx3"); + } + Ok(()) +} + +fn np_to_neuron(points: PyReadonlyArray2, k: usize) -> Result { + is_Nx3(&points)?; + Neuron::new( + points + .as_array() + .rows() + .into_iter() + .map(|r| [r[0], r[1], r[2]]) + .collect(), + k, + ) +} + fn make_neurons_many( - points_list: Vec>>, + points_list: Vec>, k: usize, threads: Option, -) -> Vec { +) -> Result, &'static str> { if let Some(t) = threads { + let vp = points_list + .into_iter() + .map(|np| { + is_Nx3(&np)?; + Ok(np + .as_array() + .rows() + .into_iter() + .map(|r| [r[0], r[1], r[2]]) + .collect::>()) + }) + .collect::, &str>>()?; + let pool = rayon::ThreadPoolBuilder::new() .num_threads(t) .build() .unwrap(); - pool.install(|| { - points_list - .into_par_iter() - .map(|ps| { - Neuron::new(ps.into_iter().map(|p| [p[0], p[1], p[2]]).collect(), k) - .expect("Invalid neuron") - }) + + Ok(pool.install(|| { + vp.into_par_iter() + .map(|ps| Neuron::new(ps, k).expect("Invalid neuron")) .collect() - }) + })) } else { points_list .into_iter() - .map(|ps| { - Neuron::new(ps.into_iter().map(|p| [p[0], p[1], p[2]]).collect(), k) - .expect("invalid neuron") - }) + .map(|ps| np_to_neuron(ps, k)) .collect() } } @@ -388,8 +534,8 @@ fn inner_bounds_to_binlookup(mut v: Vec) -> BinLookup { #[pyfunction] #[allow(clippy::too_many_arguments)] fn build_score_matrix( - py: Python, - points: Vec>>, + _py: Python, + points: Vec>, k: usize, seed: u64, use_alpha: bool, @@ -402,50 +548,48 @@ fn build_score_matrix( max_matching_pairs: Option, max_nonmatching_pairs: Option, threads: Option, -) -> (Vec, Vec, Vec) { - py.allow_threads(|| { - let neurons = make_neurons_many(points, k, threads); - let mut smatb = ScoreMatrixBuilder::new(neurons, seed); - smatb.set_threads(threads).set_use_alpha(use_alpha); - - if let Some(mmp) = max_matching_pairs { - smatb.set_max_matching_pairs(mmp); - } +) -> PyResult<(Vec, Vec, Vec)> { + let neurons = make_neurons_many(points, k, threads).map_err(PyRuntimeError::new_err)?; + let mut smatb = ScoreMatrixBuilder::new(neurons, seed); + smatb.set_threads(threads).set_use_alpha(use_alpha); - if let Some(mnmp) = max_nonmatching_pairs { - smatb.set_max_nonmatching_pairs(mnmp); - } + if let Some(mmp) = max_matching_pairs { + smatb.set_max_matching_pairs(mmp); + } - for m in matching_sets.into_iter() { - smatb.add_matching_set(&m); - } - if let Some(ns) = nonmatching_sets { - for n in ns.into_iter() { - smatb.add_nonmatching_set(&n); - } - } + if let Some(mnmp) = max_nonmatching_pairs { + smatb.set_max_nonmatching_pairs(mnmp); + } - if let Some(inner) = dist_inner_bounds { - smatb.set_dist_lookup(inner_bounds_to_binlookup(inner)); - } else if let Some(n) = dist_n_bins { - smatb.set_n_dist_bins(n); - } else { - unimplemented!("should supply dist_inner_bounds or dist_n_bins"); + for m in matching_sets.into_iter() { + smatb.add_matching_set(&m); + } + if let Some(ns) = nonmatching_sets { + for n in ns.into_iter() { + smatb.add_nonmatching_set(&n); } + } - if let Some(inner) = dot_inner_bounds { - smatb.set_dot_lookup(inner_bounds_to_binlookup(inner)); - } else if let Some(n) = dot_n_bins { - smatb.set_n_dot_bins(n); - } else { - unimplemented!("should supply dot_inner_bounds or dot_n_bins"); - } + if let Some(inner) = dist_inner_bounds { + smatb.set_dist_lookup(inner_bounds_to_binlookup(inner)); + } else if let Some(n) = dist_n_bins { + smatb.set_n_dist_bins(n); + } else { + unimplemented!("should supply dist_inner_bounds or dist_n_bins"); + } - let mut table = smatb.build().expect("Failed to build score matrix"); - let dot_bounds = table.bins_lookup.lookups.pop().unwrap().bin_boundaries; - let dist_bounds = table.bins_lookup.lookups.pop().unwrap().bin_boundaries; - (dist_bounds, dot_bounds, table.cells) - }) + if let Some(inner) = dot_inner_bounds { + smatb.set_dot_lookup(inner_bounds_to_binlookup(inner)); + } else if let Some(n) = dot_n_bins { + smatb.set_n_dot_bins(n); + } else { + unimplemented!("should supply dot_inner_bounds or dot_n_bins"); + } + + let mut table = smatb.build().expect("Failed to build score matrix"); + let dot_bounds = table.bins_lookup.lookups.pop().unwrap().bin_boundaries; + let dist_bounds = table.bins_lookup.lookups.pop().unwrap().bin_boundaries; + Ok((dist_bounds, dot_bounds, table.cells)) } #[pyfunction] @@ -453,12 +597,18 @@ fn get_version(_py: Python) -> String { env!("CARGO_PKG_VERSION").to_string() } +#[pyfunction] +fn backend(_py: Python) -> String { + DEFAULT_BACKEND.to_string() +} + #[pymodule] fn pynblast(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(get_version, m)?)?; m.add_function(wrap_pyfunction!(build_score_matrix, m)?)?; + m.add_function(wrap_pyfunction!(backend, m)?)?; Ok(()) } diff --git a/nblast-py/tests/conftest.py b/nblast-py/tests/conftest.py index 6c94fe9..953473b 100644 --- a/nblast-py/tests/conftest.py +++ b/nblast-py/tests/conftest.py @@ -1,6 +1,7 @@ from typing import Callable, Tuple, List, Dict, Optional from pathlib import Path import json +from pynblast.smat_builder import ScoreMatrixBuilder import pytest import pandas as pd @@ -16,7 +17,7 @@ def smat_path() -> Path: @pytest.fixture -def score_mat_tup(smat_path) -> pynblast.ScoreMatrix: +def score_mat(smat_path) -> pynblast.ScoreMatrix: return pynblast.ScoreMatrix.read(smat_path) @@ -33,17 +34,17 @@ def points() -> List[Tuple[str, pd.DataFrame]]: @pytest.fixture -def arena(score_mat_tup): - return pynblast.NblastArena(*score_mat_tup, k=5) +def arena(score_mat): + return pynblast.NblastArena(score_mat, k=5) @pytest.fixture def arena_names_factory( - score_mat_tup, points + score_mat, points ) -> Callable[[bool, Optional[int]], Tuple[pynblast.NblastArena, dict[int, str]]]: def fn(use_alpha, threads): arena = pynblast.NblastArena( - *score_mat_tup, k=5, use_alpha=use_alpha, threads=threads + score_mat, k=5, use_alpha=use_alpha, threads=threads ) idx_to_name = dict() @@ -93,3 +94,20 @@ def skeleton() -> List[Tuple[int, Optional[int], float, float, float]]: @pytest.fixture def resampler(skeleton) -> pynblast.ResamplingArbor: return pynblast.ResamplingArbor(skeleton) + + +@pytest.fixture +def smatbuilder(points): + matching_sets = {"Fru": set(), "Gad": set()} + builder = ScoreMatrixBuilder(1991) + for name, pts in points: + idx = builder.add_points(pts.to_numpy()) + for k, v in matching_sets.items(): + if name.startswith(k): + v.add(idx) + + for ms in matching_sets.values(): + builder.add_matching_set(ms) + + builder.set_dist_bins(6).set_dot_bins(4) + return builder diff --git a/nblast-py/tests/test_bench.py b/nblast-py/tests/test_bench.py new file mode 100644 index 0000000..8c84b57 --- /dev/null +++ b/nblast-py/tests/test_bench.py @@ -0,0 +1,68 @@ +from pynblast import NblastArena +from pynblast.util import Format + +import pytest + + +@pytest.mark.benchmark(group="construct") +def test_empty_arena(benchmark, score_mat): + benchmark(NblastArena, score_mat, k=5) + + +@pytest.mark.benchmark(group="construct") +def test_clone_arena(benchmark, arena: NblastArena): + benchmark(arena.copy, True) + + +@pytest.mark.benchmark(group="construct") +def test_add_points(benchmark, points, arena): + pt = points[0][1].to_numpy() + + def fn(): + ar2 = arena.copy(True) + ar2.add_points(pt) + + benchmark(fn) + + +@pytest.mark.benchmark(group="construct") +def test_add_points_tangents_alphas(benchmark, points, arena: NblastArena): + pt = points[0][1].to_numpy() + empty = arena.copy(True) + idx = arena.add_points(pt) + t = arena.tangents(idx) + a = arena.alphas(idx) + + def fn(): + ar = empty.copy(True) + ar.add_points_tangents_alphas(pt, t, a) + + benchmark(fn) + + +@pytest.mark.benchmark(group="construct") +def test_add_serialized(benchmark, points, arena: NblastArena): + pt = points[0][1].to_numpy() + empty = arena.copy(True) + idx = arena.add_points(pt) + b = arena._impl.serialize_neuron(idx, Format.CBOR) + + def fn(): + ar = empty.copy(True) + # ar.add_serialized_neuron(BytesIO(b), Format.CBOR) + ar._impl.add_serialized_neuron(b, Format.CBOR) + + benchmark(fn) + + +@pytest.mark.benchmark(group="all_v_all") +def test_all_v_all_serial(benchmark, arena_names_factory): + arena: NblastArena + (arena, _) = arena_names_factory(False, None) + + benchmark(arena.all_v_all) + + +@pytest.mark.benchmark(group="smatbuilder") +def test_smatbuilder_bench(benchmark, smatbuilder): + benchmark(smatbuilder.build) diff --git a/nblast-py/tests/test_nblast.py b/nblast-py/tests/test_nblast.py index 333d252..2657e05 100644 --- a/nblast-py/tests/test_nblast.py +++ b/nblast-py/tests/test_nblast.py @@ -17,14 +17,31 @@ def test_read_smat(smat_path): - dist_bins, dot_bins, arr = ScoreMatrix.read(smat_path) + smat = ScoreMatrix.read(smat_path) + dist_bins = smat.dist_thresholds + dot_bins = smat.dot_thresholds + arr = smat.values assert arr.shape == (len(dist_bins) - 1, len(dot_bins) - 1) df = pd.read_csv(smat_path, index_col=0, header=0) assert np.allclose(arr, df.to_numpy()) -def test_construction(score_mat_tup): - NblastArena(*score_mat_tup) +def test_construction(score_mat): + NblastArena(score_mat) + + +def test_shallow_copy(arena: NblastArena): + arena2 = arena.copy(False) + pts = np.random.random((100, 3)) + arena.add_points(pts) + assert len(arena) == len(arena2) + + +def test_deep_copy(arena: NblastArena): + arena2 = arena.copy(True) + pts = np.random.random((100, 3)) + arena.add_points(pts) + assert len(arena) != len(arena2) def test_insertion(arena: NblastArena, points): @@ -189,3 +206,36 @@ def test_prepop_tangents(points, arena): assert arena.query_target(idx0, idx1, normalize=True) == pytest.approx( 1.0, abs=EPSILON ) + + +def test_neuron_table(points, arena: NblastArena): + p = points[0][1].to_numpy() + idx1 = arena.add_points(p) + df = arena.neuron_table(idx1) + + points = df[list("xyz")].to_numpy() + assert np.allclose(points, arena.points(idx1)) + + tangents = df[["tangent_" + d for d in "xyz"]].to_numpy() + assert np.allclose(tangents, arena.tangents(idx1)) + + assert np.allclose(df["alpha"].to_numpy(), arena.alphas(idx1)) + + +# @pytest.mark.parametrize( +# ["format"], +# [ +# # (Format.JSON,), +# (Format.CBOR,) +# ], +# ) +# def test_ser_roundtrip(points, arena: NblastArena, format): +# p = points[0][1].to_numpy() +# idx1 = arena.add_points(p) +# b: bytes = arena.serialize_neuron(idx1, None, format) +# idx2 = arena.add_serialized_neuron(BytesIO(b), format) + +# df1 = arena.neuron_table(idx1) +# df2 = arena.neuron_table(idx2) + +# assert np.allclose(df1.to_numpy(), df2.to_numpy()) diff --git a/nblast-py/tests/test_smat_builder.py b/nblast-py/tests/test_smat_builder.py new file mode 100644 index 0000000..a32b889 --- /dev/null +++ b/nblast-py/tests/test_smat_builder.py @@ -0,0 +1,5 @@ +def test_smatbuilder(smatbuilder): + smat = smatbuilder.build() + assert len(smat.dist_thresholds) == 7 + assert len(smat.dot_thresholds) == 5 + assert smat.values.shape == (6, 4) diff --git a/nblast-rs/Cargo.toml b/nblast-rs/Cargo.toml index 2b95100..4b19b3b 100644 --- a/nblast-rs/Cargo.toml +++ b/nblast-rs/Cargo.toml @@ -28,6 +28,7 @@ nabo = { version = "0.2.1", optional = true } kiddo = { version = "4.0", optional = true } cfg-if = "1.0.0" bosque = { version = "0.2.0", optional = true } +serde = { version = "1", optional = true, features = ["derive"]} [dev-dependencies] @@ -40,8 +41,7 @@ serde = { version = "1", features = ["derive"] } default = ["kiddo"] parallel = ["rayon"] - -# default = ["parallel"] +serde = ["dep:serde", "kiddo?/serialize", "rstar?/serde"] [[bench]] diff --git a/nblast-rs/src/lib.rs b/nblast-rs/src/lib.rs index 94f9be7..49b7921 100644 --- a/nblast-rs/src/lib.rs +++ b/nblast-rs/src/lib.rs @@ -116,6 +116,8 @@ use table_lookup::InvalidRangeTable; pub use rayon; #[cfg(feature = "parallel")] use rayon::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; pub use nalgebra; @@ -126,7 +128,7 @@ mod table_lookup; pub use table_lookup::{BinLookup, NdBinLookup, RangeTable}; pub mod neurons; -pub use neurons::{NblastNeuron, Neuron, QueryNeuron, TargetNeuron}; +pub use neurons::{NblastNeuron, Neuron, QueryNeuron, TargetNeuron, DEFAULT_BACKEND}; #[cfg(not(any( feature = "nabo", @@ -172,11 +174,78 @@ fn harmonic_mean(a: Precision, b: Precision) -> Precision { /// A tangent, alpha pair associated with a point. #[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct TangentAlpha { + #[cfg_attr(feature = "serde", serde(with = "serde_norm"))] pub tangent: Normal3, pub alpha: Precision, } +#[cfg(feature = "serde")] +mod serde_norm { + use super::*; + use serde::{ + de, + de::{SeqAccess, Visitor}, + Deserializer, Serializer, + }; + use std::fmt; + + pub fn serialize(norm: &Normal3, ser: S) -> Result + where + S: Serializer, + { + ser.collect_seq(norm.as_slice().iter()) + } + + struct NormVisitor; + + impl<'de> Visitor<'de> for NormVisitor { + type Value = Normal3; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("ArrayKeyedMap key value sequence.") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut elems = [0.0; 3]; + let mut idx = 0; + let mut mag: f64 = 0.0; + while let Some(val) = seq.next_element()? { + elems[idx] = val; + mag += val * val; + idx += 1; + if idx > elems.len() { + return Err(de::Error::custom("Too many elements")); + } + } + if idx < elems.len() { + return Err(de::Error::custom("Not enough elements")); + } + + let v = Vector3::from(elems); + + if mag == 0.0 { + Err(de::Error::custom("tangent has magnitude 0")) + } else if mag == 1.0 { + Ok(Unit::new_unchecked(v)) + } else { + Ok(Unit::new_unchecked(v / mag.sqrt())) + } + } + } + + pub fn deserialize<'de, D>(deser: D) -> Result + where + D: Deserializer<'de>, + { + deser.deserialize_seq(NormVisitor) + } +} + impl TangentAlpha { fn new_from_points<'a>(points: impl Iterator) -> Self { let inertia = calc_inertia(points); @@ -493,7 +562,7 @@ impl Location for &Point3 { } } -#[derive(Clone)] +#[derive(Clone, Debug)] struct NeuronSelfHit { neuron: N, self_hit: Precision, @@ -553,6 +622,7 @@ impl ScoreCalc { /// Struct for caching a number of neurons for multiple comparable NBLAST queries. #[allow(dead_code)] +#[derive(Debug, Clone)] pub struct NblastArena where N: TargetNeuron, @@ -592,10 +662,6 @@ where } } - pub fn size_of(&self, idx: NeuronIdx) -> Option { - self.neurons_scores.get(idx).map(|n| n.neuron.len()) - } - fn next_id(&self) -> NeuronIdx { self.neurons_scores.len() } @@ -609,6 +675,10 @@ where idx } + pub fn neuron(&self, idx: NeuronIdx) -> Option<&N> { + self.neurons_scores.get(idx).map(|ns| &ns.neuron) + } + /// Make a single query using the given indexes. /// `normalize` divides the result by the self-hit score of the query neuron. /// `symmetry`, if `Some`, also calculates the reverse score @@ -805,18 +875,6 @@ where pub fn len(&self) -> usize { self.neurons_scores.len() } - - pub fn points(&self, idx: NeuronIdx) -> Option + '_> { - self.neurons_scores.get(idx).map(|n| n.neuron.points()) - } - - pub fn tangents(&self, idx: NeuronIdx) -> Option + '_> { - self.neurons_scores.get(idx).map(|n| n.neuron.tangents()) - } - - pub fn alphas(&self, idx: NeuronIdx) -> Option + '_> { - self.neurons_scores.get(idx).map(|n| n.neuron.alphas()) - } } fn pairs_to_raw_serial( diff --git a/nblast-rs/src/neurons/any.rs b/nblast-rs/src/neurons/any.rs new file mode 100644 index 0000000..4da7449 --- /dev/null +++ b/nblast-rs/src/neurons/any.rs @@ -0,0 +1,144 @@ +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// Enum for serializing/ deserializing any kind of neuron. +/// +/// Get neurons in and out of it with `From` and `TryFrom` respectively. +#[derive(Serialize, Deserialize)] +pub enum AnyNeuron { + #[cfg(feature = "bosque")] + Bosque(super::bosque::BosqueNeuron), + #[cfg(feature = "kiddo")] + Kiddo(super::kiddo::KiddoNeuron), + #[cfg(feature = "kiddo")] + ApproxKiddo(super::kiddo::ApproxKiddoNeuron), + #[cfg(feature = "rstar")] + Rstar(super::rstar::RstarNeuron), + // nabo doesn't implement Ser/Deser because it uses the Point trait rather than a specific type internally. + + // can't implement neuron traits because iterators have different types, + // but could Box them +} + +#[derive(Debug, thiserror::Error)] +#[error("Wrong neuron backend: expected '{expected}', got '{got}'")] +pub struct WrongNeuronType { + expected: &'static str, + got: &'static str, +} + +impl WrongNeuronType { + fn new(expected: &'static str, got: &'static str) -> Self { + Self { expected, got } + } +} + +#[cfg(feature = "rstar")] +impl From for AnyNeuron { + fn from(value: super::rstar::RstarNeuron) -> Self { + Self::Rstar(value) + } +} + +#[cfg(feature = "rstar")] +impl TryFrom for super::rstar::RstarNeuron { + type Error = WrongNeuronType; + + fn try_from(value: AnyNeuron) -> Result { + use AnyNeuron::*; + let expected = "rstar"; + match value { + #[cfg(feature = "rstar")] + Rstar(n) => Ok(n), + #[cfg(feature = "kiddo")] + Kiddo(_) => Err(WrongNeuronType::new(expected, "kiddo")), + #[cfg(feature = "kiddo")] + ApproxKiddo(_) => Err(WrongNeuronType::new(expected, "approx_kiddo")), + #[cfg(feature = "bosque")] + Bosque(_) => Err(WrongNeuronType::new(expected, "bosque")), + } + } +} + +#[cfg(feature = "kiddo")] + +impl From for AnyNeuron { + fn from(value: super::kiddo::KiddoNeuron) -> Self { + Self::Kiddo(value) + } +} +#[cfg(feature = "kiddo")] +impl TryFrom for super::kiddo::KiddoNeuron { + type Error = WrongNeuronType; + + fn try_from(value: AnyNeuron) -> Result { + use AnyNeuron::*; + let expected = "kiddo"; + match value { + #[cfg(feature = "rstar")] + Rstar(_) => Err(WrongNeuronType::new(expected, "rstar")), + #[cfg(feature = "kiddo")] + Kiddo(n) => Ok(n), + #[cfg(feature = "kiddo")] + ApproxKiddo(_) => Err(WrongNeuronType::new(expected, "approx_kiddo")), + #[cfg(feature = "bosque")] + Bosque(_) => Err(WrongNeuronType::new(expected, "bosque")), + } + } +} + +#[cfg(feature = "kiddo")] + +impl From for AnyNeuron { + fn from(value: super::kiddo::ApproxKiddoNeuron) -> Self { + Self::ApproxKiddo(value) + } +} + +#[cfg(feature = "kiddo")] +impl TryFrom for super::kiddo::ApproxKiddoNeuron { + type Error = WrongNeuronType; + + fn try_from(value: AnyNeuron) -> Result { + use AnyNeuron::*; + let expected = "approx_kiddo"; + match value { + #[cfg(feature = "rstar")] + Rstar(_) => Err(WrongNeuronType::new(expected, "rstar")), + #[cfg(feature = "kiddo")] + Kiddo(_) => Err(WrongNeuronType::new(expected, "kiddo")), + #[cfg(feature = "kiddo")] + ApproxKiddo(n) => Ok(n), + #[cfg(feature = "bosque")] + Bosque(_) => Err(WrongNeuronType::new(expected, "bosque")), + } + } +} + +#[cfg(feature = "bosque")] + +impl From for AnyNeuron { + fn from(value: super::bosque::BosqueNeuron) -> Self { + Self::Bosque(value) + } +} + +#[cfg(feature = "bosque")] +impl TryFrom for super::bosque::BosqueNeuron { + type Error = WrongNeuronType; + + fn try_from(value: AnyNeuron) -> Result { + use AnyNeuron::*; + let expected = "bosque"; + match value { + #[cfg(feature = "rstar")] + Rstar(_) => Err(WrongNeuronType::new(expected, "rstar")), + #[cfg(feature = "kiddo")] + Kiddo(_) => Err(WrongNeuronType::new(expected, "kiddo")), + #[cfg(feature = "kiddo")] + ApproxKiddo(_) => Err(WrongNeuronType::new(expected, "approx_kiddo")), + #[cfg(feature = "bosque")] + Bosque(n) => Ok(n), + } + } +} diff --git a/nblast-rs/src/neurons/bosque.rs b/nblast-rs/src/neurons/bosque.rs index 4716166..998e0d7 100644 --- a/nblast-rs/src/neurons/bosque.rs +++ b/nblast-rs/src/neurons/bosque.rs @@ -3,8 +3,11 @@ use crate::{ TangentAlpha, TargetNeuron, }; use bosque::tree::{build_tree, build_tree_with_indices, nearest_k, nearest_one}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct BosqueNeuron { points: Vec, tangents_alphas: Vec, diff --git a/nblast-rs/src/neurons/kiddo.rs b/nblast-rs/src/neurons/kiddo.rs index 83e514e..6216f18 100644 --- a/nblast-rs/src/neurons/kiddo.rs +++ b/nblast-rs/src/neurons/kiddo.rs @@ -3,14 +3,14 @@ use super::{NblastNeuron, QueryNeuron, TargetNeuron}; use crate::{geometric_mean, DistDot, Normal3, Point3, Precision, ScoreCalc, TangentAlpha}; use kiddo::{ImmutableKdTree, SquaredEuclidean}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + type KdTree = ImmutableKdTree; /// Target neuron using a KDTree from the kiddo crate. -/// -/// By default, this uses approximate nearest neighbour for one-off lookups (as used in NBLAST scoring). -/// However, in tests it is *very* approximate. -/// See the [KiddoNeuron] for exact 1NN. #[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct KiddoNeuron { tree: KdTree, points_tangents_alphas: Vec<(Point3, TangentAlpha)>, @@ -150,6 +150,7 @@ impl TargetNeuron for KiddoNeuron { } #[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ApproxKiddoNeuron(KiddoNeuron); impl ApproxKiddoNeuron { diff --git a/nblast-rs/src/neurons/mod.rs b/nblast-rs/src/neurons/mod.rs index 3bdd75a..24bb41b 100644 --- a/nblast-rs/src/neurons/mod.rs +++ b/nblast-rs/src/neurons/mod.rs @@ -14,15 +14,22 @@ pub mod nabo; #[cfg(feature = "rstar")] pub mod rstar; +#[cfg(feature = "serde")] +pub mod any; + cfg_if::cfg_if! { if #[cfg(feature = "kiddo")] { pub type Neuron = self::kiddo::KiddoNeuron; + pub const DEFAULT_BACKEND: &'static str = "kiddo"; } else if #[cfg(feature = "bosque")] { pub type Neuron = self::bosque::BosqueNeuron; + pub const DEFAULT_BACKEND: &'static str = "bosque"; } else if #[cfg(feature = "rstar")] { pub type Neuron = self::rstar::RstarNeuron; + pub const DEFAULT_BACKEND: &'static str = "rstar"; } else if #[cfg(feature = "nabo")] { pub type Neuron = self::nabo::NaboNeuron; + pub const DEFAULT_BACKEND: &'static str = "nabo"; } else { compile_error!("No spatial query backend selected"); } diff --git a/nblast-rs/src/neurons/rstar.rs b/nblast-rs/src/neurons/rstar.rs index e2b0df0..9347b61 100644 --- a/nblast-rs/src/neurons/rstar.rs +++ b/nblast-rs/src/neurons/rstar.rs @@ -4,6 +4,9 @@ use crate::{geometric_mean, DistDot, Normal3, Point3, Precision, ScoreCalc, Tang use rstar::{primitives::GeomWithData, PointDistance, RTree}; use std::borrow::Borrow; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + type PointWithIndex = GeomWithData; fn points_to_rtree( @@ -40,6 +43,7 @@ pub(crate) fn points_to_rtree_tangents_alphas( /// Target neuron using an [R*-tree](https://en.wikipedia.org/wiki/R*_tree) for spatial queries. #[derive(Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RstarNeuron { rtree: RTree, tangents_alphas: Vec,