Skip to content

Commit 7e460e6

Browse files
lukeramsdenaaronvg
andauthored
Add cached input token tracking to Usage reporting (#2394)
## Summary This PR adds comprehensive cached input token tracking to the BAML Usage reporting system. The Collector and Usage now track and report cached input tokens alongside the existing input and output tokens. ### Changes Made - **Core Usage struct**: Added `cached_input_tokens: Option<i64>` field to track cached tokens - **LLM Provider Integration**: Implemented cached token extraction for all supported providers: - **Anthropic**: Extracts from `cache_read_input_tokens` - **OpenAI**: Extracts from `input_tokens_details.cached_tokens` - **Google/Vertex**: Uses `cached_content_token_count` field - **AWS Bedrock**: Set to None (no cached token support currently in the SDK version BAML uses - and there is some sort of dependency issue when upgrading, see Cargo.toml) - **Token Aggregation**: Updated all token aggregation logic in Collector and FunctionLog to sum cached tokens - **Language Bindings**: Added cached token support to all client libraries: - TypeScript: `usage.cachedInputTokens` - Python: `usage.cached_input_tokens` - Go: `usage.CachedInputTokens()` - Ruby: `usage.cached_input_tokens` - **RPC Integration**: Updated RPC types and converters to include cached token data ### Test Plan - [x] Core library compilation verified - [x] All provider response handlers updated with cached token extraction - [x] Language binding interfaces expanded with cached token accessors - [x] Token aggregation logic preserves cached token counts across multiple calls - [x] RPC serialization includes cached token data ### Technical Notes - Cached tokens are tracked separately from input/output tokens for better cost analysis - Provider-specific token extraction handles cases where cached token data is unavailable - All changes are backward compatible with existing Usage API - Language bindings maintain consistent naming conventions across all supported languages --- Closes #2349 <!-- ELLIPSIS_HIDDEN --> ---- > [!IMPORTANT] > Add cached input token tracking to BAML Usage reporting, updating core structures, provider integrations, token aggregation, language bindings, and tests. > > - **Behavior**: > - Added `cached_input_tokens` field to `LLMUsage` in `events.rs`, `trace_event.rs`, and `mod.rs` to track cached tokens. > - Implemented cached token extraction for providers: `Anthropic` (from `cache_read_input_tokens`), `OpenAI` (from `input_tokens_details.cached_tokens`), `Google/Vertex` (from `cached_content_token_count`), and `AWS Bedrock` (set to None). > - Updated token aggregation logic in `storage.rs` and `llm_response_to_log_event.rs` to sum cached tokens. > - **Language Bindings**: > - Added cached token support to TypeScript (`native.d.ts`), Python (`log_collector.rs`), Go (`rawobjects_public.go`), and Ruby (`log_collector.rs`). > - **RPC Integration**: > - Updated RPC types and converters in `trace_data.rs` to include cached token data. > - **Tests**: > - Added tests in `test_collector.py` and `collector.test.ts` to verify cached token tracking for various providers and scenarios. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for 8fa77ed. You can [customize](https://app.ellipsis.dev/BoundaryML/settings/summaries) this summary. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN --> --------- Co-authored-by: aaronvg <[email protected]>
1 parent 013577d commit 7e460e6

File tree

36 files changed

+1603
-1873
lines changed

36 files changed

+1603
-1873
lines changed

engine/baml-lib/baml-types/src/tracing/events.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ pub struct LLMUsage {
593593
pub input_tokens: Option<u64>,
594594
pub output_tokens: Option<u64>,
595595
pub total_tokens: Option<u64>,
596+
pub cached_input_tokens: Option<u64>,
596597
}
597598

598599
#[cfg(test)]

engine/baml-rpc/src/runtime_api/trace_event.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,5 @@ pub struct LLMUsage {
183183
pub input_tokens: Option<u64>,
184184
pub output_tokens: Option<u64>,
185185
pub total_tokens: Option<u64>,
186+
pub cached_input_tokens: Option<u64>,
186187
}

engine/baml-runtime/src/internal/llm_client/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ pub struct LLMCompleteResponseMetadata {
294294
pub prompt_tokens: Option<u64>,
295295
pub output_tokens: Option<u64>,
296296
pub total_tokens: Option<u64>,
297+
pub cached_input_tokens: Option<u64>,
297298
}
298299

299300
// This is how the response gets logged if you print the result to the console.

engine/baml-runtime/src/internal/llm_client/primitive/anthropic/response_handler.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ pub fn parse_anthropic_response<C: WithClient + RequestBuilder>(
9090
prompt_tokens: Some(response.usage.input_tokens),
9191
output_tokens: Some(response.usage.output_tokens),
9292
total_tokens: Some(response.usage.input_tokens + response.usage.output_tokens),
93+
cached_input_tokens: response.usage.cache_read_input_tokens,
9394
},
9495
})
9596
}
@@ -137,6 +138,7 @@ pub fn scan_anthropic_response_stream(
137138
inner.prompt_tokens = Some(body.usage.input_tokens);
138139
inner.output_tokens = Some(body.usage.output_tokens);
139140
inner.total_tokens = Some(body.usage.input_tokens + body.usage.output_tokens);
141+
inner.cached_input_tokens = body.usage.cache_read_input_tokens;
140142
}
141143
MessageChunk::ContentBlockDelta(event) => {
142144
if let super::types::ContentBlockDelta::TextDelta { text } = event.delta {
@@ -153,6 +155,7 @@ pub fn scan_anthropic_response_stream(
153155
inner.finish_reason = body.delta.stop_reason.clone();
154156
inner.output_tokens = Some(body.usage.output_tokens);
155157
inner.total_tokens = Some(inner.prompt_tokens.unwrap_or(0) + body.usage.output_tokens);
158+
inner.cached_input_tokens = body.usage.cache_read_input_tokens;
156159
}
157160
MessageChunk::MessageStop => (),
158161
MessageChunk::Error { error } => {
@@ -218,6 +221,7 @@ mod tests {
218221
prompt_tokens: Some(321),
219222
output_tokens: Some(158),
220223
total_tokens: Some(479),
224+
cached_input_tokens: Some(0),
221225
},
222226
};
223227

engine/baml-runtime/src/internal/llm_client/primitive/anthropic/types.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ pub enum AnthropicMessageContent {
3434
pub struct AnthropicUsage {
3535
pub input_tokens: u64,
3636
pub output_tokens: u64,
37+
pub cache_creation_input_tokens: Option<u64>,
38+
pub cache_read_input_tokens: Option<u64>,
3739
}
3840

3941
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
@@ -200,7 +202,7 @@ pub struct MessageDeltaChunk {
200202
/// The result of this stream.
201203
pub delta: StreamStop,
202204
/// The billing and rate-limit usage of this stream.
203-
pub usage: DeltaUsage,
205+
pub usage: AnthropicUsage,
204206
}
205207

206208
/// The text delta content block.
@@ -222,13 +224,6 @@ pub struct StreamStop {
222224
pub stop_sequence: Option<StopSequence>,
223225
}
224226

225-
/// The delta usage of the stream.
226-
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
227-
pub struct DeltaUsage {
228-
/// The number of output tokens which were used.
229-
pub output_tokens: u64,
230-
}
231-
232227
#[cfg(test)]
233228
mod tests {
234229
use anyhow::Result;

engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ impl WithStreamChat for AwsClient {
899899
prompt_tokens: None,
900900
output_tokens: None,
901901
total_tokens: None,
902+
cached_input_tokens: None,
902903
},
903904
}),
904905
response,
@@ -962,6 +963,8 @@ impl WithStreamChat for AwsClient {
962963
Some(usage.output_tokens() as u64);
963964
new_state.metadata.total_tokens =
964965
Some((usage.total_tokens()) as u64);
966+
// AWS Bedrock does not currently support cached tokens
967+
new_state.metadata.cached_input_tokens = None;
965968
}
966969
}
967970
_ => {
@@ -1303,6 +1306,7 @@ impl WithChat for AwsClient {
13031306
.usage
13041307
.as_ref()
13051308
.and_then(|i| i.total_tokens.try_into().ok()),
1309+
cached_input_tokens: None, // AWS Bedrock does not currently support cached tokens
13061310
},
13071311
}),
13081312
Err(e) => LLMResponse::LLMFailure(LLMErrorResponse {

engine/baml-runtime/src/internal/llm_client/primitive/google/response_handler.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ pub fn parse_google_response<C: WithClient + RequestBuilder>(
9797
prompt_tokens: response.usage_metadata.prompt_token_count,
9898
output_tokens: response.usage_metadata.candidates_token_count,
9999
total_tokens: response.usage_metadata.total_token_count,
100+
cached_input_tokens: response.usage_metadata.cached_content_token_count,
100101
},
101102
})
102103
}
@@ -171,6 +172,7 @@ pub fn scan_google_response_stream(
171172
inner.metadata.prompt_tokens = event.usage_metadata.prompt_token_count;
172173
inner.metadata.output_tokens = event.usage_metadata.candidates_token_count;
173174
inner.metadata.total_tokens = event.usage_metadata.total_token_count;
175+
inner.metadata.cached_input_tokens = event.usage_metadata.cached_content_token_count;
174176

175177
inner.latency = instant_now.elapsed();
176178
Ok(())
@@ -285,6 +287,7 @@ mod tests {
285287
prompt_token_count: Some(166),
286288
candidates_token_count: Some(39),
287289
total_token_count: Some(205),
290+
cached_content_token_count: None,
288291
},
289292
};
290293

