Skip to content

Commit 3a1bd9c

Browse files
committed
Smarter formatting default for tool results, better way to customize/opt-out, and better naming/docs
1 parent 973b7aa commit 3a1bd9c

File tree

10 files changed

+116
-44
lines changed

10 files changed

+116
-44
lines changed

chatlas/_anthropic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,12 +472,15 @@ def _as_content_block(content: Content) -> "ContentBlockParam":
472472
"input": content.arguments,
473473
}
474474
elif isinstance(content, ContentToolResult):
475-
return {
475+
res: ToolResultBlockParam = {
476476
"type": "tool_result",
477477
"tool_use_id": content.id,
478-
"content": content.get_final_value(),
479478
"is_error": content.error is not None,
480479
}
480+
# Anthropic supports non-text contents like ImageBlockParam
481+
res["content"] = content.get_final_value() # type: ignore
482+
return res
483+
481484
raise ValueError(f"Unknown content type: {type(content)}")
482485

483486
@staticmethod

chatlas/_chat.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,8 +1100,8 @@ def _chat_impl(
11001100
if req is not None:
11011101
yield req
11021102
res = self._invoke_tool_request(x)
1103-
if res.result and res.result.response_output is not None:
1104-
yield res.result.response_output
1103+
if res.result and res.result.user is not None:
1104+
yield res.result.user
11051105
results.append(res)
11061106

11071107
if results:
@@ -1137,8 +1137,8 @@ async def _chat_impl_async(
11371137
if req is not None:
11381138
yield req
11391139
res = await self._invoke_tool_request_async(x)
1140-
if res.result and res.result.response_output is not None:
1141-
yield res.result.response_output
1140+
if res.result and res.result.user is not None:
1141+
yield res.result.user
11421142
results.append(res)
11431143

11441144
if results:
@@ -1289,7 +1289,7 @@ def _invoke_tool_request(self, x: ContentToolRequest) -> ContentToolResult:
12891289
result = func(x.arguments)
12901290

12911291
if not isinstance(result, ToolResult):
1292-
result = ToolResult(value=result)
1292+
result = ToolResult(result)
12931293

12941294
return ContentToolResult(x.id, result=result, error=None, name=name)
12951295
except Exception as e:
@@ -1319,7 +1319,7 @@ async def _invoke_tool_request_async(
13191319
result = await func(x.arguments)
13201320

13211321
if not isinstance(result, ToolResult):
1322-
result = ToolResult(value=result)
1322+
result = ToolResult(result)
13231323

13241324
return ContentToolResult(x.id, result=result, error=None, name=name)
13251325
except Exception as e:

chatlas/_content.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,17 +207,17 @@ class ContentToolResult(Content):
207207
name: Optional[str] = None
208208
error: Optional[str] = None
209209

210-
def _get_value(self, pretty: bool = False) -> str:
210+
def _get_value(self, pretty: bool = False) -> Stringable:
211211
if self.error:
212212
return f"Tool calling failed with error: '{self.error}'"
213213
result = cast("ToolResult", self.result)
214214
if not pretty:
215-
return result.serialized_value
215+
return result.assistant
216216
try:
217-
json_val = json.loads(result.serialized_value) # type: ignore
217+
json_val = json.loads(result.assistant) # type: ignore
218218
return pformat(json_val, indent=2, sort_dicts=False)
219219
except: # noqa: E722
220-
return result.serialized_value
220+
return result.assistant
221221

222222
# Primarily used for `echo="all"`...
223223
def __str__(self):
@@ -231,14 +231,14 @@ def _repr_markdown_(self):
231231

232232
def __repr__(self, indent: int = 0):
233233
res = " " * indent
234-
value = None if self.result is None else self.result.value
234+
value = None if self.result is None else self.result.assistant
235235
res += f"<ContentToolResult value='{value}' id='{self.id}'"
236236
if self.error:
237237
res += f" error='{self.error}'"
238238
return res + ">"
239239

240240
# The actual value to send to the model
241-
def get_final_value(self) -> str:
241+
def get_final_value(self) -> Stringable:
242242
return self._get_value()
243243

244244

chatlas/_openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,8 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
483483
elif isinstance(x, ContentToolResult):
484484
tool_results.append(
485485
ChatCompletionToolMessageParam(
486-
# TODO: a tool could return an image!?!
487-
content=x.get_final_value(),
486+
# Currently, OpenAI only allows for text content in tool results
487+
content=cast(str, x.get_final_value()),
488488
tool_call_id=x.id,
489489
role="tool",
490490
)

chatlas/_tools.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
import inspect
4+
import json
45
import warnings
5-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Protocol
6+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Literal, Optional, Protocol
67

78
from pydantic import BaseModel, Field, create_model
89

@@ -62,41 +63,69 @@ class ToolResult:
6263
"""
6364
A result from a tool invocation
6465
65-
Return this value from a tool if you want to separate what gets sent
66-
to the model vs what value gets yielded to the user.
66+
Return an instance of this class from a tool function in order to:
67+
68+
1. Yield content for the user (i.e., the downstream consumer of a `.stream()` or `.chat()`)
69+
to display.
70+
2. Control how the tool result gets formatted for the model (i.e., the assistant).
6771
6872
Parameters
6973
----------
70-
value
71-
The tool's return value. If `serialized_value` is not provided, the
72-
string representation of this value is sent to the model.
73-
response_output
74-
A value to yield when the tool is called during response generation. If
75-
`None`, no value is yielded. This is primarily useful for producing
76-
custom UI in the response output to indicate to the user that a tool
77-
call has completed (for example, return shiny UI here when
78-
`.stream()`-ing inside a shiny app).
79-
serialized_value
80-
The serialized value to send to the model. If `None`, the value is serialized
81-
using `str()`. This is useful when the value is not JSON
74+
assistant
75+
The tool result to send to the llm (i.e., assistant). If the result is
76+
not a string, `format_as` determines how to the value is formatted
77+
before sending it to the model.
78+
user
79+
A value to yield to the user (i.e., the consumer of a `.stream()`) when
80+
the tool is called. If `None`, no value is yielded. This is primarily
81+
useful for producing custom UI in the response output to indicate to the
82+
user that a tool call has completed (for example, return shiny UI here
83+
when `.stream()`-ing inside a shiny app).
84+
format_as
85+
How to format the `assistant` value for the model. The default,
86+
`"auto"`, first attempts to format the value as a JSON string. If that
87+
fails, it gets converted to a string via `str()`. To force
88+
`json.dumps()` or `str()`, set to `"json"` or `"str"`. Finally,
89+
`"as_is"` is useful for doing your own formatting and/or passing a
90+
non-string value (e.g., a list or dict) straight to the model.
91+
Non-string values are useful for tools that return images or other
92+
'known' non-text content types.
8293
"""
8394

8495
def __init__(
8596
self,
86-
value: Stringable,
87-
response_output: Optional[Stringable] = None,
88-
serialized_value: Optional[str] = None,
97+
assistant: Stringable,
98+
*,
99+
user: Optional[Stringable] = None,
100+
format_as: Literal["auto", "json", "str", "as_is"] = "auto",
89101
):
90-
self.value = value
91-
self.response_output = response_output
92-
if serialized_value is None:
93-
serialized_value = str(value)
94-
self.serialized_value = serialized_value
102+
# TODO: if called when an active user session, perhaps we could
103+
# provide a smart default here
104+
self.user = user
105+
self.assistant = self._format_value(assistant, format_as)
95106
# TODO: we could consider adding an "emit value" -- that is, the thing to
96107
# display when `echo="all"` is used. I imagine that might be useful for
97108
# advanced users, but let's not worry about it until someone asks for it.
98109
# self.emit = emit
99110

111+
def _format_value(self, value: Stringable, mode: str) -> Stringable:
112+
if isinstance(value, str):
113+
return value
114+
115+
if mode == "auto":
116+
try:
117+
return json.dumps(value)
118+
except Exception:
119+
return str(value)
120+
elif mode == "json":
121+
return json.dumps(value)
122+
elif mode == "str":
123+
return str(value)
124+
elif mode == "as_is":
125+
return value
126+
else:
127+
raise ValueError(f"Unknown format mode: {mode}")
128+
100129

101130
def func_to_schema(
102131
func: Callable[..., Any] | Callable[..., Awaitable[Any]],

tests/conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from typing import Awaitable, Callable
44

55
import pytest
6-
from chatlas import Chat, Turn, content_image_file, content_image_url
76
from PIL import Image
87
from pydantic import BaseModel
98

9+
from chatlas import Chat, Turn, content_image_file, content_image_url
10+
1011
ChatFun = Callable[..., Chat]
1112

1213

@@ -223,3 +224,8 @@ def assert_images_remote_error(chat_fun: ChatFun):
223224
chat.chat("What's in this image?", image_remote)
224225

225226
assert len(chat.get_turns()) == 0
227+
228+
229+
@pytest.fixture
230+
def test_images_dir():
231+
return Path(__file__).parent / "images"

tests/images/dice.png

219 KB
Loading

tests/test_chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_tool_results():
9797

9898
def get_date():
9999
"""Gets the current date"""
100-
return ToolResult("2024-01-01", response_output=["Tool result..."])
100+
return ToolResult("2024-01-01", user=["Tool result..."])
101101

102102
chat.register_tool(get_date)
103103
chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."])
@@ -134,7 +134,7 @@ async def get_date():
134134
import asyncio
135135

136136
await asyncio.sleep(0.1)
137-
return ToolResult("2024-01-01", response_output=["Tool result..."])
137+
return ToolResult("2024-01-01", user=["Tool result..."])
138138

139139
chat.register_tool(get_date)
140140
chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."])

tests/test_content_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def tool():
114114
assert isinstance(res, ContentToolResult)
115115
assert res.id == "x"
116116
assert res.error is None
117-
assert res.result.value == 1
117+
assert res.result.assistant == "1"
118118

119119
res = chat._invoke_tool_request(
120120
ContentToolRequest(id="x", name="tool", arguments={"x": 1})
@@ -149,7 +149,7 @@ async def tool():
149149
assert isinstance(res, ContentToolResult)
150150
assert res.id == "x"
151151
assert res.error is None
152-
assert res.result.value == 1
152+
assert res.result.assistant == "1"
153153

154154
res = await chat._invoke_tool_request_async(
155155
ContentToolRequest(id="x", name="tool", arguments={"x": 1})

tests/test_provider_anthropic.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import base64
2+
13
import pytest
4+
25
from chatlas import ChatAnthropic
36

47
from .conftest import (
@@ -94,3 +97,34 @@ def run_inlineassert():
9497

9598
retryassert(run_inlineassert, retries=3)
9699
assert_images_remote_error(chat_fun)
100+
101+
102+
def test_anthropic_image_tool(test_images_dir):
103+
from chatlas import ToolResult
104+
105+
def get_picture():
106+
"Returns an image"
107+
with open(test_images_dir / "dice.png", "rb") as image:
108+
bytez = image.read()
109+
res = [
110+
{
111+
"type": "image",
112+
"source": {
113+
"type": "base64",
114+
"media_type": "image/png",
115+
"data": base64.b64encode(bytez).decode("utf-8"),
116+
},
117+
}
118+
]
119+
return ToolResult(res, format_as="as_is")
120+
121+
chat = ChatAnthropic()
122+
chat.register_tool(get_picture)
123+
124+
res = chat.chat(
125+
"You have a tool called 'get_picture' available to you. "
126+
"When called, it returns an image. "
127+
"Tell me what you see in the image."
128+
)
129+
130+
assert "dice" in res.get_content()

0 commit comments

Comments
 (0)