11use std:: {
2- fs:: File ,
2+ fs:: { File , create_dir_all } ,
33 io:: { BufWriter , Error , Write } ,
44 path:: Path ,
5+ sync:: Arc ,
56} ;
67
8+ use arrow:: {
9+ array:: { ArrayRef , PrimitiveBuilder , RecordBatch } ,
10+ datatypes:: { DataType , Float64Type , Int64Type , Schema } ,
11+ error:: ArrowError ,
12+ } ;
713use hftbacktest_derive:: NpyDTyped ;
14+ use once_cell:: sync:: Lazy ;
15+ use parquet:: { arrow:: ArrowWriter , basic:: Compression , file:: properties:: WriterProperties } ;
816use zip:: { ZipWriter , write:: SimpleFileOptions } ;
917
1018use crate :: {
@@ -22,6 +30,10 @@ struct Record {
2230 balance : f64 ,
2331 fee : f64 ,
2432 num_trades : i64 ,
33+ num_messages : i64 ,
34+ num_cancellations : i64 ,
35+ num_creations : i64 ,
36+ num_modifications : i64 ,
2537 trading_volume : f64 ,
2638 trading_value : f64 ,
2739}
@@ -34,6 +46,140 @@ pub struct BacktestRecorder {
3446 values : Vec < Vec < Record > > ,
3547}
3648
49+ pub static ACCOUNT_STATE_DATA_POINT_FIELDS : Lazy < Vec < arrow:: datatypes:: Field > > = Lazy :: new ( || {
50+ vec ! [
51+ arrow:: datatypes:: Field :: new( "timestamp" , DataType :: Int64 , true ) ,
52+ arrow:: datatypes:: Field :: new( "balance" , DataType :: Float64 , true ) ,
53+ arrow:: datatypes:: Field :: new( "position" , DataType :: Float64 , true ) ,
54+ arrow:: datatypes:: Field :: new( "fee" , DataType :: Float64 , true ) ,
55+ arrow:: datatypes:: Field :: new( "trading_volume" , DataType :: Float64 , true ) ,
56+ arrow:: datatypes:: Field :: new( "trading_value" , DataType :: Float64 , true ) ,
57+ arrow:: datatypes:: Field :: new( "num_trades" , DataType :: Int64 , true ) ,
58+ arrow:: datatypes:: Field :: new( "num_messages" , DataType :: Int64 , true ) ,
59+ arrow:: datatypes:: Field :: new( "num_cancellations" , DataType :: Int64 , true ) ,
60+ arrow:: datatypes:: Field :: new( "num_creations" , DataType :: Int64 , true ) ,
61+ arrow:: datatypes:: Field :: new( "num_modifications" , DataType :: Int64 , true ) ,
62+ arrow:: datatypes:: Field :: new( "price" , DataType :: Float64 , true ) ,
63+ ]
64+ } ) ;
65+
66+ pub trait ColumnsBuilder < ' a > {
67+ type T ;
68+
69+ fn get_batch ( & mut self ) -> Result < RecordBatch , ArrowError > ;
70+ fn append ( & mut self , msg : & ' a Self :: T ) -> Result < ( ) , ArrowError > ;
71+
72+ fn reset ( & mut self ) -> Result < ( ) , ArrowError > ;
73+ }
74+
75+ pub struct AccountStateDataPointColumnsBuilder {
76+ schema : Schema ,
77+ timestamp_builder : PrimitiveBuilder < Int64Type > ,
78+ balance_builder : PrimitiveBuilder < Float64Type > ,
79+ position_builder : PrimitiveBuilder < Float64Type > ,
80+ fee_builder : PrimitiveBuilder < Float64Type > ,
81+ trading_volume_builder : PrimitiveBuilder < Float64Type > ,
82+ trading_value_builder : PrimitiveBuilder < Float64Type > ,
83+ num_trades_builder : PrimitiveBuilder < Int64Type > ,
84+ num_messages_builder : PrimitiveBuilder < Int64Type > ,
85+ num_cancellations_builder : PrimitiveBuilder < Int64Type > ,
86+ num_creations_builder : PrimitiveBuilder < Int64Type > ,
87+ num_modifications_builder : PrimitiveBuilder < Int64Type > ,
88+ price_builder : PrimitiveBuilder < Float64Type > ,
89+ }
90+
91+ pub struct AccountStateDataPoint {
92+ pub timestamp : i64 ,
93+ pub balance : f64 ,
94+ pub position : f64 ,
95+ pub fee : f64 ,
96+ pub trading_volume : f64 ,
97+ pub trading_value : f64 ,
98+ pub num_trades : i64 ,
99+ pub num_messages : i64 ,
100+ pub num_cancellations : i64 ,
101+ pub num_creations : i64 ,
102+ pub num_modifications : i64 ,
103+ pub price : f64 ,
104+ }
105+
106+ impl < ' a > ColumnsBuilder < ' a > for AccountStateDataPointColumnsBuilder {
107+ type T = AccountStateDataPoint ;
108+
109+ fn get_batch ( & mut self ) -> Result < RecordBatch , ArrowError > {
110+ let arrays: Vec < ArrayRef > = vec ! [
111+ Arc :: new( self . timestamp_builder. finish( ) ) ,
112+ Arc :: new( self . balance_builder. finish( ) ) ,
113+ Arc :: new( self . position_builder. finish( ) ) ,
114+ Arc :: new( self . fee_builder. finish( ) ) ,
115+ Arc :: new( self . trading_volume_builder. finish( ) ) ,
116+ Arc :: new( self . trading_value_builder. finish( ) ) ,
117+ Arc :: new( self . num_trades_builder. finish( ) ) ,
118+ Arc :: new( self . num_messages_builder. finish( ) ) ,
119+ Arc :: new( self . num_cancellations_builder. finish( ) ) ,
120+ Arc :: new( self . num_creations_builder. finish( ) ) ,
121+ Arc :: new( self . num_modifications_builder. finish( ) ) ,
122+ Arc :: new( self . price_builder. finish( ) ) ,
123+ ] ;
124+ let batch = RecordBatch :: try_new ( Arc :: new ( self . schema . clone ( ) ) , arrays) ?;
125+ Ok ( batch)
126+ }
127+
128+ fn append ( & mut self , msg : & AccountStateDataPoint ) -> Result < ( ) , ArrowError > {
129+ self . timestamp_builder . append_value ( msg. timestamp ) ;
130+ self . balance_builder . append_value ( msg. balance ) ;
131+ self . position_builder . append_value ( msg. position ) ;
132+ self . fee_builder . append_value ( msg. fee ) ;
133+ self . trading_volume_builder . append_value ( msg. trading_volume ) ;
134+ self . trading_value_builder . append_value ( msg. trading_value ) ;
135+ self . num_trades_builder . append_value ( msg. num_trades ) ;
136+ self . num_messages_builder . append_value ( msg. num_messages ) ;
137+ self . num_cancellations_builder
138+ . append_value ( msg. num_cancellations ) ;
139+ self . num_creations_builder . append_value ( msg. num_creations ) ;
140+ self . num_modifications_builder
141+ . append_value ( msg. num_modifications ) ;
142+ self . price_builder . append_value ( msg. price ) ;
143+ return Ok ( ( ) ) ;
144+ }
145+
146+ fn reset ( & mut self ) -> Result < ( ) , ArrowError > {
147+ self . timestamp_builder = Default :: default ( ) ;
148+ self . balance_builder = Default :: default ( ) ;
149+ self . position_builder = Default :: default ( ) ;
150+ self . fee_builder = Default :: default ( ) ;
151+ self . trading_volume_builder = Default :: default ( ) ;
152+ self . trading_value_builder = Default :: default ( ) ;
153+ self . num_trades_builder = Default :: default ( ) ;
154+ self . num_messages_builder = Default :: default ( ) ;
155+ self . num_cancellations_builder = Default :: default ( ) ;
156+ self . num_creations_builder = Default :: default ( ) ;
157+ self . num_modifications_builder = Default :: default ( ) ;
158+ self . price_builder = Default :: default ( ) ;
159+ return Ok ( ( ) ) ;
160+ }
161+ }
162+
163+ impl AccountStateDataPointColumnsBuilder {
164+ pub fn new ( schema : Schema ) -> AccountStateDataPointColumnsBuilder {
165+ AccountStateDataPointColumnsBuilder {
166+ schema,
167+ timestamp_builder : Default :: default ( ) ,
168+ balance_builder : Default :: default ( ) ,
169+ position_builder : Default :: default ( ) ,
170+ fee_builder : Default :: default ( ) ,
171+ trading_volume_builder : Default :: default ( ) ,
172+ trading_value_builder : Default :: default ( ) ,
173+ num_trades_builder : Default :: default ( ) ,
174+ num_messages_builder : Default :: default ( ) ,
175+ num_cancellations_builder : Default :: default ( ) ,
176+ num_creations_builder : Default :: default ( ) ,
177+ num_modifications_builder : Default :: default ( ) ,
178+ price_builder : Default :: default ( ) ,
179+ }
180+ }
181+ }
182+
37183impl Recorder for BacktestRecorder {
38184 type Error = Error ;
39185
@@ -57,6 +203,10 @@ impl Recorder for BacktestRecorder {
57203 trading_volume : state_values. trading_volume ,
58204 trading_value : state_values. trading_value ,
59205 num_trades : state_values. num_trades ,
206+ num_messages : state_values. num_messages ,
207+ num_cancellations : state_values. num_cancellations ,
208+ num_creations : state_values. num_creations ,
209+ num_modifications : state_values. num_modifications ,
60210 } ) ;
61211 }
62212 Ok ( ( ) )
@@ -91,38 +241,53 @@ impl BacktestRecorder {
91241 P : AsRef < Path > ,
92242 {
93243 let prefix = prefix. as_ref ( ) ;
244+ let base_path = path. as_ref ( ) ;
245+ create_dir_all ( base_path) ?;
246+
247+ // Buffer output to reduce frequent file I/O
94248 for ( asset_no, values) in self . values . iter ( ) . enumerate ( ) {
95- let file_path = path. as_ref ( ) . join ( format ! ( "{prefix}{asset_no}.csv" ) ) ;
96- let mut file = BufWriter :: new ( File :: create ( file_path) ?) ;
97- writeln ! (
98- file,
99- "timestamp,balance,position,fee,trading_volume,trading_value,num_trades,price" ,
249+ let file_path = base_path. join ( format ! ( "{prefix}{asset_no}.csv" ) ) ;
250+ let mut file = BufWriter :: new ( File :: create ( file_path) ?) ; // Use BufWriter for buffered writing
251+
252+ // Write header
253+ file. write_all (
254+ b"timestamp,balance,position,fee,trading_volume,trading_value,num_trades,num_messages,num_cancellations,num_creations,num_modifications,price\n " ,
100255 ) ?;
101- for Record {
102- timestamp,
103- balance,
104- position,
105- fee,
106- trading_volume,
107- trading_value,
108- num_trades,
109- price : mid_price,
110- } in values
111- {
112- writeln ! (
113- file,
114- "{timestamp},{balance},{position},{fee},{trading_volume},{trading_value},{num_trades},{mid_price}"
115- ) ?;
256+
257+ // Write records
258+ for record in values {
259+ let line = format ! (
260+ "{},{},{},{},{},{},{},{},{},{},{},{}\n " ,
261+ record. timestamp,
262+ record. balance,
263+ record. position,
264+ record. fee,
265+ record. trading_volume,
266+ record. trading_value,
267+ record. num_trades,
268+ record. num_messages,
269+ record. num_cancellations,
270+ record. num_creations,
271+ record. num_modifications,
272+ record. price,
273+ ) ;
274+ file. write_all ( line. as_bytes ( ) ) ?;
116275 }
117276 }
118277 Ok ( ( ) )
119278 }
120279
121- pub fn to_npz < P > ( & self , path : P ) -> Result < ( ) , Error >
280+ pub fn to_npz < Prefix , P > ( & self , prefix : Prefix , path : P ) -> Result < ( ) , Error >
122281 where
282+ Prefix : AsRef < str > ,
123283 P : AsRef < Path > ,
124284 {
125- let file = File :: create ( path) ?;
285+ let prefix = prefix. as_ref ( ) ;
286+ let base_path = path. as_ref ( ) ;
287+ create_dir_all ( base_path) ?;
288+
289+ let file_path = base_path. join ( format ! ( "{prefix}.npz" ) ) ;
290+ let file = File :: create ( file_path) ?;
126291
127292 let mut zip = ZipWriter :: new ( file) ;
128293
@@ -138,4 +303,71 @@ impl BacktestRecorder {
138303 zip. finish ( ) ?;
139304 Ok ( ( ) )
140305 }
306+
307+ pub fn to_parquet < Prefix , P > ( & self , prefix : Prefix , path : P ) -> Result < ( ) , Error >
308+ where
309+ Prefix : AsRef < str > ,
310+ P : AsRef < Path > ,
311+ {
312+ let prefix = prefix. as_ref ( ) ;
313+ let base_path = path. as_ref ( ) ;
314+ create_dir_all ( base_path) ?;
315+
316+ // Buffer output to reduce frequent file I/O
317+ for ( asset_no, values) in self . values . iter ( ) . enumerate ( ) {
318+ let parquet_schema = Schema :: new ( ACCOUNT_STATE_DATA_POINT_FIELDS . clone ( ) ) ;
319+ let arrow_schema = Arc :: new ( parquet_schema. clone ( ) ) ;
320+ let parquet_props = WriterProperties :: builder ( )
321+ . set_compression ( Compression :: SNAPPY )
322+ . build ( ) ;
323+
324+ let file_path = base_path. join ( format ! ( "{prefix}{asset_no}.snappy.parquet" ) ) ;
325+ let file = File :: create ( file_path) . unwrap ( ) ;
326+
327+ let mut wrt =
328+ ArrowWriter :: try_new ( file, arrow_schema. clone ( ) , Some ( parquet_props) ) . unwrap ( ) ;
329+
330+ let mut builder = AccountStateDataPointColumnsBuilder :: new ( parquet_schema. clone ( ) ) ;
331+
332+ let max_rows_per_batch: usize = 10 ;
333+ let mut row: usize = 0 ;
334+
335+ // Write records
336+ for record in values {
337+ row += 1 ;
338+ let single_row = AccountStateDataPoint {
339+ timestamp : record. timestamp ,
340+ balance : record. balance ,
341+ position : record. position ,
342+ fee : record. fee ,
343+ trading_volume : record. trading_volume ,
344+ trading_value : record. trading_value ,
345+ num_trades : record. num_trades ,
346+ num_messages : record. num_messages ,
347+ num_cancellations : record. num_cancellations ,
348+ num_creations : record. num_creations ,
349+ num_modifications : record. num_modifications ,
350+ price : record. price ,
351+ } ;
352+ builder. append ( & single_row) . unwrap ( ) ;
353+ row += 1 ;
354+
355+ if row > 0 && row % max_rows_per_batch == 0 {
356+ let batch = builder. get_batch ( ) . unwrap ( ) ;
357+ wrt. write ( & batch) . unwrap ( ) ;
358+ builder. reset ( ) . unwrap ( ) ;
359+ }
360+ }
361+
362+ // Write remaining data
363+ {
364+ let batch = builder. get_batch ( ) . unwrap ( ) ;
365+ wrt. write ( & batch) . unwrap ( ) ;
366+ builder. reset ( ) . unwrap ( ) ;
367+ }
368+
369+ wrt. close ( ) . unwrap ( ) ;
370+ }
371+ Ok ( ( ) )
372+ }
141373}
0 commit comments