@@ -331,6 +334,7 @@ mod tests {
331334
prompt_tokens: Some(166),
332335
output_tokens: Some(39),
333336
total_tokens: Some(205),
337+
cached_input_tokens: None,
334338
},
335339
};
336340

engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ pub struct UsageMetaData {
338338
pub prompt_token_count: Option<u64>,
339339
pub candidates_token_count: Option<u64>,
340340
pub total_token_count: Option<u64>,
341+
pub cached_content_token_count: Option<u64>,
341342
}
342343

343344
#[cfg(test)]

engine/baml-runtime/src/internal/llm_client/primitive/openai/response_handler.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ pub fn parse_openai_response<C: WithClient + RequestBuilder>(
9292
prompt_tokens: usage.map(|u| u.prompt_tokens),
9393
output_tokens: usage.map(|u| u.completion_tokens),
9494
total_tokens: usage.map(|u| u.total_tokens),
95+
cached_input_tokens: usage.and_then(|u| {
96+
// Extract cached tokens from input_tokens_details if available
97+
u.input_tokens_details
98+
.as_ref()
99+
.and_then(|details| details.get("cached_tokens"))
100+
.and_then(|cached| cached.as_u64())
101+
}),
95102
},
96103
})
97104
}
@@ -143,6 +150,12 @@ pub fn scan_openai_chat_completion_stream(
143150
inner.metadata.prompt_tokens = Some(usage.prompt_tokens);
144151
inner.metadata.output_tokens = Some(usage.completion_tokens);
145152
inner.metadata.total_tokens = Some(usage.total_tokens);
153+
inner.metadata.cached_input_tokens =
154+
usage.input_tokens_details.as_ref().and_then(|details| {
155+
details
156+
.get("cached_tokens")
157+
.and_then(|cached| cached.as_u64())
158+
})
146159
}
147160

