Skip to content

Commit 7d7c448

Browse files
authored
fix: add input_tokens in usage of Anthropic messages (#2173)
* Add input tokens * Update test * Fix the tests * Small fixes * Small fixes * Small fixes * Update tests * remove comment line * Update test_chat_generator.py
1 parent aab3afc commit 7d7c448

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,6 @@ def _convert_anthropic_chunk_to_streaming_chunk(
407407
if chunk.delta.type == "text_delta":
408408
content = chunk.delta.text
409409
elif chunk.delta.type == "input_json_delta":
410-
# we assign index=0 because one chunk can have only one ToolCallDelta
411410
tool_calls.append(ToolCallDelta(index=tool_call_index, arguments=chunk.delta.partial_json))
412411
# end of streaming message
413412
elif chunk.type == "message_delta":
@@ -490,12 +489,16 @@ def _process_response(
490489
chunks: List[StreamingChunk] = []
491490
model: Optional[str] = None
492491
tool_call_index = -1
492+
input_tokens = None
493493
component_info = ComponentInfo.from_component(self)
494494
for chunk in response:
495495
if chunk.type in ["message_start", "content_block_start", "content_block_delta", "message_delta"]:
496496
# Extract model from message_start chunks
497497
if chunk.type == "message_start":
498498
model = chunk.message.model
499+
if chunk.message.usage.input_tokens is not None:
500+
input_tokens = chunk.message.usage.input_tokens
501+
499502
if chunk.type == "content_block_start" and chunk.content_block.type == "tool_use":
500503
tool_call_index += 1
501504

@@ -510,6 +513,11 @@ def _process_response(
510513
completion.meta.update(
511514
{"received_at": datetime.now(timezone.utc).isoformat(), "model": model},
512515
)
516+
517+
if input_tokens is not None:
518+
if "usage" not in completion.meta:
519+
completion.meta["usage"] = {}
520+
completion.meta["usage"]["input_tokens"] = input_tokens
513521
return {"replies": [completion]}
514522
else:
515523
return {

integrations/anthropic/tests/test_chat_generator.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
337337
component = AnthropicChatGenerator(api_key=Secret.from_token("test-api-key"))
338338
component_info = ComponentInfo.from_component(component)
339339

340+
raw_chunks = []
341+
340342
# Test message_start chunk
341343
message_start_chunk = RawMessageStartEvent(
342344
message=Message(
@@ -358,6 +360,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
358360
),
359361
type="message_start",
360362
)
363+
raw_chunks.append(message_start_chunk)
361364
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
362365
message_start_chunk, component_info=component_info, tool_call_index=0
363366
)
@@ -373,6 +376,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
373376
text_block_start_chunk = RawContentBlockStartEvent(
374377
content_block=TextBlock(citations=None, text="", type="text"), index=0, type="content_block_start"
375378
)
379+
raw_chunks.append(text_block_start_chunk)
376380
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
377381
text_block_start_chunk, component_info=component_info, tool_call_index=0
378382
)
@@ -390,6 +394,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
390394
index=0,
391395
type="content_block_delta",
392396
)
397+
raw_chunks.append(text_delta_chunk)
393398
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
394399
text_delta_chunk, component_info=component_info, tool_call_index=0
395400
)
@@ -414,6 +419,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
414419
index=1,
415420
type="content_block_start",
416421
)
422+
raw_chunks.append(tool_block_start_chunk)
417423
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
418424
tool_block_start_chunk, component_info=component_info, tool_call_index=0
419425
)
@@ -431,6 +437,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
431437
empty_json_delta_chunk = RawContentBlockDeltaEvent(
432438
delta=InputJSONDelta(partial_json="", type="input_json_delta"), index=1, type="content_block_delta"
433439
)
440+
raw_chunks.append(empty_json_delta_chunk)
434441
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
435442
empty_json_delta_chunk, component_info=component_info, tool_call_index=0
436443
)
@@ -450,6 +457,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
450457
index=1,
451458
type="content_block_delta",
452459
)
460+
raw_chunks.append(json_delta_chunk)
453461
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
454462
json_delta_chunk, component_info=component_info, tool_call_index=0
455463
)
@@ -473,6 +481,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
473481
server_tool_use=None,
474482
),
475483
)
484+
raw_chunks.append(message_delta_chunk)
476485
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
477486
message_delta_chunk, component_info=component_info, tool_call_index=0
478487
)
@@ -496,6 +505,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
496505
index=2,
497506
type="content_block_start",
498507
)
508+
raw_chunks.append(tool_block_start_chunk)
499509
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
500510
tool_block_start_chunk, component_info=component_info, tool_call_index=1
501511
)
@@ -513,6 +523,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
513523
empty_json_delta_chunk = RawContentBlockDeltaEvent(
514524
delta=InputJSONDelta(partial_json="", type="input_json_delta"), index=1, type="content_block_delta"
515525
)
526+
raw_chunks.append(empty_json_delta_chunk)
516527
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
517528
empty_json_delta_chunk, component_info=component_info, tool_call_index=1
518529
)
@@ -532,6 +543,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
532543
index=2,
533544
type="content_block_delta",
534545
)
546+
raw_chunks.append(json_delta_chunk)
535547
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
536548
json_delta_chunk, component_info=component_info, tool_call_index=1
537549
)
@@ -555,6 +567,7 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
555567
server_tool_use=None,
556568
),
557569
)
570+
raw_chunks.append(message_delta_chunk)
558571
streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(
559572
message_delta_chunk, component_info=component_info, tool_call_index=0
560573
)
@@ -574,6 +587,16 @@ def test_convert_anthropic_completion_chunks_with_multiple_tool_calls_to_streami
574587
# message_stop_chunk = RawMessageStopEvent(type="message_stop")
575588
# but we don't stream it
576589

