Skip to content

Commit c703b9b

Browse files
Add number of messages to recorder; add possibility to store as parquet file
1 parent bbcc4d3 commit c703b9b

File tree

2 files changed

+289
-59
lines changed

2 files changed

+289
-59
lines changed

hftbacktest/src/backtest/recorder.rs

Lines changed: 255 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
use std::{
2-
fs::File,
2+
fs::{File, create_dir_all},
33
io::{BufWriter, Error, Write},
44
path::Path,
5+
sync::Arc,
56
};
67

8+
use arrow::{
9+
array::{ArrayRef, PrimitiveBuilder, RecordBatch},
10+
datatypes::{DataType, Float64Type, Int64Type, Schema},
11+
error::ArrowError,
12+
};
713
use hftbacktest_derive::NpyDTyped;
14+
use once_cell::sync::Lazy;
15+
use parquet::{arrow::ArrowWriter, basic::Compression, file::properties::WriterProperties};
816
use zip::{ZipWriter, write::SimpleFileOptions};
917

1018
use crate::{
@@ -22,6 +30,10 @@ struct Record {
2230
balance: f64,
2331
fee: f64,
2432
num_trades: i64,
33+
num_messages: i64,
34+
num_cancellations: i64,
35+
num_creations: i64,
36+
num_modifications: i64,
2537
trading_volume: f64,
2638
trading_value: f64,
2739
}
@@ -34,6 +46,140 @@ pub struct BacktestRecorder {
3446
values: Vec<Vec<Record>>,
3547
}
3648

49+
pub static ACCOUNT_STATE_DATA_POINT_FIELDS: Lazy<Vec<arrow::datatypes::Field>> = Lazy::new(|| {
50+
vec![
51+
arrow::datatypes::Field::new("timestamp", DataType::Int64, true),
52+
arrow::datatypes::Field::new("balance", DataType::Float64, true),
53+
arrow::datatypes::Field::new("position", DataType::Float64, true),
54+
arrow::datatypes::Field::new("fee", DataType::Float64, true),
55+
arrow::datatypes::Field::new("trading_volume", DataType::Float64, true),
56+
arrow::datatypes::Field::new("trading_value", DataType::Float64, true),
57+
arrow::datatypes::Field::new("num_trades", DataType::Int64, true),
58+
arrow::datatypes::Field::new("num_messages", DataType::Int64, true),
59+
arrow::datatypes::Field::new("num_cancellations", DataType::Int64, true),
60+
arrow::datatypes::Field::new("num_creations", DataType::Int64, true),
61+
arrow::datatypes::Field::new("num_modifications", DataType::Int64, true),
62+
arrow::datatypes::Field::new("price", DataType::Float64, true),
63+
]
64+
});
65+
66+
pub trait ColumnsBuilder<'a> {
67+
type T;
68+
69+
fn get_batch(&mut self) -> Result<RecordBatch, ArrowError>;
70+
fn append(&mut self, msg: &'a Self::T) -> Result<(), ArrowError>;
71+
72+
fn reset(&mut self) -> Result<(), ArrowError>;
73+
}
74+
75+
pub struct AccountStateDataPointColumnsBuilder {
76+
schema: Schema,
77+
timestamp_builder: PrimitiveBuilder<Int64Type>,
78+
balance_builder: PrimitiveBuilder<Float64Type>,
79+
position_builder: PrimitiveBuilder<Float64Type>,
80+
fee_builder: PrimitiveBuilder<Float64Type>,
81+
trading_volume_builder: PrimitiveBuilder<Float64Type>,
82+
trading_value_builder: PrimitiveBuilder<Float64Type>,
83+
num_trades_builder: PrimitiveBuilder<Int64Type>,
84+
num_messages_builder: PrimitiveBuilder<Int64Type>,
85+
num_cancellations_builder: PrimitiveBuilder<Int64Type>,
86+
num_creations_builder: PrimitiveBuilder<Int64Type>,
87+
num_modifications_builder: PrimitiveBuilder<Int64Type>,
88+
price_builder: PrimitiveBuilder<Float64Type>,
89+
}
90+
91+
pub struct AccountStateDataPoint {
92+
pub timestamp: i64,
93+
pub balance: f64,
94+
pub position: f64,
95+
pub fee: f64,
96+
pub trading_volume: f64,
97+
pub trading_value: f64,
98+
pub num_trades: i64,
99+
pub num_messages: i64,
100+
pub num_cancellations: i64,
101+
pub num_creations: i64,
102+
pub num_modifications: i64,
103+
pub price: f64,
104+
}
105+
106+
impl<'a> ColumnsBuilder<'a> for AccountStateDataPointColumnsBuilder {
107+
type T = AccountStateDataPoint;
108+
109+
fn get_batch(&mut self) -> Result<RecordBatch, ArrowError> {
110+
let arrays: Vec<ArrayRef> = vec![
111+
Arc::new(self.timestamp_builder.finish()),
112+
Arc::new(self.balance_builder.finish()),
113+
Arc::new(self.position_builder.finish()),
114+
Arc::new(self.fee_builder.finish()),
115+
Arc::new(self.trading_volume_builder.finish()),
116+
Arc::new(self.trading_value_builder.finish()),
117+
Arc::new(self.num_trades_builder.finish()),
118+
Arc::new(self.num_messages_builder.finish()),
119+
Arc::new(self.num_cancellations_builder.finish()),
120+
Arc::new(self.num_creations_builder.finish()),
121+
Arc::new(self.num_modifications_builder.finish()),
122+
Arc::new(self.price_builder.finish()),
123+
];
124+
let batch = RecordBatch::try_new(Arc::new(self.schema.clone()), arrays)?;
125+
Ok(batch)
126+
}
127+
128+
fn append(&mut self, msg: &AccountStateDataPoint) -> Result<(), ArrowError> {
129+
self.timestamp_builder.append_value(msg.timestamp);
130+
self.balance_builder.append_value(msg.balance);
131+
self.position_builder.append_value(msg.position);
132+
self.fee_builder.append_value(msg.fee);
133+
self.trading_volume_builder.append_value(msg.trading_volume);
134+
self.trading_value_builder.append_value(msg.trading_value);
135+
self.num_trades_builder.append_value(msg.num_trades);
136+
self.num_messages_builder.append_value(msg.num_messages);
137+
self.num_cancellations_builder
138+
.append_value(msg.num_cancellations);
139+
self.num_creations_builder.append_value(msg.num_creations);
140+
self.num_modifications_builder
141+
.append_value(msg.num_modifications);
142+
self.price_builder.append_value(msg.price);
143+
return Ok(());
144+
}
145+
146+
fn reset(&mut self) -> Result<(), ArrowError> {
147+
self.timestamp_builder = Default::default();
148+
self.balance_builder = Default::default();
149+
self.position_builder = Default::default();
150+
self.fee_builder = Default::default();
151+
self.trading_volume_builder = Default::default();
152+
self.trading_value_builder = Default::default();
153+
self.num_trades_builder = Default::default();
154+
self.num_messages_builder = Default::default();
155+
self.num_cancellations_builder = Default::default();
156+
self.num_creations_builder = Default::default();
157+
self.num_modifications_builder = Default::default();
158+
self.price_builder = Default::default();
159+
return Ok(());
160+
}
161+
}
162+
163+
impl AccountStateDataPointColumnsBuilder {
164+
pub fn new(schema: Schema) -> AccountStateDataPointColumnsBuilder {
165+
AccountStateDataPointColumnsBuilder {
166+
schema,
167+
timestamp_builder: Default::default(),
168+
balance_builder: Default::default(),
169+
position_builder: Default::default(),
170+
fee_builder: Default::default(),
171+
trading_volume_builder: Default::default(),
172+
trading_value_builder: Default::default(),
173+
num_trades_builder: Default::default(),
174+
num_messages_builder: Default::default(),
175+
num_cancellations_builder: Default::default(),
176+
num_creations_builder: Default::default(),
177+
num_modifications_builder: Default::default(),
178+
price_builder: Default::default(),
179+
}
180+
}
181+
}
182+
37183
impl Recorder for BacktestRecorder {
38184
type Error = Error;
39185

@@ -57,6 +203,10 @@ impl Recorder for BacktestRecorder {
57203
trading_volume: state_values.trading_volume,
58204
trading_value: state_values.trading_value,
59205
num_trades: state_values.num_trades,
206+
num_messages: state_values.num_messages,
207+
num_cancellations: state_values.num_cancellations,
208+
num_creations: state_values.num_creations,
209+
num_modifications: state_values.num_modifications,
60210
});
61211
}
62212
Ok(())
@@ -91,38 +241,53 @@ impl BacktestRecorder {
91241
P: AsRef<Path>,
92242
{
93243
let prefix = prefix.as_ref();
244+
let base_path = path.as_ref();
245+
create_dir_all(base_path)?;
246+
247+
// Buffer output to reduce frequent file I/O
94248
for (asset_no, values) in self.values.iter().enumerate() {
95-
let file_path = path.as_ref().join(format!("{prefix}{asset_no}.csv"));
96-
let mut file = BufWriter::new(File::create(file_path)?);
97-
writeln!(
98-
file,
99-
"timestamp,balance,position,fee,trading_volume,trading_value,num_trades,price",
249+
let file_path = base_path.join(format!("{prefix}{asset_no}.csv"));
250+
let mut file = BufWriter::new(File::create(file_path)?); // Use BufWriter for buffered writing
251+
252+
// Write header
253+
file.write_all(
254+
b"timestamp,balance,position,fee,trading_volume,trading_value,num_trades,num_messages,num_cancellations,num_creations,num_modifications,price\n",
100255
)?;
101-
for Record {
102-
timestamp,
103-
balance,
104-
position,
105-
fee,
106-
trading_volume,
107-
trading_value,
108-
num_trades,
109-
price: mid_price,
110-
} in values
111-
{
112-
writeln!(
113-
file,
114-
"{timestamp},{balance},{position},{fee},{trading_volume},{trading_value},{num_trades},{mid_price}"
115-
)?;
256+
257+
// Write records
258+
for record in values {
259+
let line = format!(
260+
"{},{},{},{},{},{},{},{},{},{},{},{}\n",
261+
record.timestamp,
262+
record.balance,
263+
record.position,
264+
record.fee,
265+
record.trading_volume,
266+
record.trading_value,
267+
record.num_trades,
268+
record.num_messages,
269+
record.num_cancellations,
270+
record.num_creations,
271+
record.num_modifications,
272+
record.price,
273+
);
274+
file.write_all(line.as_bytes())?;
116275
}
117276
}
118277
Ok(())
119278
}
120279

121-
pub fn to_npz<P>(&self, path: P) -> Result<(), Error>
280+
pub fn to_npz<Prefix, P>(&self, prefix: Prefix, path: P) -> Result<(), Error>
122281
where
282+
Prefix: AsRef<str>,
123283
P: AsRef<Path>,
124284
{
125-
let file = File::create(path)?;
285+
let prefix = prefix.as_ref();
286+
let base_path = path.as_ref();
287+
create_dir_all(base_path)?;
288+
289+
let file_path = base_path.join(format!("{prefix}.npz"));
290+
let file = File::create(file_path)?;
126291

127292
let mut zip = ZipWriter::new(file);
128293

@@ -138,4 +303,71 @@ impl BacktestRecorder {
138303
zip.finish()?;
139304
Ok(())
140305
}
306+
307+
pub fn to_parquet<Prefix, P>(&self, prefix: Prefix, path: P) -> Result<(), Error>
308+
where
309+
Prefix: AsRef<str>,
310+
P: AsRef<Path>,
311+
{
312+
let prefix = prefix.as_ref();
313+
let base_path = path.as_ref();
314+
create_dir_all(base_path)?;
315+
316+
// Buffer output to reduce frequent file I/O
317+
for (asset_no, values) in self.values.iter().enumerate() {
318+
let parquet_schema = Schema::new(ACCOUNT_STATE_DATA_POINT_FIELDS.clone());
319+
let arrow_schema = Arc::new(parquet_schema.clone());
320+
let parquet_props = WriterProperties::builder()
321+
.set_compression(Compression::SNAPPY)
322+
.build();
323+
324+
let file_path = base_path.join(format!("{prefix}{asset_no}.snappy.parquet"));
325+
let file = File::create(file_path).unwrap();
326+
327+
let mut wrt =
328+
ArrowWriter::try_new(file, arrow_schema.clone(), Some(parquet_props)).unwrap();
329+
330+
let mut builder = AccountStateDataPointColumnsBuilder::new(parquet_schema.clone());
331+
332+
let max_rows_per_batch: usize = 10;
333+
let mut row: usize = 0;
334+
335+
// Write records
336+
for record in values {
337+
row += 1;
338+
let single_row = AccountStateDataPoint {
339+
timestamp: record.timestamp,
340+
balance: record.balance,
341+
position: record.position,
342+
fee: record.fee,
343+
trading_volume: record.trading_volume,
344+
trading_value: record.trading_value,
345+
num_trades: record.num_trades,
346+
num_messages: record.num_messages,
347+
num_cancellations: record.num_cancellations,
348+
num_creations: record.num_creations,
349+
num_modifications: record.num_modifications,
350+
price: record.price,
351+
};
352+
builder.append(&single_row).unwrap();
353+
row += 1;
354+
355+
if row > 0 && row % max_rows_per_batch == 0 {
356+
let batch = builder.get_batch().unwrap();
357+
wrt.write(&batch).unwrap();
358+
builder.reset().unwrap();
359+
}
360+
}
361+
362+
// Write remaining data
363+
{
364+
let batch = builder.get_batch().unwrap();
365+
wrt.write(&batch).unwrap();
366+
builder.reset().unwrap();
367+
}
368+
369+
wrt.close().unwrap();
370+
}
371+
Ok(())
372+
}
141373
}

0 commit comments

Comments
 (0)