148161
Ok(())
@@ -226,6 +239,7 @@ mod tests {
226239
prompt_tokens: Some(128),
227240
output_tokens: Some(71),
228241
total_tokens: Some(199),
242+
cached_input_tokens: Some(0),
229243
},
230244
};
231245

@@ -322,6 +336,13 @@ pub fn parse_openai_responses_response<C: WithClient + RequestBuilder>(
322336
prompt_tokens: usage.map(|u| u.prompt_tokens),
323337
output_tokens: usage.map(|u| u.completion_tokens),
324338
total_tokens: usage.map(|u| u.total_tokens),
339+
cached_input_tokens: usage.and_then(|u| {
340+
// Extract cached tokens from input_tokens_details if available
341+
u.input_tokens_details
342+
.as_ref()
343+
.and_then(|details| details.get("cached_tokens"))
344+
.and_then(|cached| cached.as_u64())
345+
}),
325346
},
326347
})
327348
}
@@ -390,6 +411,12 @@ pub fn scan_openai_responses_stream(
390411
inner.metadata.prompt_tokens = Some(usage.prompt_tokens);
391412
inner.metadata.output_tokens = Some(usage.completion_tokens);
392413
inner.metadata.total_tokens = Some(usage.total_tokens);
414+
inner.metadata.cached_input_tokens =
415+
usage.input_tokens_details.as_ref().and_then(|details| {
416+
details
417+
.get("cached_tokens")
418+
.and_then(|cached| cached.as_u64())
419+
})
393420
}
394421
}
395422
ResponseFailed { response, .. } => {
@@ -441,6 +468,12 @@ pub fn scan_openai_responses_stream(
441468
inner.metadata.prompt_tokens = Some(usage.prompt_tokens);
442469
inner.metadata.output_tokens = Some(usage.completion_tokens);
443470
inner.metadata.total_tokens = Some(usage.total_tokens);
471+
inner.metadata.cached_input_tokens =
472+
usage.input_tokens_details.as_ref().and_then(|details| {
473+
details
474+
.get("cached_tokens")
475+
.and_then(|cached| cached.as_u64())
476+
})
444477
}
445478
}
446479
OutputTextDelta { delta, .. } => {
@@ -507,6 +540,7 @@ mod responses_tests {
507540
prompt_tokens: Some(36),
508541
output_tokens: Some(87),
509542
total_tokens: Some(123),
543+
cached_input_tokens: Some(0),
510544
},
511545
};
512546

engine/baml-runtime/src/internal/llm_client/primitive/openai/types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ pub struct CompletionUsage {
255255
/// Total number of tokens used in the request (prompt + completion).
256256
pub total_tokens: u64,
257257
/// Additional fields that may be present in responses API
258+
#[serde(alias = "prompt_tokens_details")]
258259
pub input_tokens_details: Option<serde_json::Value>,
260+
#[serde(alias = "completion_tokens_details")]
259261
pub output_tokens_details: Option<serde_json::Value>,
260262
}
261263

0 commit comments

Comments
 (0)