Skip to content

Commit e482f6a

Browse files
committed
First pass at an interface for surfacing tool calls
1 parent 791cc31 commit e482f6a

File tree

5 files changed

+191
-95
lines changed

5 files changed

+191
-95
lines changed

chatlas/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ._provider import Provider
1313
from ._snowflake import ChatSnowflake
1414
from ._tokens import token_usage
15-
from ._tools import Tool
15+
from ._tools import Tool, ToolResult
1616
from ._turn import Turn
1717

1818
try:
@@ -41,6 +41,7 @@
4141
"Provider",
4242
"token_usage",
4343
"Tool",
44+
"ToolResult",
4445
"Turn",
4546
"types",
4647
)

chatlas/_chat.py

Lines changed: 123 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
)
4040
from ._logging import log_tool_error
4141
from ._provider import Provider
42-
from ._tools import Tool
42+
from ._tools import Stringable, Tool, ToolResult
4343
from ._turn import Turn, user_turn
4444
from ._typing_extensions import TypedDict
4545
from ._utils import html_escape, wrap_async
@@ -96,6 +96,9 @@ def __init__(
9696
"rich_console": {},
9797
"css_styles": {},
9898
}
99+
self._on_tool_request_default: Optional[
100+
Callable[[ContentToolRequest], Stringable]
101+
] = None
99102

