Skip to content

Commit d5fdb96

Browse files
committed
Require a ToolResultDisplay() for better typing, doc, and serialization experience
1 parent 2f225d6 commit d5fdb96

File tree

7 files changed

+156
-38
lines changed

7 files changed

+156
-38
lines changed

pkg-py/src/shinychat/_chat_normalize_chatlas.py

Lines changed: 126 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import json
44
import os
55
import warnings
6-
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Literal
6+
from typing import TYPE_CHECKING, Any, Literal
87

9-
from htmltools import RenderedHTML, Tag, TagChild, Tagifiable, TagList
8+
from htmltools import HTML, RenderedHTML, Tag, Tagifiable, TagList
109
from packaging import version
10+
from pydantic import BaseModel, field_serializer, field_validator
1111

1212
from ._typing_extensions import TypeGuard
1313

@@ -18,10 +18,12 @@
1818
"ToolResultDisplay",
1919
]
2020

21+
# Pydantic doesn't work with htmltool's recursive TagChild type
22+
TagChild = Tag | TagList | HTML | str | None
2123

22-
@dataclass
23-
class ToolCardComponent:
24-
"Data class mirroring the ShinyToolCard component class in chat-tools.ts"
24+
25+
class ToolCardComponent(BaseModel):
26+
"A class that mirrors the ShinyToolCard component class in chat-tools.ts"
2527

