Skip to content

Commit f900916

Browse files
committedJun 22, 2024·
sais python functions
1 parent 8a46937 commit f900916

File tree

5 files changed

+64
-2
lines changed

5 files changed

+64
-2
lines changed
 

‎src/in_memory_index.rs

+28
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,34 @@ impl InMemoryIndex {
5050
})
5151
}
5252

53+
#[staticmethod]
54+
pub fn from_token_file_sais(
55+
path: String,
56+
token_limit: Option<usize>,
57+
) -> PyResult<Self> {
58+
let mut buffer = Vec::new();
59+
let mut file = File::open(&path)?;
60+
61+
if let Some(max_tokens) = token_limit {
62+
// Limit on the number of tokens to consider is provided
63+
let max_bytes = max_tokens * std::mem::size_of::<u16>();
64+
file.take(max_bytes as u64).read_to_end(&mut buffer)?;
65+
} else {
66+
file.read_to_end(&mut buffer)?;
67+
};
68+
69+
Ok(InMemoryIndex {
70+
table: SuffixTable::new_sais(transmute_slice(buffer.as_slice())),
71+
})
72+
}
73+
74+
#[staticmethod]
75+
pub fn sais(tokens: Vec<u16>) -> Self {
76+
InMemoryIndex {
77+
table: SuffixTable::new_sais(tokens),
78+
}
79+
}
80+
5381
pub fn contains(&self, query: Vec<u16>) -> bool {
5482
self.table.contains(&query)
5583
}

‎src/sais.rs

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ extern crate utf16_literal;
3232
use rayon::prelude::*;
3333
use std::u64;
3434

35+
use crate::util::par_bincount;
3536
use self::SuffixType::{Ascending, Descending, Valley};
3637

3738

@@ -55,6 +56,7 @@ fn sais<T: Text + ?Sized>(sa: &mut [u64], stypes: &mut SuffixTypes, bins: &mut B
5556
}
5657
sa.fill(0);
5758

59+
// TODO: Parallelize this step across batches / chunks / documents in the corpus
5860
stypes.compute(text);
5961
bins.find_sizes((0..text.len()).map(|i| text.char_at(i)));
6062
bins.find_tail_pointers();

‎src/util.rs

+29
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,32 @@
1+
use funty::Unsigned;
2+
use rayon::prelude::*;
3+
use std::sync::atomic::{AtomicUsize, Ordering};
4+
5+
/// Essentially np.bincount(data) in parallel.
6+
pub fn par_bincount<'a, I, T>(data: &'a I) -> Vec<usize>
7+
where
8+
I: IntoParallelRefIterator<'a, Item = T>,
9+
T: Unsigned,
10+
{
11+
// Find the maximum value in the data
12+
let max = match data.par_iter().max() {
13+
Some(m) => m,
14+
None => return Vec::new(),
15+
};
16+
17+
// Create a vector of atomic counters
18+
let mut counts = Vec::with_capacity(max.as_usize() + 1);
19+
for _ in 0..=max.as_usize() {
20+
counts.push(AtomicUsize::new(0));
21+
}
22+
23+
// Increment the counters in parallel
24+
data.par_iter().for_each(|x| {
25+
counts[x.as_usize()].fetch_add(1, Ordering::Relaxed);
26+
});
27+
counts.into_iter().map(|c| c.into_inner()).collect()
28+
}
29+
130
/// Return a zero-copy view of the given slice with the given type.
231
/// The resulting view has the same lifetime as the provided slice.
332
#[inline]

‎tests/tests.rs

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ fn qc<T: Testable>(f: T) {
1515

1616

1717
// Do some testing on substring search.
18-
1918
#[test]
2019
fn empty_find_empty() {
2120
let sa = sais("");

‎tokengrams/tokengrams.pyi

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ class InMemoryIndex:
77
@staticmethod
88
def from_token_file(path: str, verbose: bool, token_limit: int | None) -> "InMemoryIndex":
99
"""Construct a `InMemoryIndex` from a file containing raw little-endian tokens."""
10+
11+
@staticmethod
12+
def from_token_file_sais(path: str, verbose: bool, token_limit: int | None) -> "InMemoryIndex":
13+
"""Construct a `InMemoryIndex` from a file containing raw little-endian tokens."""
1014

1115
def contains(self, query: list[int]) -> bool:
1216
"""Check if `query` has nonzero count. Faster than `count(query) > 0`."""
@@ -46,7 +50,7 @@ class MemmapIndex:
4650
@staticmethod
4751
def build(token_file: str, index_file: str) -> "MemmapIndex":
4852
"""Build a memory-mapped index from a token file."""
49-
53+
5054
def contains(self, query: list[int]) -> bool:
5155
"""Check if `query` has nonzero count. Faster than `count(query) > 0`."""
5256

0 commit comments

Comments
 (0)
Please sign in to comment.