100103
def get_turns(
101104
self,
@@ -658,7 +661,7 @@ def stream(
658661
kwargs=kwargs,
659662
)
660663

661-
def wrapper() -> Generator[str, None, None]:
664+
def wrapper() -> Generator[Stringable, None, None]:
662665
with display:
663666
for chunk in generator:
664667
yield chunk
@@ -695,7 +698,7 @@ async def stream_async(
695698

696699
display = self._markdown_display(echo=echo)
697700

698-
async def wrapper() -> AsyncGenerator[str, None]:
701+
async def wrapper() -> AsyncGenerator[Stringable, None]:
699702
with display:
700703
async for chunk in self._chat_impl_async(
701704
turn,
@@ -831,6 +834,7 @@ def register_tool(
831834
self,
832835
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
833836
*,
837+
on_request: Optional[Callable[[ContentToolRequest], Stringable]] = None,
834838
model: Optional[type[BaseModel]] = None,
835839
):
836840
"""
@@ -900,16 +904,49 @@ def add(a: int, b: int) -> int:
900904
----------
901905
func
902906
The function to be invoked when the tool is called.
907+
on_request
908+
A callable that will be passed a :class:`~chatlas.ContentToolRequest`
909+
when the tool is requested. If defined, and the callable returns a
910+
stringable object, that value will be yielded to the chat as a part
911+
of the response.
903912
model
904913
A Pydantic model that describes the input parameters for the function.
905914
If not provided, the model will be inferred from the function's type hints.
906915
The primary reason why you might want to provide a model in
907916
Note that the name and docstring of the model takes precedence over the
908917
name and docstring of the function.
909918
"""
910-
tool = Tool(func, model=model)
919+
tool = Tool(func, on_request=on_request, model=model)
911920
self._tools[tool.name] = tool
912921

922+
def on_tool_request(
923+
self,
924+
func: Callable[[ContentToolRequest], Stringable],
925+
):
926+
"""
927+
Register a default function to be invoked when a tool is requested.
928+
929+
This function will be invoked if a tool is requested that does not have
930+
a specific `on_request` function defined.
931+
932+
Parameters
933+
----------
934+
func
935+
A callable that will be passed a :class:`~chatlas.ContentToolRequest`
936+
when the tool is requested. If defined, and the callable returns a
937+
stringable object, that value will be yielded to the chat as a part
938+
of the response.
939+
"""
940+
self._on_tool_request_default = func
941+
942+
def _on_tool_request(self, req: ContentToolRequest) -> Stringable | None:
943+
tool_def = self._tools.get(req.name, None)
944+
if tool_def and tool_def.on_request:
945+
return tool_def.on_request(req)
946+
if self._on_tool_request_default:
947+
return self._on_tool_request_default(req)
948+
return None
949+
913950
def export(
914951
self,
915952
filename: str | Path,
@@ -1040,7 +1077,7 @@ def _chat_impl(
10401077
display: MarkdownDisplay,
10411078
stream: bool,
10421079
kwargs: Optional[SubmitInputArgsT] = None,
1043-
) -> Generator[str, None, None]:
1080+
) -> Generator[Stringable, None, None]:
10441081
user_turn_result: Turn | None = user_turn
10451082
while user_turn_result is not None:
10461083
for chunk in self._submit_turns(
@@ -1051,7 +1088,24 @@ def _chat_impl(
10511088
kwargs=kwargs,
10521089
):
10531090
yield chunk
1054-
user_turn_result = self._invoke_tools()
1091+
1092+
turn = self.get_last_turn(role="assistant")
1093+
assert turn is not None
1094+
user_turn_result = None
1095+
1096+
results: list[ContentToolResult] = []
1097+
for x in turn.contents:
1098+
if isinstance(x, ContentToolRequest):
1099+
req = self._on_tool_request(x)
1100+
if req is not None:
1101+
yield req
1102+
result, output = self._invoke_tool_request(x)
1103+
if output is not None:
1104+
yield output
1105+
results.append(result)
1106+
1107+
if results:
1108+
user_turn_result = Turn("user", results)
10551109

10561110
async def _chat_impl_async(
10571111
self,
@@ -1060,7 +1114,7 @@ async def _chat_impl_async(
10601114
display: MarkdownDisplay,
10611115
stream: bool,
10621116
kwargs: Optional[SubmitInputArgsT] = None,
1063-
) -> AsyncGenerator[str, None]:
1117+
) -> AsyncGenerator[Stringable, None]:
10641118
user_turn_result: Turn | None = user_turn
10651119
while user_turn_result is not None:
10661120
async for chunk in self._submit_turns_async(
@@ -1071,7 +1125,24 @@ async def _chat_impl_async(
10711125
kwargs=kwargs,
10721126
):
10731127
yield chunk
1074-
user_turn_result = await self._invoke_tools_async()
1128+
1129+
turn = self.get_last_turn(role="assistant")
1130+
assert turn is not None
1131+
user_turn_result = None
1132+
1133+
results: list[ContentToolResult] = []
1134+
for x in turn.contents:
1135+
if isinstance(x, ContentToolRequest):
1136+
req = self._on_tool_request(x)
1137+
if req is not None:
1138+
yield req
1139+
result, output = await self._invoke_tool_request_async(x)
1140+
if output is not None:
1141+
yield output
1142+
results.append(result)
1143+
1144+
if results:
1145+
user_turn_result = Turn("user", results)
10751146

10761147
def _submit_turns(
10771148
self,
@@ -1085,7 +1156,7 @@ def _submit_turns(
10851156
if any(x._is_async for x in self._tools.values()):
10861157
raise ValueError("Cannot use async tools in a synchronous chat")
10871158

1088-
def emit(text: str | Content):
1159+
def emit(text: Stringable):
10891160
display.update(str(text))
10901161

10911162
emit("<br>\n\n")
@@ -1148,7 +1219,7 @@ async def _submit_turns_async(
11481219
data_model: type[BaseModel] | None = None,
11491220
kwargs: Optional[SubmitInputArgsT] = None,
11501221
) -> AsyncGenerator[str, None]:
1151-
def emit(text: str | Content):
1222+
def emit(text: Stringable):
11521223
display.update(str(text))
11531224

11541225
emit("<br>\n\n")
@@ -1202,88 +1273,62 @@ def emit(text: str | Content):
12021273

12031274
self._turns.extend([user_turn, turn])
12041275

1205-
def _invoke_tools(self) -> Turn | None:
1206-
turn = self.get_last_turn()
1207-
if turn is None:
1208-
return None
1209-
1210-
results: list[ContentToolResult] = []
1211-
for x in turn.contents:
1212-
if isinstance(x, ContentToolRequest):
1213-
tool_def = self._tools.get(x.name, None)
1214-
func = tool_def.func if tool_def is not None else None
1215-
results.append(self._invoke_tool(func, x.arguments, x.id))
1216-
1217-
if not results:
1218-
return None
1276+
def _invoke_tool_request(
1277+
self, x: ContentToolRequest
1278+
) -> tuple[ContentToolResult, Stringable]:
1279+
tool_def = self._tools.get(x.name, None)
1280+
func = tool_def.func if tool_def is not None else None
12191281

1220-
return Turn("user", results)
1221-
1222-
async def _invoke_tools_async(self) -> Turn | None:
1223-
turn = self.get_last_turn()
1224-
if turn is None:
1225-
return None
1226-
1227-
results: list[ContentToolResult] = []
1228-
for x in turn.contents:
1229-
if isinstance(x, ContentToolRequest):
1230-
tool_def = self._tools.get(x.name, None)
1231-
func = None
1232-
if tool_def:
1233-
if tool_def._is_async:
1234-
func = tool_def.func
1235-
else:
1236-
func = wrap_async(tool_def.func)
1237-
results.append(await self._invoke_tool_async(func, x.arguments, x.id))
1238-
1239-
if not results:
1240-
return None
1241-
1242-
return Turn("user", results)
1243-
1244-
@staticmethod
1245-
def _invoke_tool(
1246-
func: Callable[..., Any] | None,
1247-
arguments: object,
1248-
id_: str,
1249-
) -> ContentToolResult:
12501282
if func is None:
1251-
return ContentToolResult(id_, value=None, error="Unknown tool")
1283+
return ContentToolResult(x.id, value=None, error="Unknown tool"), None
12521284

12531285
name = func.__name__
12541286

12551287
try:
1256-
if isinstance(arguments, dict):
1257-
result = func(**arguments)
1288+
if isinstance(x.arguments, dict):
1289+
result = func(**x.arguments)
12581290
else:
1259-
result = func(arguments)
1291+
result = func(x.arguments)
12601292

1261-
return ContentToolResult(id_, value=result, error=None, name=name)
1293+
value, output = (result, None)
1294+
if isinstance(result, ToolResult):
1295+
value, output = (result.assistant, result.output)
1296+
1297+
return ContentToolResult(x.id, value=value, error=None, name=name), output
12621298
except Exception as e:
1263-
log_tool_error(name, str(arguments), e)
1264-
return ContentToolResult(id_, value=None, error=str(e), name=name)
1299+
log_tool_error(name, str(x.arguments), e)
1300+
return ContentToolResult(x.id, value=None, error=str(e), name=name), None
1301+
1302+
async def _invoke_tool_request_async(
1303+
self, x: ContentToolRequest
1304+
) -> tuple[ContentToolResult, Stringable]:
1305+
tool_def = self._tools.get(x.name, None)
1306+
func = None
1307+
if tool_def:
1308+
if tool_def._is_async:
1309+
func = tool_def.func
1310+
else:
1311+
func = wrap_async(tool_def.func)
12651312

1266-
@staticmethod
1267-
async def _invoke_tool_async(
1268-
func: Callable[..., Awaitable[Any]] | None,
1269-
arguments: object,
1270-
id_: str,
1271-
) -> ContentToolResult:
12721313
if func is None:
1273-
return ContentToolResult(id_, value=None, error="Unknown tool")
1314+
return ContentToolResult(x.id, value=None, error="Unknown tool"), None
12741315

12751316
name = func.__name__
12761317

12771318
try:
1278-
if isinstance(arguments, dict):
1279-
result = await func(**arguments)
1319+
if isinstance(x.arguments, dict):
1320+
result = await func(**x.arguments)
12801321
else:
1281-
result = await func(arguments)
1322+
result = await func(x.arguments)
1323+
1324+
value, output = (result, None)
1325+
if isinstance(result, ToolResult):
1326+
value, output = (result.assistant, result.output)
12821327

1283-
return ContentToolResult(id_, value=result, error=None, name=name)
1328+
return ContentToolResult(x.id, value=value, error=None, name=name), output
12841329
except Exception as e:
1285-
log_tool_error(func.__name__, str(arguments), e)
1286-
return ContentToolResult(id_, value=None, error=str(e), name=name)
1330+
log_tool_error(func.__name__, str(x.arguments), e)
1331+
return ContentToolResult(x.id, value=None, error=str(e), name=name), None
12871332

12881333
def _markdown_display(
12891334
self, echo: Literal["text", "all", "none"]
@@ -1378,15 +1423,15 @@ class ChatResponse:
13781423
still be retrieved (via the `content` attribute).
13791424
"""
13801425

1381-
def __init__(self, generator: Generator[str, None]):
1426+
def __init__(self, generator: Generator[Stringable, None]):
13821427
self._generator = generator
13831428
self.content: str = ""
13841429

13851430
def __iter__(self) -> Iterator[str]:
13861431
return self
13871432

13881433
def __next__(self) -> str:
1389-
chunk = next(self._generator)
1434+
chunk = str(next(self._generator))
13901435
self.content += chunk # Keep track of accumulated content
13911436
return chunk
13921437

@@ -1430,15 +1475,15 @@ class ChatResponseAsync:
14301475
still be retrieved (via the `content` attribute).
14311476
"""
14321477

1433-
def __init__(self, generator: AsyncGenerator[str, None]):
1478+
def __init__(self, generator: AsyncGenerator[Stringable, None]):
14341479
self._generator = generator
14351480
self.content: str = ""
14361481

14371482
def __aiter__(self) -> AsyncIterator[str]:
14381483
return self
14391484

14401485
async def __anext__(self) -> str:
1441-
chunk = await self._generator.__anext__()
1486+
chunk = str(await self._generator.__anext__())
14421487
self.content += chunk # Keep track of accumulated content
14431488
return chunk
14441489

0 commit comments

Comments
 (0)