Skip to content

Commit bc6d036

Browse files
committedAug 11, 2024·
move in memory index tests into private rust-side in memory index file
1 parent 842ac42 commit bc6d036

File tree

4 files changed

+112
-144
lines changed

4 files changed

+112
-144
lines changed
 

‎src/bindings/in_memory_index.rs

-16
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,6 @@ pub trait InMemoryIndexTrait {
3939
fn estimate_deltas(&mut self, n: usize);
4040
}
4141

42-
impl InMemoryIndex {
43-
pub fn new(tokens: Vec<usize>, vocab: Option<usize>, verbose: bool) -> Self {
44-
let vocab = vocab.unwrap_or(u16::MAX as usize + 1);
45-
46-
let index: Box<dyn InMemoryIndexTrait + Send + Sync> = if vocab <= u16::MAX as usize + 1 {
47-
let tokens: Vec<u16> = tokens.iter().map(|&x| x as u16).collect();
48-
Box::new(InMemoryIndexRs::<u16>::new(tokens, Some(vocab), verbose))
49-
} else {
50-
let tokens: Vec<u32> = tokens.iter().map(|&x| x as u32).collect();
51-
Box::new(InMemoryIndexRs::<u32>::new(tokens, Some(vocab), verbose))
52-
};
53-
54-
InMemoryIndex { index }
55-
}
56-
}
57-
5842
#[pymethods]
5943
impl InMemoryIndex {
6044
#[new]

‎src/bindings/memmap_index.rs

+1-15
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub trait MemmapIndexTrait {
4040
impl MemmapIndex {
4141
#[new]
4242
#[pyo3(signature = (text_path, table_path, vocab=u16::MAX as usize + 1))]
43-
pub fn new_py(
43+
pub fn new(
4444
_py: Python,
4545
text_path: String,
4646
table_path: String,
@@ -146,17 +146,3 @@ impl MemmapIndex {
146146
self.index.estimate_deltas(n);
147147
}
148148
}
149-
150-
impl MemmapIndex {
151-
pub fn new(text_path: String, table_path: String, vocab: usize) -> Result<Self> {
152-
if vocab <= u16::MAX as usize + 1 {
153-
Ok(MemmapIndex {
154-
index: Box::new(MemmapIndexRs::<u16>::new(text_path, table_path, vocab)?),
155-
})
156-
} else {
157-
Ok(MemmapIndex {
158-
index: Box::new(MemmapIndexRs::<u32>::new(text_path, table_path, vocab)?),
159-
})
160-
}
161-
}
162-
}

‎src/in_memory_index.rs

+110
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,113 @@ impl<T: Unsigned> InMemoryIndexTrait for InMemoryIndexRs<T> {
283283
<Self as Sample<T>>::estimate_deltas(self, n)
284284
}
285285
}
286+
287+
288+
#[cfg(test)]
289+
pub mod tests {
290+
use super::*;
291+
use utf16_literal::utf16;
292+
use crate::table::SuffixTable;
293+
294+
fn sais(text: &str) -> SuffixTable {
295+
SuffixTable::new(text.encode_utf16().collect::<Vec<_>>(), None, false)
296+
}
297+
298+
fn utf16_as_usize(s: &str) -> Vec<usize> {
299+
s.encode_utf16().map(|x| x as usize).collect()
300+
}
301+
302+
#[test]
303+
fn sample_unsmoothed_empty_query_exists() {
304+
let s = utf16!("aaa");
305+
let index: Box<dyn Sample<u16>> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false));
306+
307+
let seqs = index.sample_unsmoothed(&[], 3, 10, 1).unwrap();
308+
309+
assert_eq!(*seqs[0].last().unwrap(), s[0]);
310+
}
311+
312+
#[test]
313+
fn sample_unsmoothed_u16_exists() {
314+
let s = utf16!("aaaa");
315+
let a = &s[0..1];
316+
let index: Box<dyn Sample<u16>> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false));
317+
318+
let seqs = index.sample_unsmoothed(a, 3, 10, 1).unwrap();
319+
320+
assert_eq!(*seqs[0].last().unwrap(), a[0]);
321+
}
322+
323+
#[test]
324+
fn sample_unsmoothed_u32_exists() {
325+
let s: Vec<u32> = "aaaa".encode_utf16().map(|c| c as u32).collect();
326+
let u32_vocab = Some(u16::MAX as usize + 2);
327+
let index: Box<dyn Sample<u32>> = Box::new(InMemoryIndexRs::<u32>::new(s.clone(), u32_vocab, false));
328+
329+
let seqs = index.sample_unsmoothed(&s[0..1], 3, 10, 1).unwrap();
330+
331+
assert_eq!(*seqs[0].last().unwrap(), s[0]);
332+
}
333+
334+
#[test]
335+
fn sample_unsmoothed_usize_exists() {
336+
let s = utf16_as_usize("aaaa");
337+
let index: Box<dyn InMemoryIndexTrait> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false));
338+
339+
let seqs = index.sample_unsmoothed(s[0..1].to_vec(), 3, 10, 1).unwrap();
340+
341+
assert_eq!(*seqs[0].last().unwrap(), s[0]);
342+
}
343+
344+
#[test]
345+
fn sample_smoothed_exists() {
346+
let s = utf16!("aabbccabccba");
347+
let mut index: Box<dyn Sample<u16>> = Box::new(InMemoryIndexRs::new(s.to_vec(), None, false));
348+
349+
let tokens = &index.sample_smoothed(&s[0..1], 3, 10, 1).unwrap()[0];
350+
351+
assert_eq!(tokens.len(), 11);
352+
}
353+
354+
#[test]
355+
fn sample_smoothed_empty_query_exists() {
356+
let s: Vec<u16> = "aabbccabccba".encode_utf16().collect();
357+
let mut index: Box<dyn Sample<u16>> = Box::new(InMemoryIndexRs::new(s, None, false));
358+
359+
let tokens = &index.sample_smoothed(&[], 1, 10, 10).unwrap()[0];
360+
361+
assert_eq!(tokens.len(), 10);
362+
}
363+
364+
#[test]
365+
fn smoothed_probs_exists() {
366+
let tokens = "aaaaaaaabc".to_string();
367+
let tokens_vec: Vec<u16> = tokens.encode_utf16().collect();
368+
let query: Vec<_> = vec![utf16!("b")[0]];
369+
370+
// Get unsmoothed probs for query
371+
let sa: SuffixTable = sais(&tokens);
372+
let bigram_counts = sa.count_next(&query);
373+
let unsmoothed_probs = bigram_counts
374+
.iter()
375+
.map(|&x| x as f64 / bigram_counts.iter().sum::<usize>() as f64)
376+
.collect::<Vec<f64>>();
377+
378+
// Get smoothed probs for query
379+
let mut index: Box<dyn Sample<u16>> = Box::new(InMemoryIndexRs::new(tokens_vec, None, false));
380+
let smoothed_probs = index.get_smoothed_probs(&query);
381+
382+
// Compare unsmoothed and smoothed probabilities
383+
let a = utf16!("a")[0] as usize;
384+
let c = utf16!("c")[0] as usize;
385+
386+
// The naive bigram probability for query 'b' is p(c) = 1.0.
387+
assert!(unsmoothed_probs[a] == 0.0);
388+
assert!(unsmoothed_probs[c] == 1.0);
389+
390+
// The smoothed bigram probabilities interpolate with the lower-order unigram
391+
// probabilities where p(a) is high, lowering p(c)
392+
assert!(smoothed_probs[a] > 0.1);
393+
assert!(smoothed_probs[c] < 1.0);
394+
}
395+
}

