Skip to content

Commit 6cb4558

Browse files
committedJul 19, 2024·
wip
1 parent 51fc305 commit 6cb4558

10 files changed

+538
-426
lines changed
 

‎Cargo.lock

+85
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ crate-type = ["cdylib", "rlib"]
1313
[dependencies]
1414
anyhow = "1.0.81"
1515
bincode = "1.3.3"
16+
derive_builder = "0.20.0"
1617
funty = "2.0.0"
1718
indicatif = "0.17.8"
1819
memmap2 = "0.9.4"

‎src/architecture.md

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
``` rust
2+
struct SuffixTable {}
3+
impl SuffixTable {
4+
pub fn count_next(&self);
5+
6+
/// To be deleted - not a natural fit for a suffix table data structure, and depends on count_next
7+
/// which is "overridden" in ShardedMemmapIndex. I want to move it into a sampler struct.
8+
pub fn sample(&self) {
9+
counts = self.count_next()
10+
do_sample(counts)
11+
}
12+
}
13+
14+
struct MemmapIndex {
15+
table: SuffixTable
16+
}
17+
impl MemmapIndex {
18+
...lots of index/table creation code
19+
20+
pub fn count_next(&self) {
21+
self.table.count_next()
22+
}
23+
24+
pub fn sample(&self) {
25+
self.table.sample()
26+
}
27+
}
28+
29+
struct ShardedInMemoryIndex {
30+
shards: Vec<MemmapIndex>
31+
}
32+
impl ShardedMemmapIndex {
33+
pub fn count_next(&self) {
34+
counts = self.shards.iter().map(|shard| shard.count_next())
35+
merge_counts(counts)
36+
}
37+
}
38+
39+
// WIP
40+
struct Sampler {
41+
// Not possible due to PyO3 not supporting lifetimes
42+
index: Box<Index>
43+
}
44+
impl Sampler {
45+
pub fn sample(&self) {
46+
counts = self.index.count_next()
47+
do_sample(counts)
48+
}
49+
}
50+
```

‎src/countable_index.rs

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
use std::collections::HashMap;
2+
3+
pub trait CountableIndex: Send + Sync {
4+
fn count_next(&self, query: Vec<u16>, vocab: Option<u16>) -> Vec<usize>;
5+
6+
/// Generate a frequency map from an occurrence frequency
7+
/// to the number of n-grams in the data structure with that
8+
/// frequency.
9+
fn count_ngrams(&self, n: usize) -> HashMap<usize, usize>;
10+
}

‎src/in_memory_index.rs

+17-43
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
use bincode::{deserialize, serialize};
2-
use pyo3::exceptions::PyValueError;
2+
use std::collections::HashMap;
33
use pyo3::prelude::*;
44
use std::fs::File;
55
use std::io::Read;
66

77
use crate::table::SuffixTable;
8+
use crate::countable_index::CountableIndex;
89
use crate::util::transmute_slice;
910

1011
/// An in-memory index exposes suffix table functionality over text corpora small enough to fit in memory.
1112
#[pyclass]
13+
#[derive(Builder)]
1214
pub struct InMemoryIndex {
1315
table: SuffixTable,
1416
}
@@ -59,14 +61,14 @@ impl InMemoryIndex {
5961
self.table.contains(&query)
6062
}
6163

62-
pub fn count(&self, query: Vec<u16>) -> usize {
63-
self.table.positions(&query).len()
64-
}
65-
6664
pub fn positions(&self, query: Vec<u16>) -> Vec<u64> {
6765
self.table.positions(&query).to_vec()
6866
}
6967

68+
pub fn count(&self, query: Vec<u16>) -> usize {
69+
self.table.positions(&query).len()
70+
}
71+
7072
pub fn count_next(&self, query: Vec<u16>, vocab: Option<u16>) -> Vec<usize> {
7173
self.table.count_next(&query, vocab)
7274
}
@@ -75,48 +77,20 @@ impl InMemoryIndex {
7577
self.table.batch_count_next(&queries, vocab)
7678
}
7779

78-
pub fn sample_unsmoothed(
79-
&self,
80-
query: Vec<u16>,
81-
n: usize,
82-
k: usize,
83-
num_samples: usize,
84-
vocab: Option<u16>
85-
) -> Result<Vec<Vec<u16>>, PyErr> {
86-
self.table
87-
.sample_unsmoothed(&query, n, k, num_samples, vocab)
88-
.map_err(|error| PyValueError::new_err(error.to_string()))
89-
}
90-
91-
pub fn sample_smoothed(
92-
&mut self,
93-
query: Vec<u16>,
94-
n: usize,
95-
k: usize,
96-
num_samples: usize,
97-
vocab: Option<u16>
98-
) -> Result<Vec<Vec<u16>>, PyErr> {
99-
self.table
100-
.sample_smoothed(&query, n, k, num_samples, vocab)
101-
.map_err(|error| PyValueError::new_err(error.to_string()))
102-
}
103-
104-
pub fn smoothed_probs(&mut self, query: Vec<u16>, vocab: Option<u16>) -> Vec<f64> {
105-
self.table.get_smoothed_probs(&query, vocab)
106-
}
107-
108-
pub fn batch_smoothed_probs(&mut self, queries: Vec<Vec<u16>>, vocab: Option<u16>) -> Vec<Vec<f64>> {
109-
self.table.batch_get_smoothed_probs(&queries, vocab)
110-
}
111-
112-
pub fn estimate_deltas(&mut self, n: usize) {
113-
self.table.estimate_deltas(n);
114-
}
115-
11680
pub fn save(&self, path: String) -> PyResult<()> {
11781
// TODO: handle errors here
11882
let bytes = serialize(&self.table).unwrap();
11983
std::fs::write(&path, bytes)?;
12084
Ok(())
12185
}
12286
}
87+
88+
impl CountableIndex for InMemoryIndex {
89+
fn count_next(&self, query: Vec<u16>, vocab: Option<u16>) -> Vec<usize> {
90+
self.table.count_next(&query, vocab)
91+
}
92+
93+
fn count_ngrams(&self, n: usize) -> HashMap<usize, usize> {
94+
self.table.count_ngrams(n)
95+
}
96+
}

‎src/lib.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
11
pub mod mmap_slice;
22
pub use in_memory_index::InMemoryIndex;
33
pub use memmap_index::MemmapIndex;
4-
pub use sharded_index::ShardedIndex;
4+
pub use sharded_memmap_index::ShardedMemmapIndex;
55
pub use table::SuffixTable;
66

77
/// Python bindings
88
use pyo3::prelude::*;
99

10+
mod sharded_memmap_index;
1011
mod in_memory_index;
1112
mod memmap_index;
12-
mod sharded_index;
13-
mod par_quicksort;
13+
mod countable_index;
14+
mod sampler;
1415
mod table;
16+
mod par_quicksort;
1517
mod util;
1618

19+
#[macro_use]
20+
extern crate derive_builder;
21+
1722
#[pymodule]
1823
fn tokengrams(_py: Python, m: &PyModule) -> PyResult<()> {
1924
m.add_class::<InMemoryIndex>()?;
2025
m.add_class::<MemmapIndex>()?;
21-
m.add_class::<ShardedIndex>()?;
26+
m.add_class::<ShardedMemmapIndex>()?;
2227
Ok(())
2328
}

0 commit comments

Comments
 (0)
Please sign in to comment.