@@ -230,16 +230,18 @@ fn is_valid_token_pair(
230
230
) -> bool {
231
231
// Keep track of the maximum token which can still be chosen across the split point.
232
232
let mut limit = u32:: MAX ;
233
+ // println!("checking if {token1}, {token2} is a valid token_pair");
233
234
loop {
234
235
// Check whether BPE would choose a different token pair across the split point.
235
236
// this is super super important
236
237
if let Some ( combined) = pair_lookup. get ( & ( token1, token2) ) {
237
238
if * combined < limit {
239
+ // println!("Done1");
238
240
return false ;
239
241
}
240
242
}
241
243
// Reverse the merge operation from BPE.
242
- // println!("{token1}, {token2}");
244
+
243
245
// println!("{:?}", split_table);
244
246
if token1 > token2 {
245
247
limit = token1;
@@ -248,6 +250,7 @@ fn is_valid_token_pair(
248
250
limit = token2 + 1 ;
249
251
token2 = unsafe { split_table. get_unchecked ( token2 as usize ) . 0 } ;
250
252
if token2 + 1 == limit {
253
+ // println!("Done2");
251
254
return true ;
252
255
}
253
256
}
@@ -258,11 +261,13 @@ fn is_valid_token_pair(
258
261
limit = token1;
259
262
token1 = unsafe { split_table. get_unchecked ( token1 as usize ) . 1 } ;
260
263
if token1 == limit {
264
+ // println!("Done3");
261
265
return true ;
262
266
}
263
267
}
264
268
}
265
269
}
270
+
266
271
}
267
272
268
273
fn token_range ( token_starts : & [ u32 ] , token_id : u32 ) -> Range < usize > {
@@ -477,36 +482,36 @@ impl BacktrackingBpe {
477
482
let mut split_table = vec ! [ ] ;
478
483
let mut pair_lookup = FnvHashMap :: default ( ) ;
479
484
480
- // First option, use the input merge table.
481
- if let Some ( ref merges) = merges {
482
- for ( index, pair) in merges. into_iter ( ) . enumerate ( ) {
483
- let token1 = & pair. 0 . clone ( ) ;
484
- let token2 = & pair. 1 . clone ( ) ;
485
- // TODO something is weird here
486
- if token1. len ( ) ==1 {
487
- split_table. push ( ( vocab[ token1] , vocab[ token1] ) ) ;
488
- }
489
- if token2. len ( ) == 1 {
490
- split_table. push ( ( vocab[ token2] , vocab[ token2] ) ) ;
491
- }
492
- let id1 = vocab[ token1] ;
493
- let id2 = vocab[ token2] ;
494
- let new_token = format ! ( "{}{}" , token1, & token2) ;
495
- let new_id = vocab
496
- . get ( & new_token)
497
- . ok_or ( Error :: MergeTokenOutOfVocabulary ( new_token) ) ;
498
- if let Ok ( id) = new_id {
499
- pair_lookup. insert ( ( id1, id2) , * id) ;
500
- split_table. push ( ( id1, id2) ) ;
501
- merge_map. insert ( Pair :: from ( ( id1, id2) ) , ( index as u32 , * id) ) ;
502
- } else {
503
- println ! ( "Token not added?" ) ;
504
- }
505
-
506
- // TODO wrong
507
- }
508
- split_table. push ( ( merges. len ( ) as u32 , merges. len ( ) as u32 ) ) ;
509
- }
485
+ // // First option, use the input merge table.
486
+ // if let Some(ref merges) = merges {
487
+ // for (index, pair) in merges.into_iter().enumerate() {
488
+ // let token1 = &pair.0.clone();
489
+ // let token2 = &pair.1.clone();
490
+ // // TODO something is weird here
491
+ // if token1.len() ==1{
492
+ // split_table.push((vocab[token1], vocab[token1]));
493
+ // }
494
+ // if token2.len() == 1 {
495
+ // split_table.push((vocab[token2], vocab[token2]));
496
+ // }
497
+ // let id1 = vocab[token1];
498
+ // let id2 = vocab[token2];
499
+ // let new_token = format!("{}{}", token1, &token2);
500
+ // let new_id = vocab
501
+ // .get(&new_token)
502
+ // .ok_or(Error::MergeTokenOutOfVocabulary(new_token));
503
+ // if let Ok(id) = new_id {
504
+ // pair_lookup.insert((id1, id2), *id);
505
+ // split_table.push((id1, id2));
506
+ // merge_map.insert(Pair::from((id1, id2)), (index as u32, *id));
507
+ // } else {
508
+ // println!("Token not added?");
509
+ // }
510
+
511
+ // // TODO wrong
512
+ // }
513
+ // split_table.push((merges.len() as u32, merges.len() as u32));
514
+ // }
510
515
// Second option, reverse engineer the merge/split table from the vocabulary.
511
516
{
512
517
for ( id, token) in token_iter ( & all_tokens, & token_starts) . enumerate ( ) {
@@ -684,6 +689,7 @@ impl BacktrackingBpe {
684
689
last_token. push ( new_token) ;
685
690
break ;
686
691
}
692
+ // println!("Finished encoding prefix")
687
693
}
688
694
}
689
695
}
@@ -729,9 +735,9 @@ impl BacktrackingBpe {
729
735
let mut token = backtrack_state. next_token ?;
730
736
let last = backtrack_state. tokens . last ( ) . copied ( ) ;
731
737
loop {
738
+ // println!("in step, token: {last:?}, {token}");
732
739
let token_len = self . token_len ( token) ;
733
740
let end_pos = backtrack_state. pos + token_len;
734
- // println!("in step, token: {last:?}, {token}");
735
741
if backtrack_state. bitfield . is_set ( end_pos)
736
742
&& last
737
743
. map ( |last_token| self . is_valid_token_pair ( last_token, token) )
@@ -755,6 +761,8 @@ impl BacktrackingBpe {
755
761
break ;
756
762
}
757
763
}
764
+ // println!("finished step, token: {last:?}, {token}");
765
+
758
766
backtrack_state. next_token
759
767
}
760
768
0 commit comments