2628
request_id: str
2729
"""
@@ -44,10 +46,23 @@ class ToolCardComponent:
4446
expanded: bool = False
4547
"Controls whether the card content is expanded/visible."
4648

49+
model_config = {"arbitrary_types_allowed": True}
50+
51+
@field_serializer("icon")
52+
def _serialize_icon(self, value: TagChild):
53+
return TagList(value).render()
54+
55+
@field_validator("icon", mode="before")
56+
@classmethod
57+
def _validate_icon(cls, value: TagChild) -> TagChild:
58+
if isinstance(value, dict):
59+
return restore_rendered_html(value)
60+
else:
61+
return value
62+
4763

48-
@dataclass
4964
class ToolRequestComponent(ToolCardComponent):
50-
"Data class mirroring the ShinyToolRequest component class from chat-tools.ts"
65+
"A class that mirrors the ShinyToolRequest component class from chat-tools.ts"
5166

5267
arguments: str = ""
5368
"The function arguments as requested by the LLM, typically in JSON format."
@@ -71,9 +86,8 @@ def tagify(self):
7186
ValueType = Literal["html", "markdown", "text", "code"]
7287

7388

74-
@dataclass
7589
class ToolResultComponent(ToolCardComponent):
76-
"Data class mirroring the ShinyToolResult component class from chat-tools.ts"
90+
"A class that mirrors the ShinyToolResult component class from chat-tools.ts"
7791

7892
request_call: str = ""
7993
"The original tool call that generated this result. Used to display the tool invocation."
@@ -129,18 +143,76 @@ def tagify(self):
129143
)
130144

131145

132-
@dataclass
133-
class ToolResultDisplay:
134-
"Data class to for users to customize how tool results are displayed"
146+
class ToolResultDisplay(BaseModel):
147+
"""
148+
Customize how tool results are displayed.
135149
150+
Assign a `ToolResultDisplay` instance to a
151+
[`chatlas.ContentToolResult`](https://posit-dev.github.io/chatlas/reference/types.ContentToolResult.html)
152+
to customize the UI shown to the user when tool calls occur.
153+
154+
Examples
155+
--------
156+
157+
```python
158+
import chatlas as ctl
159+
from shinychat.types import ToolResultDisplay
160+
161+
162+
def my_tool():
163+
display = ToolResultDisplay(
164+
title="Tool result title",
165+
markdown="A _markdown_ message shown to user.",
166+
)
167+
return ctl.ContentToolResult(
168+
value="Value the model sees",
169+
extra={"display": display},
170+
)
171+
172+
173+
chat_client = ctl.ChatAuto()
174+
chat_client.register_tool(my_tool)
175+
```
176+
177+
Parameters
178+
---------
179+
title
180+
The title to display in the header of the tool result.
181+
icon
182+
An icon to display in the header (alongside the title).
183+
show_request
184+
Whether to show the tool request inside the tool result container.
185+
open
186+
Whether or not the tool result details are expanded by default.
187+
html
188+
Custom HTML content (to use in place of the default result display).
189+
markdown
190+
Custom Markdown string (to use in place of the default result display).
191+
text
192+
Custom plain text string (to use in place of the default result display).
193+
"""
194+
195+
title: str | None = None
196+
icon: TagChild = None
136197
html: TagChild = None
137-
markdown: str | None = None
138-
text: str | None = None
139198
show_request: bool = True
140199
open: bool = False
141-
title: str | None = None
142-
icon: TagChild = None
143-
expanded: bool | None = None
200+
markdown: str | None = None
201+
text: str | None = None
202+
203+
model_config = {"arbitrary_types_allowed": True}
204+
205+
@field_serializer("html", "icon")
206+
def _serialize_html_icon(self, value: TagChild):
207+
return TagList(value).render()
208+
209+
@field_validator("html", "icon", mode="before")
210+
@classmethod
211+
def _validate_html_icon(cls, value: TagChild) -> TagChild:
212+
if isinstance(value, dict):
213+
return restore_rendered_html(value)
214+
else:
215+
return value
144216

145217

146218
def tool_request_contents(x: "ContentToolRequest") -> Tagifiable:
@@ -299,7 +371,7 @@ def is_tool_result(val: object) -> "TypeGuard[ContentToolResult]":
299371
return False
300372

301373

302-
# Tools were added to ContentToolRequest class until 0.11.1
374+
# Tools started getting added to ContentToolRequest staring with 0.11.1
303375
def is_legacy():
304376
import chatlas
305377

@@ -316,3 +388,38 @@ def tool_display_override() -> Literal["none", "basic", "rich"]:
316388
raise ValueError(
317389
'The `SHINYCHAT_TOOL_DISPLAY` env var must be one of: "none", "basic", or "rich"'
318390
)
391+
392+
393+
def restore_rendered_html(x: dict[str, Any]):
394+
from htmltools import HTML, HTMLDependency, TagList
395+
396+
if "html" not in x or "dependencies" not in x:
397+
raise ValueError(f"Don't know how to restore HTML from {x}")
398+
399+
deps: list[HTMLDependency] = []
400+
for d in x["dependencies"]:
401+
if not isinstance(d, dict):
402+
continue
403+
name = d["name"]
404+
version = d["version"]
405+
other = {k: v for k, v in d.items() if k not in ("name", "version")}
406+
# TODO: warn if the source is a tempdir?
407+
deps.append(HTMLDependency(name=name, version=version, **other))
408+
409+
res = TagList(HTML(x["html"]), *deps)
410+
if not deps:
411+
return res
412+
413+
session = None
414+
try:
415+
from shiny.session import get_current_session
416+
417+
session = get_current_session()
418+
except Exception:
419+
pass
420+
421+
# De-dupe dependencies for the current Shiny session
422+
if session:
423+
session._process_ui(res)
424+
425+
return res
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from .._chat import ChatMessage, ChatMessageDict
2+
from .._chat_normalize_chatlas import ToolResultDisplay
3+
4+
ToolResultDisplay.model_rebuild()
25

36
__all__ = [
47
"ChatMessage",
58
"ChatMessageDict",
9+
"ToolResultDisplay",
610
]

pkg-py/tests/playwright/tools/basic/app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from shiny import reactive
1111
from shiny.express import input, ui
1212
from shinychat.express import Chat
13+
from shinychat.types import ToolResultDisplay
1314

1415
TOOL_OPTS = {
1516
"async": os.getenv("TEST_TOOL_ASYNC", "TRUE").lower() == "true",
@@ -28,7 +29,9 @@ def list_files_impl():
2829

2930
extra = {}
3031
if TOOL_OPTS["with_icon"]:
31-
extra = {"display": {"icon": faicons.icon_svg("folder-open")}}
32+
extra = {
33+
"display": ToolResultDisplay(icon=faicons.icon_svg("folder-open")),
34+
}
3235

3336
return ContentToolResult(
3437
value=["app.py", "data.csv"],

pkg-py/tests/playwright/tools/map/app.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ipyleaflet import CircleMarker, Map
66
from shiny.express import ui
77
from shinychat.express import Chat
8+
from shinychat.types import ToolResultDisplay
89
from shinywidgets import output_widget, register_widget
910

1011

@@ -31,12 +32,12 @@ def tool_show_map(
3132
return ContentToolResult(
3233
value="Map shown to the user.",
3334
extra={
34-
"display": {
35-
"html": output_widget(id),
36-
"show_request": False,
37-
"open": True,
38-
"title": f"Map of {title}",
39-
},
35+
"display": ToolResultDisplay(
36+
html=output_widget(id),
37+
show_request=False,
38+
open=True,
39+
title=f"Map of {title}",
40+
),
4041
},
4142
)
4243

pkg-py/tests/playwright/tools/weather/app_03_tool_result_simple.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from chatlas import ChatOpenAI, ContentToolResult
33
from shiny.express import app_opts, ui
44
from shinychat.express import Chat
5+
from shinychat.types import ToolResultDisplay
56

67
from . import tools
78

@@ -24,10 +25,10 @@ def get_weather_forecast(
2425
return ContentToolResult(
2526
value=forecast_data,
2627
extra={
27-
"display": {
28-
"title": f"Weather Forecast for {location_name}",
29-
"icon": faicons.icon_svg(icon),
30-
}
28+
"display": ToolResultDisplay(
29+
title=f"Weather Forecast for {location_name}",
30+
icon=faicons.icon_svg(icon),
31+
)
3132
},
3233
)
3334

pkg-py/tests/playwright/tools/weather/app_04_tool_result_table.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from chatlas import ChatOpenAI, ContentToolResult
55
from shiny.express import ui
66
from shinychat.express import Chat
7+
from shinychat.types import ToolResultDisplay
78

89
# Set environment variable for tool display
910
os.environ["SHINYCHAT_TOOL_DISPLAY"] = "rich"
@@ -33,10 +34,10 @@ def get_weather_forecast(
3334
return ContentToolResult(
3435
value=forecast_table,
3536
extra={
36-
"display": {
37-
"html": ui.HTML(forecast_table),
38-
"title": f"Weather Forecast for {location_name}",
39-
}
37+
"display": ToolResultDisplay(
38+
html=ui.HTML(forecast_table),
39+
title=f"Weather Forecast for {location_name}",
40+
)
4041
},
4142
)
4243

pkg-py/tests/playwright/tools/weather/app_05_tool_custom_result_class.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from chatlas import ChatOpenAI, ContentToolResult
33
from shiny.express import ui
44
from shinychat.express import Chat
5+
from shinychat.types import ToolResultDisplay
56

67

78
class WeatherToolResult(ContentToolResult):
@@ -24,10 +25,10 @@ def __init__(self, forecast_data, location_name: str, **kwargs):
2425
html_table = str(forecast_data) # Fallback
2526

2627
extra = {
27-
"display": {
28-
"html": ui.HTML(html_table),
29-
"title": f"Weather Forecast for {location_name}",
30-
}
28+
"display": ToolResultDisplay(
29+
html=ui.HTML(html_table),
30+
title=f"Weather Forecast for {location_name}",
31+
)
3132
}
3233

3334
super().__init__(value=forecast_data, extra=extra, **kwargs)

0 commit comments

Comments
 (0)