diff --git a/hftbacktest/src/backtest/recorder.rs b/hftbacktest/src/backtest/recorder.rs index d26d83f8..e3b7e21c 100644 --- a/hftbacktest/src/backtest/recorder.rs +++ b/hftbacktest/src/backtest/recorder.rs @@ -1,10 +1,18 @@ use std::{ - fs::File, + fs::{File, create_dir_all}, io::{BufWriter, Error, Write}, path::Path, + sync::Arc, }; +use arrow::{ + array::{ArrayRef, PrimitiveBuilder, RecordBatch}, + datatypes::{DataType, Float64Type, Int64Type, Schema}, + error::ArrowError, +}; use hftbacktest_derive::NpyDTyped; +use once_cell::sync::Lazy; +use parquet::{arrow::ArrowWriter, basic::Compression, file::properties::WriterProperties}; use zip::{ZipWriter, write::SimpleFileOptions}; use crate::{ @@ -22,6 +30,10 @@ struct Record { balance: f64, fee: f64, num_trades: i64, + num_messages: i64, + num_cancellations: i64, + num_creations: i64, + num_modifications: i64, trading_volume: f64, trading_value: f64, } @@ -34,6 +46,140 @@ pub struct BacktestRecorder { values: Vec>, } +pub static ACCOUNT_STATE_DATA_POINT_FIELDS: Lazy> = Lazy::new(|| { + vec![ + arrow::datatypes::Field::new("timestamp", DataType::Int64, true), + arrow::datatypes::Field::new("balance", DataType::Float64, true), + arrow::datatypes::Field::new("position", DataType::Float64, true), + arrow::datatypes::Field::new("fee", DataType::Float64, true), + arrow::datatypes::Field::new("trading_volume", DataType::Float64, true), + arrow::datatypes::Field::new("trading_value", DataType::Float64, true), + arrow::datatypes::Field::new("num_trades", DataType::Int64, true), + arrow::datatypes::Field::new("num_messages", DataType::Int64, true), + arrow::datatypes::Field::new("num_cancellations", DataType::Int64, true), + arrow::datatypes::Field::new("num_creations", DataType::Int64, true), + arrow::datatypes::Field::new("num_modifications", DataType::Int64, true), + arrow::datatypes::Field::new("price", DataType::Float64, true), + ] +}); + +pub trait ColumnsBuilder<'a> { + type T; + + fn get_batch(&mut self) -> Result; + fn append(&mut self, msg: &'a Self::T) -> Result<(), ArrowError>; + + fn reset(&mut self) -> Result<(), ArrowError>; +} + +pub struct AccountStateDataPointColumnsBuilder { + schema: Schema, + timestamp_builder: PrimitiveBuilder, + balance_builder: PrimitiveBuilder, + position_builder: PrimitiveBuilder, + fee_builder: PrimitiveBuilder, + trading_volume_builder: PrimitiveBuilder, + trading_value_builder: PrimitiveBuilder, + num_trades_builder: PrimitiveBuilder, + num_messages_builder: PrimitiveBuilder, + num_cancellations_builder: PrimitiveBuilder, + num_creations_builder: PrimitiveBuilder, + num_modifications_builder: PrimitiveBuilder, + price_builder: PrimitiveBuilder, +} + +pub struct AccountStateDataPoint { + pub timestamp: i64, + pub balance: f64, + pub position: f64, + pub fee: f64, + pub trading_volume: f64, + pub trading_value: f64, + pub num_trades: i64, + pub num_messages: i64, + pub num_cancellations: i64, + pub num_creations: i64, + pub num_modifications: i64, + pub price: f64, +} + +impl<'a> ColumnsBuilder<'a> for AccountStateDataPointColumnsBuilder { + type T = AccountStateDataPoint; + + fn get_batch(&mut self) -> Result { + let arrays: Vec = vec![ + Arc::new(self.timestamp_builder.finish()), + Arc::new(self.balance_builder.finish()), + Arc::new(self.position_builder.finish()), + Arc::new(self.fee_builder.finish()), + Arc::new(self.trading_volume_builder.finish()), + Arc::new(self.trading_value_builder.finish()), + Arc::new(self.num_trades_builder.finish()), + Arc::new(self.num_messages_builder.finish()), + Arc::new(self.num_cancellations_builder.finish()), + Arc::new(self.num_creations_builder.finish()), + Arc::new(self.num_modifications_builder.finish()), + Arc::new(self.price_builder.finish()), + ]; + let batch = RecordBatch::try_new(Arc::new(self.schema.clone()), arrays)?; + Ok(batch) + } + + fn append(&mut self, msg: &AccountStateDataPoint) -> Result<(), ArrowError> { + self.timestamp_builder.append_value(msg.timestamp); + self.balance_builder.append_value(msg.balance); + self.position_builder.append_value(msg.position); + self.fee_builder.append_value(msg.fee); + self.trading_volume_builder.append_value(msg.trading_volume); + self.trading_value_builder.append_value(msg.trading_value); + self.num_trades_builder.append_value(msg.num_trades); + self.num_messages_builder.append_value(msg.num_messages); + self.num_cancellations_builder + .append_value(msg.num_cancellations); + self.num_creations_builder.append_value(msg.num_creations); + self.num_modifications_builder + .append_value(msg.num_modifications); + self.price_builder.append_value(msg.price); + return Ok(()); + } + + fn reset(&mut self) -> Result<(), ArrowError> { + self.timestamp_builder = Default::default(); + self.balance_builder = Default::default(); + self.position_builder = Default::default(); + self.fee_builder = Default::default(); + self.trading_volume_builder = Default::default(); + self.trading_value_builder = Default::default(); + self.num_trades_builder = Default::default(); + self.num_messages_builder = Default::default(); + self.num_cancellations_builder = Default::default(); + self.num_creations_builder = Default::default(); + self.num_modifications_builder = Default::default(); + self.price_builder = Default::default(); + return Ok(()); + } +} + +impl AccountStateDataPointColumnsBuilder { + pub fn new(schema: Schema) -> AccountStateDataPointColumnsBuilder { + AccountStateDataPointColumnsBuilder { + schema, + timestamp_builder: Default::default(), + balance_builder: Default::default(), + position_builder: Default::default(), + fee_builder: Default::default(), + trading_volume_builder: Default::default(), + trading_value_builder: Default::default(), + num_trades_builder: Default::default(), + num_messages_builder: Default::default(), + num_cancellations_builder: Default::default(), + num_creations_builder: Default::default(), + num_modifications_builder: Default::default(), + price_builder: Default::default(), + } + } +} + impl Recorder for BacktestRecorder { type Error = Error; @@ -57,6 +203,10 @@ impl Recorder for BacktestRecorder { trading_volume: state_values.trading_volume, trading_value: state_values.trading_value, num_trades: state_values.num_trades, + num_messages: state_values.num_messages, + num_cancellations: state_values.num_cancellations, + num_creations: state_values.num_creations, + num_modifications: state_values.num_modifications, }); } Ok(()) @@ -91,38 +241,53 @@ impl BacktestRecorder { P: AsRef, { let prefix = prefix.as_ref(); + let base_path = path.as_ref(); + create_dir_all(base_path)?; + + // Buffer output to reduce frequent file I/O for (asset_no, values) in self.values.iter().enumerate() { - let file_path = path.as_ref().join(format!("{prefix}{asset_no}.csv")); - let mut file = BufWriter::new(File::create(file_path)?); - writeln!( - file, - "timestamp,balance,position,fee,trading_volume,trading_value,num_trades,price", + let file_path = base_path.join(format!("{prefix}{asset_no}.csv")); + let mut file = BufWriter::new(File::create(file_path)?); // Use BufWriter for buffered writing + + // Write header + file.write_all( + b"timestamp,balance,position,fee,trading_volume,trading_value,num_trades,num_messages,num_cancellations,num_creations,num_modifications,price\n", )?; - for Record { - timestamp, - balance, - position, - fee, - trading_volume, - trading_value, - num_trades, - price: mid_price, - } in values - { - writeln!( - file, - "{timestamp},{balance},{position},{fee},{trading_volume},{trading_value},{num_trades},{mid_price}" - )?; + + // Write records + for record in values { + let line = format!( + "{},{},{},{},{},{},{},{},{},{},{},{}\n", + record.timestamp, + record.balance, + record.position, + record.fee, + record.trading_volume, + record.trading_value, + record.num_trades, + record.num_messages, + record.num_cancellations, + record.num_creations, + record.num_modifications, + record.price, + ); + file.write_all(line.as_bytes())?; } } Ok(()) } - pub fn to_npz

(&self, path: P) -> Result<(), Error> + pub fn to_npz(&self, prefix: Prefix, path: P) -> Result<(), Error> where + Prefix: AsRef, P: AsRef, { - let file = File::create(path)?; + let prefix = prefix.as_ref(); + let base_path = path.as_ref(); + create_dir_all(base_path)?; + + let file_path = base_path.join(format!("{prefix}.npz")); + let file = File::create(file_path)?; let mut zip = ZipWriter::new(file); @@ -138,4 +303,71 @@ impl BacktestRecorder { zip.finish()?; Ok(()) } + + pub fn to_parquet(&self, prefix: Prefix, path: P) -> Result<(), Error> + where + Prefix: AsRef, + P: AsRef, + { + let prefix = prefix.as_ref(); + let base_path = path.as_ref(); + create_dir_all(base_path)?; + + // Buffer output to reduce frequent file I/O + for (asset_no, values) in self.values.iter().enumerate() { + let parquet_schema = Schema::new(ACCOUNT_STATE_DATA_POINT_FIELDS.clone()); + let arrow_schema = Arc::new(parquet_schema.clone()); + let parquet_props = WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(); + + let file_path = base_path.join(format!("{prefix}{asset_no}.snappy.parquet")); + let file = File::create(file_path).unwrap(); + + let mut wrt = + ArrowWriter::try_new(file, arrow_schema.clone(), Some(parquet_props)).unwrap(); + + let mut builder = AccountStateDataPointColumnsBuilder::new(parquet_schema.clone()); + + let max_rows_per_batch: usize = 10; + let mut row: usize = 0; + + // Write records + for record in values { + row += 1; + let single_row = AccountStateDataPoint { + timestamp: record.timestamp, + balance: record.balance, + position: record.position, + fee: record.fee, + trading_volume: record.trading_volume, + trading_value: record.trading_value, + num_trades: record.num_trades, + num_messages: record.num_messages, + num_cancellations: record.num_cancellations, + num_creations: record.num_creations, + num_modifications: record.num_modifications, + price: record.price, + }; + builder.append(&single_row).unwrap(); + row += 1; + + if row > 0 && row % max_rows_per_batch == 0 { + let batch = builder.get_batch().unwrap(); + wrt.write(&batch).unwrap(); + builder.reset().unwrap(); + } + } + + // Write remaining data + { + let batch = builder.get_batch().unwrap(); + wrt.write(&batch).unwrap(); + builder.reset().unwrap(); + } + + wrt.close().unwrap(); + } + Ok(()) + } } diff --git a/hftbacktest/src/live/recorder.rs b/hftbacktest/src/live/recorder.rs index 1cd50788..e5070bc1 100644 --- a/hftbacktest/src/live/recorder.rs +++ b/hftbacktest/src/live/recorder.rs @@ -2,16 +2,14 @@ use std::collections::{HashMap, hash_map::Entry}; use tracing::info; -use crate::{ - depth::MarketDepth, - prelude::{Bot, get_precision}, - types::{Recorder, StateValues}, -}; +use crate::{depth::MarketDepth, prelude::Bot, types::Recorder}; /// Provides logging of the live strategy's state values. #[derive(Default)] pub struct LoggingRecorder { - state: HashMap, + position: HashMap, + symbol: String, + asset_no: usize, } impl Recorder for LoggingRecorder { @@ -22,44 +20,44 @@ impl Recorder for LoggingRecorder { MD: MarketDepth, I: Bot, { - for asset_no in 0..hbt.num_assets() { - let depth = hbt.depth(asset_no); - let price_prec = get_precision(depth.tick_size()); - let mid = (depth.best_bid() + depth.best_ask()) / 2.0; - let state_values = hbt.state_values(asset_no); - let updated = match self.state.entry(asset_no) { - Entry::Occupied(mut entry) => { - let (prev_mid, prev_state_values) = entry.get(); - if (*prev_mid != mid) || (prev_state_values != state_values) { - *entry.get_mut() = (mid, state_values.clone()); - true - } else { - false - } - } - Entry::Vacant(entry) => { - entry.insert((mid, state_values.clone())); + let position = hbt.position(self.asset_no); + + let updated = match self.position.entry(self.asset_no) { + Entry::Occupied(mut entry) => { + let prev_position = entry.get(); + if *prev_position != position { + *entry.get_mut() = position; true + } else { + false } - }; - if updated { - info!( - %asset_no, - %mid, - bid = format!("{:.prec$}", depth.best_bid(), prec = price_prec), - ask = format!("{:.prec$}", depth.best_ask(), prec = price_prec), - ?state_values, - "The state of asset number {asset_no} has been updated." - ); } + Entry::Vacant(entry) => { + entry.insert(position); + true + } + }; + + if updated { + info!( + asset_no = %self.asset_no, + symbol = %self.symbol, + %position, + "Position updated" + ); } + Ok(()) } } impl LoggingRecorder { - /// Constructs an instance of `LoggingRecorder`. - pub fn new() -> Self { - Default::default() + /// Constructs an instance of `LoggingRecorder` for a single symbol. + pub fn new(symbol: String, asset_no: usize) -> Self { + Self { + position: HashMap::new(), + symbol, + asset_no, + } } }