1
1
use std:: collections:: HashSet ;
2
2
3
3
use pyo3:: {
4
- PyResult , exceptions,
4
+ IntoPyObjectExt , PyResult , exceptions,
5
5
prelude:: * ,
6
6
pybacked:: PyBackedStr ,
7
- types:: { PyBytes , PyList , PyTuple } ,
7
+ types:: { PyBytes , PyList } ,
8
8
} ;
9
9
use rustc_hash:: FxHashMap as HashMap ;
10
10
@@ -37,11 +37,14 @@ impl CoreBPE {
37
37
py : Python ,
38
38
text : & str ,
39
39
allowed_special : HashSet < PyBackedStr > ,
40
- ) -> Vec < Rank > {
40
+ ) -> PyResult < Vec < Rank > > {
41
41
py. allow_threads ( || {
42
42
let allowed_special: HashSet < & str > =
43
43
allowed_special. iter ( ) . map ( |s| s. as_ref ( ) ) . collect ( ) ;
44
- self . encode ( text, & allowed_special) . 0
44
+ match self . encode ( text, & allowed_special) {
45
+ Ok ( ( tokens, _) ) => Ok ( tokens) ,
46
+ Err ( e) => Err ( PyErr :: new :: < exceptions:: PyValueError , _ > ( e. message ) ) ,
47
+ }
45
48
} )
46
49
}
47
50
@@ -50,14 +53,20 @@ impl CoreBPE {
50
53
py : Python ,
51
54
text : & str ,
52
55
allowed_special : HashSet < PyBackedStr > ,
53
- ) -> Py < PyAny > {
54
- let tokens = py. allow_threads ( || {
56
+ ) -> PyResult < Py < PyAny > > {
57
+ let tokens_res = py. allow_threads ( || {
55
58
let allowed_special: HashSet < & str > =
56
59
allowed_special. iter ( ) . map ( |s| s. as_ref ( ) ) . collect ( ) ;
57
- self . encode ( text, & allowed_special) . 0
60
+ self . encode ( text, & allowed_special)
58
61
} ) ;
62
+
63
+ let tokens = match tokens_res {
64
+ Ok ( ( tokens, _) ) => tokens,
65
+ Err ( e) => return Err ( PyErr :: new :: < exceptions:: PyValueError , _ > ( e. message ) ) ,
66
+ } ;
67
+
59
68
let buffer = TiktokenBuffer { tokens } ;
60
- buffer. into_py ( py)
69
+ buffer. into_py_any ( py)
61
70
}
62
71
63
72
fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < Rank > {
@@ -69,7 +78,8 @@ impl CoreBPE {
69
78
// Unicode space, so we make our best guess at where we would have splits
70
79
Err ( e) => {
71
80
let text = unsafe { std:: str:: from_utf8_unchecked ( & bytes[ ..e. valid_up_to ( ) ] ) } ;
72
- let ( tokens, last_piece_token_len) = self . encode ( text, & HashSet :: new ( ) ) ;
81
+ let ( tokens, last_piece_token_len) =
82
+ self . encode ( text, & HashSet :: new ( ) ) . unwrap ( ) ;
73
83
let ( mut tokens, last_piece_token_len) =
74
84
self . _increase_last_piece_token_len ( tokens, last_piece_token_len) ;
75
85
@@ -110,19 +120,14 @@ impl CoreBPE {
110
120
py : Python ,
111
121
text : & str ,
112
122
allowed_special : HashSet < PyBackedStr > ,
113
- ) -> Py < PyTuple > {
114
- let ( tokens, completions) = py. allow_threads ( || {
123
+ ) -> PyResult < ( Vec < Rank > , Py < PyList > ) > {
124
+ let ( tokens, completions) : ( Vec < Rank > , HashSet < Vec < Rank > > ) = py. allow_threads ( || {
115
125
let allowed_special: HashSet < & str > =
116
126
allowed_special. iter ( ) . map ( |s| s. as_ref ( ) ) . collect ( ) ;
117
127
self . _encode_unstable_native ( text, & allowed_special)
118
128
} ) ;
119
- let py_completions = PyList :: new_bound (
120
- py,
121
- completions
122
- . iter ( )
123
- . map ( |seq| PyList :: new_bound ( py, & seq[ ..] ) ) ,
124
- ) ;
125
- ( tokens, py_completions) . into_py ( py)
129
+ let py_completions = PyList :: new ( py, completions. into_iter ( ) ) ?;
130
+ Ok ( ( tokens, py_completions. into ( ) ) )
126
131
}
127
132
128
133
fn encode_single_token ( & self , piece : & [ u8 ] ) -> PyResult < Rank > {
@@ -151,17 +156,17 @@ impl CoreBPE {
151
156
#[ pyo3( name = "decode_bytes" ) ]
152
157
fn py_decode_bytes ( & self , py : Python , tokens : Vec < Rank > ) -> Result < Py < PyBytes > , PyErr > {
153
158
match py. allow_threads ( || self . decode_bytes ( & tokens) ) {
154
- Ok ( bytes) => Ok ( PyBytes :: new_bound ( py, & bytes) . into ( ) ) ,
159
+ Ok ( bytes) => Ok ( PyBytes :: new ( py, & bytes) . into ( ) ) ,
155
160
Err ( e) => Err ( pyo3:: exceptions:: PyKeyError :: new_err ( format ! ( "{}" , e) ) ) ,
156
161
}
157
162
}
158
163
159
164
fn decode_single_token_bytes ( & self , py : Python , token : Rank ) -> PyResult < Py < PyBytes > > {
160
165
if let Some ( bytes) = self . decoder . get ( & token) {
161
- return Ok ( PyBytes :: new_bound ( py, bytes) . into ( ) ) ;
166
+ return Ok ( PyBytes :: new ( py, bytes) . into ( ) ) ;
162
167
}
163
168
if let Some ( bytes) = self . special_tokens_decoder . get ( & token) {
164
- return Ok ( PyBytes :: new_bound ( py, bytes) . into ( ) ) ;
169
+ return Ok ( PyBytes :: new ( py, bytes) . into ( ) ) ;
165
170
}
166
171
Err ( PyErr :: new :: < exceptions:: PyKeyError , _ > ( token. to_string ( ) ) )
167
172
}
0 commit comments