590+
generator = AnthropicChatGenerator(Secret.from_token("test-api-key"))
591+
message = generator._process_response(raw_chunks)
592+
assert message["replies"][0].meta["usage"] == {
593+
"cache_creation_input_tokens": None,
594+
"cache_read_input_tokens": None,
595+
"input_tokens": 393,
596+
"output_tokens": 77,
597+
"server_tool_use": None,
598+
}
599+
577600
def test_convert_streaming_chunks_to_chat_message_with_multiple_tool_calls(self):
578601
"""
579602
Test converting streaming chunks to a chat message with tool calls
@@ -703,7 +726,7 @@ def test_convert_streaming_chunks_to_chat_message_with_multiple_tool_calls(self)
703726
meta={
704727
"type": "message_delta",
705728
"delta": {"stop_reason": "tool_calls", "stop_sequence": None},
706-
"usage": {"completion_tokens": 40},
729+
"usage": {"output_tokens": 40},
707730
},
708731
component_info=ComponentInfo.from_component(self),
709732
finish_reason="tool_calls",
@@ -728,7 +751,7 @@ def test_convert_streaming_chunks_to_chat_message_with_multiple_tool_calls(self)
728751
# Verify meta information
729752
assert message._meta["index"] == 0
730753
assert message._meta["finish_reason"] == "tool_calls"
731-
assert message._meta["usage"] == {"completion_tokens": 40}
754+
assert message._meta["usage"] == {"output_tokens": 40}
732755

733756
def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments(self):
734757
"""
@@ -815,7 +838,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments
815838
meta={
816839
"type": "message_delta",
817840
"delta": {"stop_reason": "tool_calls", "stop_sequence": None},
818-
"usage": {"completion_tokens": 40},
841+
"usage": {"output_tokens": 40},
819842
},
820843
component_info=ComponentInfo.from_component(self),
821844
index=1,
@@ -838,7 +861,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments
838861
# Verify meta information
839862
assert message._meta["index"] == 0
840863
assert message._meta["finish_reason"] == "tool_calls"
841-
assert message._meta["usage"] == {"completion_tokens": 40}
864+
assert message._meta["usage"] == {"output_tokens": 40}
842865

843866
def test_serde_in_pipeline(self):
844867
tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
@@ -971,9 +994,10 @@ def __call__(self, chunk: StreamingChunk) -> None:
971994

972995
assert "claude-sonnet-4-20250514" in message.meta["model"]
973996
assert message.meta["finish_reason"] == "stop"
974-
975997
assert callback.counter > 1
976998
assert "Paris" in callback.responses
999+
assert "input_tokens" in message.meta["usage"]
1000+
assert "output_tokens" in message.meta["usage"]
9771001

9781002
def test_convert_message_to_anthropic_format(self):
9791003
"""
@@ -1171,6 +1195,7 @@ def test_live_run_with_tools(self, tools):
11711195
assert tool_call.tool_name == "weather"
11721196
assert tool_call.arguments == {"city": "Paris"}
11731197
assert message.meta["finish_reason"] == "tool_calls"
1198+
assert "completion_tokens" in message.meta["usage"]
11741199

11751200
new_messages = [
11761201
*initial_messages,
@@ -1268,6 +1293,8 @@ def test_live_run_with_tools_streaming(self, tools):
12681293
assert tool_call.tool_name == "weather"
12691294
assert tool_call.arguments == {"city": "Paris"}
12701295
assert message.meta["finish_reason"] == "tool_calls"
1296+
assert "output_tokens" in message.meta["usage"]
1297+
assert "input_tokens" in message.meta["usage"]
12711298

12721299
new_messages = [
12731300
*initial_messages,
@@ -1673,6 +1700,7 @@ async def test_run_async_with_params(self, chat_messages, mock_anthropic_complet
16731700
assert "Hello! I'm Claude." in response["replies"][0].text
16741701
assert response["replies"][0].meta["model"] == "claude-sonnet-4-20250514"
16751702
assert response["replies"][0].meta["finish_reason"] == "stop"
1703+
assert "completion_tokens" in response["replies"][0].meta["usage"]
16761704

16771705
@pytest.mark.asyncio
16781706
@pytest.mark.skipif(
@@ -1691,6 +1719,7 @@ async def test_live_run_async(self):
16911719
assert "Paris" in message.text
16921720
assert "claude-sonnet-4-20250514" in message.meta["model"]
16931721
assert message.meta["finish_reason"] == "stop"
1722+
assert "completion_tokens" in message.meta["usage"]
16941723

16951724
@pytest.mark.asyncio
16961725
@pytest.mark.skipif(
@@ -1726,6 +1755,8 @@ async def callback(chunk: StreamingChunk) -> None:
17261755
assert "paris" in message.text.lower()
17271756
assert "claude-sonnet-4-20250514" in message.meta["model"]
17281757
assert message.meta["finish_reason"] == "stop"
1758+
assert "input_tokens" in message.meta["usage"]
1759+
assert "output_tokens" in message.meta["usage"]
17291760

17301761
# Verify streaming behavior
17311762
assert counter > 1 # Should have received multiple chunks
@@ -1767,3 +1798,4 @@ async def test_live_run_async_with_tools(self, tools):
17671798
assert not final_message.tool_calls
17681799
assert len(final_message.text) > 0
17691800
assert "paris" in final_message.text.lower()
1801+
assert "completion_tokens" in final_message.meta["usage"]

0 commit comments

Comments
 (0)