Skip to content

Commit 81d8336

Browse files
committed
fix the unigram::from calls
1 parent 167ecde commit 81d8336

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

tokenizers/src/models/unigram/model.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,8 @@ mod tests {
548548
("abcd".to_string(), 10.0),
549549
];
550550

551-
let model = Unigram::from(sentencepieces, Some(0), false).unwrap();
551+
let model =
552+
Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap();
552553
let result = model.encode("abcd").unwrap();
553554
assert_eq!(result, vec!["abcd"]);
554555
}
@@ -570,7 +571,8 @@ mod tests {
570571
("qr".to_string(), -0.5),
571572
];
572573

573-
let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap();
574+
let mut model =
575+
Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap();
574576

575577
for is_optimized in &[true, false] {
576578
model.set_optimized(*is_optimized);
@@ -617,7 +619,8 @@ mod tests {
617619
("<0xC3>".to_string(), -0.01),
618620
("<0xA9>".to_string(), -0.03),
619621
];
620-
let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap();
622+
let unigram =
623+
Unigram::from(sentencepieces, Some(0), true, &AddedVocabulary::default()).unwrap();
621624
let tokens: Vec<Token> = unigram.tokenize("é").unwrap();
622625
assert_eq!(
623626
tokens,

tokenizers/src/models/unigram/serialization.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::AddedVocabulary;
2+
13
use super::model::Unigram;
24
use serde::{
35
de::{Error, MapAccess, Visitor},
@@ -69,8 +71,12 @@ impl<'de> Visitor<'de> for UnigramVisitor {
6971
}
7072
}
7173
match (vocab, unk_id, byte_fallback) {
72-
(Some(vocab), unk_id, byte_fallback) => Ok(Unigram::from(vocab, unk_id, byte_fallback)
73-
.map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?),
74+
(Some(vocab), unk_id, byte_fallback) => {
75+
Ok(
76+
Unigram::from(vocab, unk_id, byte_fallback, &AddedVocabulary::default())
77+
.map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?,
78+
)
79+
}
7480
(None, _, _) => Err(Error::custom("Missing vocab")),
7581
}
7682
}

0 commit comments

Comments
 (0)