1
1
use std:: collections:: HashSet ;
2
2
3
3
use pyo3:: {
4
- exceptions,
4
+ PyResult , exceptions,
5
5
prelude:: * ,
6
6
pybacked:: PyBackedStr ,
7
7
types:: { PyBytes , PyList , PyTuple } ,
8
- PyResult ,
9
8
} ;
10
9
use rustc_hash:: FxHashMap as HashMap ;
11
10
12
- use crate :: { byte_pair_encode , CoreBPE , Rank } ;
11
+ use crate :: { CoreBPE , Rank , byte_pair_encode } ;
13
12
14
13
#[ pymethods]
15
14
impl CoreBPE {
@@ -19,12 +18,8 @@ impl CoreBPE {
19
18
special_tokens_encoder : HashMap < String , Rank > ,
20
19
pattern : & str ,
21
20
) -> PyResult < Self > {
22
- Self :: new_internal (
23
- encoder,
24
- special_tokens_encoder,
25
- pattern,
26
- )
27
- . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) )
21
+ Self :: new_internal ( encoder, special_tokens_encoder, pattern)
22
+ . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) )
28
23
}
29
24
30
25
// ====================
@@ -178,7 +173,7 @@ impl CoreBPE {
178
173
fn token_byte_values ( & self , py : Python ) -> Vec < Py < PyBytes > > {
179
174
self . sorted_token_bytes
180
175
. iter ( )
181
- . map ( |x| PyBytes :: new_bound ( py, x) . into ( ) )
176
+ . map ( |x| PyBytes :: new ( py, x) . into ( ) )
182
177
. collect ( )
183
178
}
184
179
}
@@ -204,39 +199,47 @@ impl TiktokenBuffer {
204
199
"Object is not writable" ,
205
200
) ) ;
206
201
}
207
-
208
- ( * view) . obj = slf. clone ( ) . into_any ( ) . into_ptr ( ) ;
209
-
210
- let data = & slf. borrow ( ) . tokens ;
211
- ( * view) . buf = data. as_ptr ( ) as * mut std:: os:: raw:: c_void ;
212
- ( * view) . len = ( data. len ( ) * std:: mem:: size_of :: < Rank > ( ) ) as isize ;
213
- ( * view) . readonly = 1 ;
214
- ( * view) . itemsize = std:: mem:: size_of :: < Rank > ( ) as isize ;
215
- ( * view) . format = if ( flags & pyo3:: ffi:: PyBUF_FORMAT ) == pyo3:: ffi:: PyBUF_FORMAT {
216
- let msg = std:: ffi:: CString :: new ( "I" ) . unwrap ( ) ;
217
- msg. into_raw ( )
218
- } else {
219
- std:: ptr:: null_mut ( )
220
- } ;
221
- ( * view) . ndim = 1 ;
222
- ( * view) . shape = if ( flags & pyo3:: ffi:: PyBUF_ND ) == pyo3:: ffi:: PyBUF_ND {
223
- & mut ( * view) . len
224
- } else {
225
- std:: ptr:: null_mut ( )
226
- } ;
227
- ( * view) . strides = if ( flags & pyo3:: ffi:: PyBUF_STRIDES ) == pyo3:: ffi:: PyBUF_STRIDES {
228
- & mut ( * view) . itemsize
229
- } else {
230
- std:: ptr:: null_mut ( )
231
- } ;
232
- ( * view) . suboffsets = std:: ptr:: null_mut ( ) ;
233
- ( * view) . internal = std:: ptr:: null_mut ( ) ;
202
+ unsafe {
203
+ let view_ref = & mut * view;
204
+ view_ref. obj = slf. clone ( ) . into_any ( ) . into_ptr ( ) ;
205
+
206
+ let data = & slf. borrow ( ) . tokens ;
207
+ view_ref. buf = data. as_ptr ( ) as * mut std:: os:: raw:: c_void ;
208
+ view_ref. len = ( data. len ( ) * std:: mem:: size_of :: < Rank > ( ) ) as isize ;
209
+ view_ref. readonly = 1 ;
210
+ view_ref. itemsize = std:: mem:: size_of :: < Rank > ( ) as isize ;
211
+ view_ref. format = if ( flags & pyo3:: ffi:: PyBUF_FORMAT ) == pyo3:: ffi:: PyBUF_FORMAT {
212
+ let msg = std:: ffi:: CString :: new ( "I" ) . unwrap ( ) ;
213
+ msg. into_raw ( )
214
+ } else {
215
+ std:: ptr:: null_mut ( )
216
+ } ;
217
+ view_ref. ndim = 1 ;
218
+ view_ref. shape = if ( flags & pyo3:: ffi:: PyBUF_ND ) == pyo3:: ffi:: PyBUF_ND {
219
+ & mut view_ref. len
220
+ } else {
221
+ std:: ptr:: null_mut ( )
222
+ } ;
223
+ view_ref. strides = if ( flags & pyo3:: ffi:: PyBUF_STRIDES ) == pyo3:: ffi:: PyBUF_STRIDES {
224
+ & mut view_ref. itemsize
225
+ } else {
226
+ std:: ptr:: null_mut ( )
227
+ } ;
228
+ view_ref. suboffsets = std:: ptr:: null_mut ( ) ;
229
+ view_ref. internal = std:: ptr:: null_mut ( ) ;
230
+ }
234
231
235
232
Ok ( ( ) )
236
233
}
237
234
238
235
unsafe fn __releasebuffer__ ( & self , view : * mut pyo3:: ffi:: Py_buffer ) {
239
- std:: mem:: drop ( std:: ffi:: CString :: from_raw ( ( * view) . format ) ) ;
236
+ // Note that Py_buffer doesn't have a Drop impl
237
+ unsafe {
238
+ let view_ref = & mut * view;
239
+ if !view_ref. format . is_null ( ) {
240
+ std:: mem:: drop ( std:: ffi:: CString :: from_raw ( view_ref. format ) ) ;
241
+ }
242
+ }
240
243
}
241
244
}
242
245
0 commit comments