‎tests/tests.rs

+1-113
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ extern crate quickcheck;
22
extern crate utf16_literal;
33

44
use quickcheck::{QuickCheck, Testable};
5-
use tokengrams::{InMemoryIndex, SuffixTable};
5+
use tokengrams::SuffixTable;
66
use utf16_literal::utf16;
77

88
fn sais(text: &str) -> SuffixTable {
@@ -13,10 +13,6 @@ fn qc<T: Testable>(f: T) {
1313
QuickCheck::new().tests(1000).max_tests(10000).quickcheck(f);
1414
}
1515

16-
fn utf16_as_usize(s: &str) -> Vec<usize> {
17-
s.encode_utf16().map(|x| x as usize).collect()
18-
}
19-
2016
// Do some testing on substring search.
2117

2218
#[test]
@@ -150,111 +146,3 @@ fn prop_positions() {
150146
}
151147
qc(prop as fn(String, u16) -> bool);
152148
}
153-
154-
#[test]
155-
fn sample_unsmoothed_exists() {
156-
let s = utf16_as_usize("aaaa");
157-
let a = &s[0..1];
158-
159-
let index = InMemoryIndex::new(s.clone(), None, false);
160-
let seqs = index.sample_unsmoothed(a.to_vec(), 3, 10, 20).unwrap();
161-
162-
assert_eq!(*seqs[0].last().unwrap(), a[0]);
163-
assert_eq!(*seqs[19].last().unwrap(), a[0]);
164-
}
165-
166-
#[test]
167-
fn sample_unsmoothed_empty_query_exists() {
168-
let s = utf16_as_usize("aaa");
169-
let a = s[0];
170-
let index = InMemoryIndex::new(s.clone(), None, false);
171-
let seqs = index.sample_unsmoothed(Vec::new(), 3, 10, 20).unwrap();
172-
173-
assert_eq!(*seqs[0].last().unwrap(), a);
174-
assert_eq!(*seqs[19].last().unwrap(), a);
175-
}
176-
177-
#[test]
178-
fn sample_smoothed_exists() {
179-
let s = utf16_as_usize("aabbccabccba");
180-
let mut index = InMemoryIndex::new(s.clone(), None, false);
181-
182-
let tokens = &index.sample_smoothed(s[0..1].to_vec(), 3, 10, 1).unwrap()[0];
183-
184-
assert_eq!(tokens.len(), 11);
185-
}
186-
187-
#[test]
188-
fn sample_smoothed_unigrams_exists() {
189-
let s = utf16_as_usize("aabbccabccba");
190-
let mut index = InMemoryIndex::new(s.clone(), None, false);
191-
192-
let tokens = &index.sample_smoothed(s[0..1].to_vec(), 1, 10, 10).unwrap()[0];
193-
194-
assert_eq!(tokens.len(), 11);
195-
}
196-
197-
#[test]
198-
fn prop_sample() {
199-
fn prop(s: String) -> bool {
200-
let s = utf16_as_usize(&s);
201-
if s.len() < 2 {
202-
return true;
203-
}
204-
205-
let query = match s.get(0..1) {
206-
Some(slice) => slice,
207-
None => &[],
208-
};
209-
let index = InMemoryIndex::new(s.clone(), None, false);
210-
211-
let got = &index.sample_unsmoothed(query.to_vec(), 2, 1, 1).unwrap()[0];
212-
s.contains(got.first().unwrap())
213-
}
214-
215-
qc(prop as fn(String) -> bool);
216-
}
217-
218-
#[test]
219-
fn smoothed_probs_exists() {
220-
let tokens = "aaaaaaaabc".to_string();
221-
222-
let sa: SuffixTable = sais(&tokens);
223-
let query = vec![utf16!("b")[0]];
224-
let vocab = utf16!("c")[0] + 1;
225-
let a = utf16!("a")[0] as usize;
226-
let c = utf16!("c")[0] as usize;
227-
228-
let bigram_counts = sa.count_next(&query);
229-
let unsmoothed_probs = bigram_counts
230-
.iter()
231-
.map(|&x| x as f64 / bigram_counts.iter().sum::<usize>() as f64)
232-
.collect::<Vec<f64>>();
233-
234-
let s = utf16_as_usize(&tokens);
235-
let query = utf16_as_usize("b");
236-
let mut index = InMemoryIndex::new(s.clone(), Some(vocab as usize), false);
237-
let smoothed_probs = index.get_smoothed_probs(query);
238-
239-
// The naive bigram probability for query 'b' is p(c) = 1.0.
240-
assert!(unsmoothed_probs[a] == 0.0);
241-
assert!(unsmoothed_probs[c] == 1.0);
242-
243-
// The smoothed bigram probabilities interpolate with the lower-order unigram
244-
// probabilities where p(a) is high, lowering p(c)
245-
assert!(smoothed_probs[a] > 0.1);
246-
assert!(smoothed_probs[c] < 1.0);
247-
}
248-
249-
#[test]
250-
fn smoothed_probs_empty_query_exists() {
251-
let s = utf16_as_usize("aaa");
252-
let vocab = s[0] + 1;
253-
254-
let mut index = InMemoryIndex::new(s, Some(vocab), false);
255-
256-
let probs = index.get_smoothed_probs(Vec::new());
257-
let residual = (probs.iter().sum::<f64>() - 1.0).abs();
258-
259-
assert!(residual < 1e-4);
260-
}

0 commit comments

Comments
 (0)
Please sign in to comment.