Skip to content

Commit e5976a6

Browse files
committed
Is this what's expected?
1 parent b51819d commit e5976a6

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

tokenizers/src/pre_tokenizers/byte_level.rs

+28-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ lazy_static! {
4141
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
4242
)
4343
.unwrap();
44+
static ref RE_VEC: Vec<SysRegex> = {
45+
let pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
46+
let mut vec = Vec::with_capacity(MAX_NUM_THREADS);
47+
for _ in 0..MAX_NUM_THREADS {
48+
vec.push(SysRegex::new(pattern).unwrap());
49+
}
50+
vec
51+
};
4452
static ref BYTES_CHAR: HashMap<u8, char> = bytes_char();
4553
static ref CHAR_BYTES: HashMap<char, u8> =
4654
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
@@ -111,12 +119,31 @@ impl ByteLevel {
111119
}
112120
}
113121

122+
use std::num::NonZeroU64;
123+
use std::thread;
124+
125+
pub struct FakeThreadId(NonZeroU64);
126+
127+
fn hash_current_thread() -> usize {
128+
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
129+
// that works great for our use case of avoiding collisions in our array. Unfortunately,
130+
// it's private. However, there are only so many ways you can layout a u64, so just transmute
131+
// https://github.com/rust-lang/rust/issues/67939
132+
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
133+
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
134+
let x =
135+
unsafe { std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0 };
136+
u64::from(x) as usize - 1
137+
}
138+
139+
const MAX_NUM_THREADS: usize = 128;
140+
114141
/// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into
115142
/// their byte-level counterpart. It also splits the input according to the configured regex.
116143
// TODO: Give the ability to modify this regex
117144
impl PreTokenizer for ByteLevel {
118145
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
119-
let re_ref: &SysRegex = &RE;
146+
let re_ref: &SysRegex = &RE_VEC[hash_current_thread() % MAX_NUM_THREADS]; // TODO use the thread thing here as well!
120147
pretokenized.split(|_, mut normalized| {
121148
if self.add_prefix_space && !normalized.get().starts_with(' ') {
122149
normalized.prepend(" ");

tokenizers/src/tokenizer/added_vocabulary.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ impl AddedVocabulary {
514514
// 1. We extract all the non-normalized tokens from the non-normalized string
515515
pretokenized
516516
.split(|_, sequence| {
517-
Ok(self.split_with_indices(
517+
Ok(self.fast_split_with_indices(
518518
sequence,
519519
&self.split_trie_vec[hash_current_thread() % MAX_NUM_THREADS],
520520
))

tokenizers/src/tokenizer/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ where
894894
) -> Result<Encoding> {
895895
let mut pretokenized: PreTokenizedString = pretokenized.into();
896896
pretokenized.tokenize(|normalized| self.model.tokenize(normalized.get()))?;
897-
pretokenized.into_encoding(word_idx, type_id, offsets_type)
897+
pretokenized.fast_into_encoding()
898898
}
899899
}
900900

0 commit comments

Comments
 (0)