39
39
)
40
40
from ._logging import log_tool_error
41
41
from ._provider import Provider
42
- from ._tools import Tool
42
+ from ._tools import Stringable , Tool , ToolResult
43
43
from ._turn import Turn , user_turn
44
44
from ._typing_extensions import TypedDict
45
45
from ._utils import html_escape , wrap_async
@@ -96,6 +96,9 @@ def __init__(
96
96
"rich_console" : {},
97
97
"css_styles" : {},
98
98
}
99
+ self ._on_tool_request_default : Optional [
100
+ Callable [[ContentToolRequest ], Stringable ]
101
+ ] = None
99
102
100
103
def get_turns (
101
104
self ,
@@ -658,7 +661,7 @@ def stream(
658
661
kwargs = kwargs ,
659
662
)
660
663
661
- def wrapper () -> Generator [str , None , None ]:
664
+ def wrapper () -> Generator [Stringable , None , None ]:
662
665
with display :
663
666
for chunk in generator :
664
667
yield chunk
@@ -695,7 +698,7 @@ async def stream_async(
695
698
696
699
display = self ._markdown_display (echo = echo )
697
700
698
- async def wrapper () -> AsyncGenerator [str , None ]:
701
+ async def wrapper () -> AsyncGenerator [Stringable , None ]:
699
702
with display :
700
703
async for chunk in self ._chat_impl_async (
701
704
turn ,
@@ -831,6 +834,7 @@ def register_tool(
831
834
self ,
832
835
func : Callable [..., Any ] | Callable [..., Awaitable [Any ]],
833
836
* ,
837
+ on_request : Optional [Callable [[ContentToolRequest ], Stringable ]] = None ,
834
838
model : Optional [type [BaseModel ]] = None ,
835
839
):
836
840
"""
@@ -900,16 +904,49 @@ def add(a: int, b: int) -> int:
900
904
----------
901
905
func
902
906
The function to be invoked when the tool is called.
907
+ on_request
908
+ A callable that will be passed a :class:`~chatlas.ContentToolRequest`
909
+ when the tool is requested. If defined, and the callable returns a
910
+ stringable object, that value will be yielded to the chat as a part
911
+ of the response.
903
912
model
904
913
A Pydantic model that describes the input parameters for the function.
905
914
If not provided, the model will be inferred from the function's type hints.
906
915
The primary reason why you might want to provide a model in
907
916
Note that the name and docstring of the model takes precedence over the
908
917
name and docstring of the function.
909
918
"""
910
- tool = Tool (func , model = model )
919
+ tool = Tool (func , on_request = on_request , model = model )
911
920
self ._tools [tool .name ] = tool
912
921
922
+ def on_tool_request (
923
+ self ,
924
+ func : Callable [[ContentToolRequest ], Stringable ],
925
+ ):
926
+ """
927
+ Register a default function to be invoked when a tool is requested.
928
+
929
+ This function will be invoked if a tool is requested that does not have
930
+ a specific `on_request` function defined.
931
+
932
+ Parameters
933
+ ----------
934
+ func
935
+ A callable that will be passed a :class:`~chatlas.ContentToolRequest`
936
+ when the tool is requested. If defined, and the callable returns a
937
+ stringable object, that value will be yielded to the chat as a part
938
+ of the response.
939
+ """
940
+ self ._on_tool_request_default = func
941
+
942
+ def _on_tool_request (self , req : ContentToolRequest ) -> Stringable | None :
943
+ tool_def = self ._tools .get (req .name , None )
944
+ if tool_def and tool_def .on_request :
945
+ return tool_def .on_request (req )
946
+ if self ._on_tool_request_default :
947
+ return self ._on_tool_request_default (req )
948
+ return None
949
+
913
950
def export (
914
951
self ,
915
952
filename : str | Path ,
@@ -1040,7 +1077,7 @@ def _chat_impl(
1040
1077
display : MarkdownDisplay ,
1041
1078
stream : bool ,
1042
1079
kwargs : Optional [SubmitInputArgsT ] = None ,
1043
- ) -> Generator [str , None , None ]:
1080
+ ) -> Generator [Stringable , None , None ]:
1044
1081
user_turn_result : Turn | None = user_turn
1045
1082
while user_turn_result is not None :
1046
1083
for chunk in self ._submit_turns (
@@ -1051,7 +1088,24 @@ def _chat_impl(
1051
1088
kwargs = kwargs ,
1052
1089
):
1053
1090
yield chunk
1054
- user_turn_result = self ._invoke_tools ()
1091
+
1092
+ turn = self .get_last_turn (role = "assistant" )
1093
+ assert turn is not None
1094
+ user_turn_result = None
1095
+
1096
+ results : list [ContentToolResult ] = []
1097
+ for x in turn .contents :
1098
+ if isinstance (x , ContentToolRequest ):
1099
+ req = self ._on_tool_request (x )
1100
+ if req is not None :
1101
+ yield req
1102
+ result , output = self ._invoke_tool_request (x )
1103
+ if output is not None :
1104
+ yield output
1105
+ results .append (result )
1106
+
1107
+ if results :
1108
+ user_turn_result = Turn ("user" , results )
1055
1109
1056
1110
async def _chat_impl_async (
1057
1111
self ,
@@ -1060,7 +1114,7 @@ async def _chat_impl_async(
1060
1114
display : MarkdownDisplay ,
1061
1115
stream : bool ,
1062
1116
kwargs : Optional [SubmitInputArgsT ] = None ,
1063
- ) -> AsyncGenerator [str , None ]:
1117
+ ) -> AsyncGenerator [Stringable , None ]:
1064
1118
user_turn_result : Turn | None = user_turn
1065
1119
while user_turn_result is not None :
1066
1120
async for chunk in self ._submit_turns_async (
@@ -1071,7 +1125,24 @@ async def _chat_impl_async(
1071
1125
kwargs = kwargs ,
1072
1126
):
1073
1127
yield chunk
1074
- user_turn_result = await self ._invoke_tools_async ()
1128
+
1129
+ turn = self .get_last_turn (role = "assistant" )
1130
+ assert turn is not None
1131
+ user_turn_result = None
1132
+
1133
+ results : list [ContentToolResult ] = []
1134
+ for x in turn .contents :
1135
+ if isinstance (x , ContentToolRequest ):
1136
+ req = self ._on_tool_request (x )
1137
+ if req is not None :
1138
+ yield req
1139
+ result , output = await self ._invoke_tool_request_async (x )
1140
+ if output is not None :
1141
+ yield output
1142
+ results .append (result )
1143
+
1144
+ if results :
1145
+ user_turn_result = Turn ("user" , results )
1075
1146
1076
1147
def _submit_turns (
1077
1148
self ,
@@ -1085,7 +1156,7 @@ def _submit_turns(
1085
1156
if any (x ._is_async for x in self ._tools .values ()):
1086
1157
raise ValueError ("Cannot use async tools in a synchronous chat" )
1087
1158
1088
- def emit (text : str | Content ):
1159
+ def emit (text : Stringable ):
1089
1160
display .update (str (text ))
1090
1161
1091
1162
emit ("<br>\n \n " )
@@ -1148,7 +1219,7 @@ async def _submit_turns_async(
1148
1219
data_model : type [BaseModel ] | None = None ,
1149
1220
kwargs : Optional [SubmitInputArgsT ] = None ,
1150
1221
) -> AsyncGenerator [str , None ]:
1151
- def emit (text : str | Content ):
1222
+ def emit (text : Stringable ):
1152
1223
display .update (str (text ))
1153
1224
1154
1225
emit ("<br>\n \n " )
@@ -1202,88 +1273,62 @@ def emit(text: str | Content):
1202
1273
1203
1274
self ._turns .extend ([user_turn , turn ])
1204
1275
1205
- def _invoke_tools (self ) -> Turn | None :
1206
- turn = self .get_last_turn ()
1207
- if turn is None :
1208
- return None
1209
-
1210
- results : list [ContentToolResult ] = []
1211
- for x in turn .contents :
1212
- if isinstance (x , ContentToolRequest ):
1213
- tool_def = self ._tools .get (x .name , None )
1214
- func = tool_def .func if tool_def is not None else None
1215
- results .append (self ._invoke_tool (func , x .arguments , x .id ))
1216
-
1217
- if not results :
1218
- return None
1276
+ def _invoke_tool_request (
1277
+ self , x : ContentToolRequest
1278
+ ) -> tuple [ContentToolResult , Stringable ]:
1279
+ tool_def = self ._tools .get (x .name , None )
1280
+ func = tool_def .func if tool_def is not None else None
1219
1281
1220
- return Turn ("user" , results )
1221
-
1222
- async def _invoke_tools_async (self ) -> Turn | None :
1223
- turn = self .get_last_turn ()
1224
- if turn is None :
1225
- return None
1226
-
1227
- results : list [ContentToolResult ] = []
1228
- for x in turn .contents :
1229
- if isinstance (x , ContentToolRequest ):
1230
- tool_def = self ._tools .get (x .name , None )
1231
- func = None
1232
- if tool_def :
1233
- if tool_def ._is_async :
1234
- func = tool_def .func
1235
- else :
1236
- func = wrap_async (tool_def .func )
1237
- results .append (await self ._invoke_tool_async (func , x .arguments , x .id ))
1238
-
1239
- if not results :
1240
- return None
1241
-
1242
- return Turn ("user" , results )
1243
-
1244
- @staticmethod
1245
- def _invoke_tool (
1246
- func : Callable [..., Any ] | None ,
1247
- arguments : object ,
1248
- id_ : str ,
1249
- ) -> ContentToolResult :
1250
1282
if func is None :
1251
- return ContentToolResult (id_ , value = None , error = "Unknown tool" )
1283
+ return ContentToolResult (x . id , value = None , error = "Unknown tool" ), None
1252
1284
1253
1285
name = func .__name__
1254
1286
1255
1287
try :
1256
- if isinstance (arguments , dict ):
1257
- result = func (** arguments )
1288
+ if isinstance (x . arguments , dict ):
1289
+ result = func (** x . arguments )
1258
1290
else :
1259
- result = func (arguments )
1291
+ result = func (x . arguments )
1260
1292
1261
- return ContentToolResult (id_ , value = result , error = None , name = name )
1293
+ value , output = (result , None )
1294
+ if isinstance (result , ToolResult ):
1295
+ value , output = (result .assistant , result .output )
1296
+
1297
+ return ContentToolResult (x .id , value = value , error = None , name = name ), output
1262
1298
except Exception as e :
1263
- log_tool_error (name , str (arguments ), e )
1264
- return ContentToolResult (id_ , value = None , error = str (e ), name = name )
1299
+ log_tool_error (name , str (x .arguments ), e )
1300
+ return ContentToolResult (x .id , value = None , error = str (e ), name = name ), None
1301
+
1302
+ async def _invoke_tool_request_async (
1303
+ self , x : ContentToolRequest
1304
+ ) -> tuple [ContentToolResult , Stringable ]:
1305
+ tool_def = self ._tools .get (x .name , None )
1306
+ func = None
1307
+ if tool_def :
1308
+ if tool_def ._is_async :
1309
+ func = tool_def .func
1310
+ else :
1311
+ func = wrap_async (tool_def .func )
1265
1312
1266
- @staticmethod
1267
- async def _invoke_tool_async (
1268
- func : Callable [..., Awaitable [Any ]] | None ,
1269
- arguments : object ,
1270
- id_ : str ,
1271
- ) -> ContentToolResult :
1272
1313
if func is None :
1273
- return ContentToolResult (id_ , value = None , error = "Unknown tool" )
1314
+ return ContentToolResult (x . id , value = None , error = "Unknown tool" ), None
1274
1315
1275
1316
name = func .__name__
1276
1317
1277
1318
try :
1278
- if isinstance (arguments , dict ):
1279
- result = await func (** arguments )
1319
+ if isinstance (x . arguments , dict ):
1320
+ result = await func (** x . arguments )
1280
1321
else :
1281
- result = await func (arguments )
1322
+ result = await func (x .arguments )
1323
+
1324
+ value , output = (result , None )
1325
+ if isinstance (result , ToolResult ):
1326
+ value , output = (result .assistant , result .output )
1282
1327
1283
- return ContentToolResult (id_ , value = result , error = None , name = name )
1328
+ return ContentToolResult (x . id , value = value , error = None , name = name ), output
1284
1329
except Exception as e :
1285
- log_tool_error (func .__name__ , str (arguments ), e )
1286
- return ContentToolResult (id_ , value = None , error = str (e ), name = name )
1330
+ log_tool_error (func .__name__ , str (x . arguments ), e )
1331
+ return ContentToolResult (x . id , value = None , error = str (e ), name = name ), None
1287
1332
1288
1333
def _markdown_display (
1289
1334
self , echo : Literal ["text" , "all" , "none" ]
@@ -1378,15 +1423,15 @@ class ChatResponse:
1378
1423
still be retrieved (via the `content` attribute).
1379
1424
"""
1380
1425
1381
- def __init__ (self , generator : Generator [str , None ]):
1426
+ def __init__ (self , generator : Generator [Stringable , None ]):
1382
1427
self ._generator = generator
1383
1428
self .content : str = ""
1384
1429
1385
1430
def __iter__ (self ) -> Iterator [str ]:
1386
1431
return self
1387
1432
1388
1433
def __next__ (self ) -> str :
1389
- chunk = next (self ._generator )
1434
+ chunk = str ( next (self ._generator ) )
1390
1435
self .content += chunk # Keep track of accumulated content
1391
1436
return chunk
1392
1437
@@ -1430,15 +1475,15 @@ class ChatResponseAsync:
1430
1475
still be retrieved (via the `content` attribute).
1431
1476
"""
1432
1477
1433
- def __init__ (self , generator : AsyncGenerator [str , None ]):
1478
+ def __init__ (self , generator : AsyncGenerator [Stringable , None ]):
1434
1479
self ._generator = generator
1435
1480
self .content : str = ""
1436
1481
1437
1482
def __aiter__ (self ) -> AsyncIterator [str ]:
1438
1483
return self
1439
1484
1440
1485
async def __anext__ (self ) -> str :
1441
- chunk = await self ._generator .__anext__ ()
1486
+ chunk = str ( await self ._generator .__anext__ () )
1442
1487
self .content += chunk # Keep track of accumulated content
1443
1488
return chunk
1444
1489
0 commit comments