2020//! contract. Since a contract can call other contracts, we need a way of restoring the counter after every execution.
2121//!
2222//! See `cairo-native-run` for an example on how to do it.
23- use std:: { collections:: HashSet , os :: raw :: c_void , ptr } ;
23+ use std:: collections:: HashSet ;
2424
2525use melior:: {
2626 dialect:: { llvm, memref, ods} ,
2727 ir:: {
2828 attribute:: { FlatSymbolRefAttribute , StringAttribute , TypeAttribute } ,
29- operation:: OperationBuilder ,
3029 r#type:: { IntegerType , MemRefType } ,
3130 Attribute , Block , BlockLike , Location , Module , Region , Value ,
3231 } ,
@@ -41,27 +40,15 @@ use crate::{
4140
4241#[ derive( Clone , Copy , Debug , Hash , PartialEq , Eq ) ]
4342pub enum LibfuncCounterBinding {
44- StoreArrayCounter ,
4543 CounterId ,
46- ArrayCounter ,
44+ CounterArray ,
4745}
4846
4947impl LibfuncCounterBinding {
5048 pub const fn symbol ( self ) -> & ' static str {
5149 match self {
52- LibfuncCounterBinding :: StoreArrayCounter => "cairo_native__store_array_counter" ,
5350 LibfuncCounterBinding :: CounterId => "cairo_native__counter_id" ,
54- LibfuncCounterBinding :: ArrayCounter => "cairo_native__array_counter" ,
55- }
56- }
57-
58- const fn function_ptr ( self ) -> * const ( ) {
59- match self {
60- LibfuncCounterBinding :: StoreArrayCounter => {
61- libfunc_counter_runtime:: store_array_counter as * const ( )
62- }
63- LibfuncCounterBinding :: CounterId => ptr:: null ( ) ,
64- LibfuncCounterBinding :: ArrayCounter => ptr:: null ( ) ,
51+ LibfuncCounterBinding :: CounterArray => "cairo_native__counter_array" ,
6552 }
6653 }
6754}
@@ -155,68 +142,24 @@ impl LibfuncCounterMeta {
155142 block. append_op_result ( memref:: load ( libfunc_counter_id_ptr, & [ ] , location) )
156143 }
157144
158- /// Indexes the array of counters and increments the counter relative
159- /// to the given libfunc index
160- pub fn store_array_counter (
161- & mut self ,
162- context : & Context ,
163- module : & Module ,
164- block : & Block < ' _ > ,
165- location : Location ,
166- libfunc_amount : u32 ,
167- ) -> Result < ( ) > {
168- let counter_id = self . build_counter_id ( context, module, block, location) ?;
169- let function_ptr = self . build_function (
170- context,
171- module,
172- block,
173- location,
174- LibfuncCounterBinding :: StoreArrayCounter ,
175- ) ?;
176- let lifuncs_amount = block. const_int ( context, location, libfunc_amount, 32 ) ?;
177- // by this time, the array counter should be initialized
178- let array_counter_ptr_ptr = block. append_op_result (
179- ods:: llvm:: mlir_addressof (
180- context,
181- llvm:: r#type:: pointer ( context, 0 ) ,
182- FlatSymbolRefAttribute :: new ( context, LibfuncCounterBinding :: ArrayCounter . symbol ( ) ) ,
183- location,
184- )
185- . into ( ) ,
186- ) ?;
187- let array_counter_ptr = block. load (
188- context,
189- location,
190- array_counter_ptr_ptr,
191- llvm:: r#type:: pointer ( context, 0 ) ,
192- ) ?;
193-
194- block. append_operation (
195- OperationBuilder :: new ( "llvm.call" , location)
196- . add_operands ( & [ function_ptr] )
197- . add_operands ( & [ counter_id, array_counter_ptr, lifuncs_amount] )
198- . build ( ) ?,
199- ) ;
200-
201- Ok ( ( ) )
202- }
203-
204145 /// Build the array of counters
205- fn get_array_counter < ' c , ' a > (
146+ fn build_array_counter < ' c , ' a > (
206147 & mut self ,
207148 context : & ' c Context ,
208149 module : & Module ,
209150 block : & ' a Block < ' c > ,
210151 location : Location < ' c > ,
211152 libfunc_amount : u32 ,
212153 ) -> Result < Value < ' c , ' a > > {
213- if self . active_map . insert ( LibfuncCounterBinding :: ArrayCounter ) {
154+ if self . active_map . insert ( LibfuncCounterBinding :: CounterArray ) {
155+ self . build_counter_id ( context, module, block, location) ?;
156+
214157 module. body ( ) . append_operation (
215158 ods:: llvm:: mlir_global (
216159 context,
217160 Region :: new ( ) ,
218161 TypeAttribute :: new ( llvm:: r#type:: pointer ( context, 0 ) ) ,
219- StringAttribute :: new ( context, LibfuncCounterBinding :: ArrayCounter . symbol ( ) ) ,
162+ StringAttribute :: new ( context, LibfuncCounterBinding :: CounterArray . symbol ( ) ) ,
220163 Attribute :: parse ( context, "#llvm.linkage<weak>" )
221164 . ok_or ( Error :: ParseAttributeError ) ?,
222165 location,
@@ -240,7 +183,7 @@ impl LibfuncCounterMeta {
240183 llvm:: r#type:: pointer ( context, 0 ) ,
241184 FlatSymbolRefAttribute :: new (
242185 context,
243- LibfuncCounterBinding :: ArrayCounter . symbol ( ) ,
186+ LibfuncCounterBinding :: CounterArray . symbol ( ) ,
244187 ) ,
245188 location,
246189 )
@@ -264,13 +207,13 @@ impl LibfuncCounterMeta {
264207 ods:: llvm:: mlir_addressof (
265208 context,
266209 llvm:: r#type:: pointer ( context, 0 ) ,
267- FlatSymbolRefAttribute :: new ( context, LibfuncCounterBinding :: ArrayCounter . symbol ( ) ) ,
210+ FlatSymbolRefAttribute :: new ( context, LibfuncCounterBinding :: CounterArray . symbol ( ) ) ,
268211 location,
269212 )
270213 . into ( ) ,
271214 ) ?;
272215
273- // // return the pointer to array counter
216+ // return the pointer to array counter
274217 block. load (
275218 context,
276219 location,
@@ -289,10 +232,11 @@ impl LibfuncCounterMeta {
289232 libfuncs_amount : u32 ,
290233 ) -> Result < ( ) > {
291234 let u32_ty = IntegerType :: new ( context, 32 ) . into ( ) ;
292- let k1 = block. const_int ( context, location, 0 , 32 ) ?;
235+ let k1 = block. const_int ( context, location, 1 , 32 ) ?;
293236
294237 let array_counter_ptr =
295- self . get_array_counter ( context, module, block, location, libfuncs_amount) ?;
238+ self . build_array_counter ( context, module, block, location, libfuncs_amount) ?;
239+
296240 let value_counter_ptr = block. gep (
297241 context,
298242 location,
@@ -310,24 +254,13 @@ impl LibfuncCounterMeta {
310254 }
311255}
312256
313- pub fn setup_runtime ( find_symbol_ptr : impl Fn ( & str ) -> Option < * mut c_void > ) {
314- let bindings = & [ LibfuncCounterBinding :: StoreArrayCounter ] ;
315-
316- for binding in bindings {
317- if let Some ( global) = find_symbol_ptr ( binding. symbol ( ) ) {
318- let global = global. cast :: < * const ( ) > ( ) ;
319- unsafe { * global = binding. function_ptr ( ) } ;
320- }
321- }
322- }
323-
324257pub mod libfunc_counter_runtime {
258+ use core:: slice;
325259 use std:: {
326260 collections:: HashMap ,
327261 sync:: { LazyLock , Mutex } ,
328262 } ;
329263
330- use itertools:: Itertools ;
331264 use melior:: {
332265 ir:: { Block , Location , Module } ,
333266 Context ,
@@ -364,16 +297,18 @@ pub mod libfunc_counter_runtime {
364297 )
365298 }
366299
367- pub unsafe extern "C" fn store_array_counter (
368- counter_id : u64 ,
369- array_counter : * const u32 ,
370- libfuncs_amount : u32 ,
300+ pub unsafe fn store_counters_array (
301+ counter_id_ptr : * mut u64 ,
302+ array_ptr_ptr : * mut * mut u32 ,
303+ libfuncs_amount : usize ,
371304 ) {
372- let mut libfunc_counter = LIBFUNC_COUNTER . lock ( ) . unwrap ( ) ;
373- let vec = ( 0 ..libfuncs_amount)
374- . map ( |i| * array_counter. add ( i as usize ) )
375- . collect_vec ( ) ;
305+ let counters_vec = slice:: from_raw_parts ( * array_ptr_ptr, libfuncs_amount) . to_vec ( ) ;
306+
307+ LIBFUNC_COUNTER
308+ . lock ( )
309+ . unwrap ( )
310+ . insert ( * counter_id_ptr, counters_vec) ;
376311
377- libfunc_counter . insert ( counter_id , vec ) ;
312+ libc :: free ( * array_ptr_ptr . cast :: < * mut libc :: c_void > ( ) ) ;
378313 }
379314}
0 commit comments