33import json
44import os
55import warnings
6- from dataclasses import dataclass
7- from typing import TYPE_CHECKING , Literal
6+ from typing import TYPE_CHECKING , Any , Literal
87
9- from htmltools import RenderedHTML , Tag , TagChild , Tagifiable , TagList
8+ from htmltools import HTML , RenderedHTML , Tag , Tagifiable , TagList
109from packaging import version
10+ from pydantic import BaseModel , field_serializer , field_validator
1111
1212from ._typing_extensions import TypeGuard
1313
1818 "ToolResultDisplay" ,
1919]
2020
21+ # Pydantic doesn't work with htmltool's recursive TagChild type
22+ TagChild = Tag | TagList | HTML | str | None
2123
22- @ dataclass
23- class ToolCardComponent :
24- "Data class mirroring the ShinyToolCard component class in chat-tools.ts"
24+
25+ class ToolCardComponent ( BaseModel ) :
26+ "A class that mirrors the ShinyToolCard component class in chat-tools.ts"
2527
2628 request_id : str
2729 """
@@ -44,10 +46,23 @@ class ToolCardComponent:
4446 expanded : bool = False
4547 "Controls whether the card content is expanded/visible."
4648
49+ model_config = {"arbitrary_types_allowed" : True }
50+
51+ @field_serializer ("icon" )
52+ def _serialize_icon (self , value : TagChild ):
53+ return TagList (value ).render ()
54+
55+ @field_validator ("icon" , mode = "before" )
56+ @classmethod
57+ def _validate_icon (cls , value : TagChild ) -> TagChild :
58+ if isinstance (value , dict ):
59+ return restore_rendered_html (value )
60+ else :
61+ return value
62+
4763
48- @dataclass
4964class ToolRequestComponent (ToolCardComponent ):
50- "Data class mirroring the ShinyToolRequest component class from chat-tools.ts"
65+ "A class that mirrors the ShinyToolRequest component class from chat-tools.ts"
5166
5267 arguments : str = ""
5368 "The function arguments as requested by the LLM, typically in JSON format."
@@ -71,9 +86,8 @@ def tagify(self):
7186ValueType = Literal ["html" , "markdown" , "text" , "code" ]
7287
7388
74- @dataclass
7589class ToolResultComponent (ToolCardComponent ):
76- "Data class mirroring the ShinyToolResult component class from chat-tools.ts"
90+ "A class that mirrors the ShinyToolResult component class from chat-tools.ts"
7791
7892 request_call : str = ""
7993 "The original tool call that generated this result. Used to display the tool invocation."
@@ -129,18 +143,76 @@ def tagify(self):
129143 )
130144
131145
132- @ dataclass
133- class ToolResultDisplay :
134- "Data class to for users to customize how tool results are displayed"
146+ class ToolResultDisplay ( BaseModel ):
147+ """
148+ Customize how tool results are displayed.
135149
150+ Assign a `ToolResultDisplay` instance to a
151+ [`chatlas.ContentToolResult`](https://posit-dev.github.io/chatlas/reference/types.ContentToolResult.html)
152+ to customize the UI shown to the user when tool calls occur.
153+
154+ Examples
155+ --------
156+
157+ ```python
158+ import chatlas as ctl
159+ from shinychat.types import ToolResultDisplay
160+
161+
162+ def my_tool():
163+ display = ToolResultDisplay(
164+ title="Tool result title",
165+ markdown="A _markdown_ message shown to user.",
166+ )
167+ return ctl.ContentToolResult(
168+ value="Value the model sees",
169+ extra={"display": display},
170+ )
171+
172+
173+ chat_client = ctl.ChatAuto()
174+ chat_client.register_tool(my_tool)
175+ ```
176+
177+ Parameters
178+ ---------
179+ title
180+ The title to display in the header of the tool result.
181+ icon
182+ An icon to display in the header (alongside the title).
183+ show_request
184+ Whether to show the tool request inside the tool result container.
185+ open
186+ Whether or not the tool result details are expanded by default.
187+ html
188+ Custom HTML content (to use in place of the default result display).
189+ markdown
190+ Custom Markdown string (to use in place of the default result display).
191+ text
192+ Custom plain text string (to use in place of the default result display).
193+ """
194+
195+ title : str | None = None
196+ icon : TagChild = None
136197 html : TagChild = None
137- markdown : str | None = None
138- text : str | None = None
139198 show_request : bool = True
140199 open : bool = False
141- title : str | None = None
142- icon : TagChild = None
143- expanded : bool | None = None
200+ markdown : str | None = None
201+ text : str | None = None
202+
203+ model_config = {"arbitrary_types_allowed" : True }
204+
205+ @field_serializer ("html" , "icon" )
206+ def _serialize_html_icon (self , value : TagChild ):
207+ return TagList (value ).render ()
208+
209+ @field_validator ("html" , "icon" , mode = "before" )
210+ @classmethod
211+ def _validate_html_icon (cls , value : TagChild ) -> TagChild :
212+ if isinstance (value , dict ):
213+ return restore_rendered_html (value )
214+ else :
215+ return value
144216
145217
146218def tool_request_contents (x : "ContentToolRequest" ) -> Tagifiable :
@@ -299,7 +371,7 @@ def is_tool_result(val: object) -> "TypeGuard[ContentToolResult]":
299371 return False
300372
301373
302- # Tools were added to ContentToolRequest class until 0.11.1
374+ # Tools started getting added to ContentToolRequest staring with 0.11.1
303375def is_legacy ():
304376 import chatlas
305377
@@ -316,3 +388,38 @@ def tool_display_override() -> Literal["none", "basic", "rich"]:
316388 raise ValueError (
317389 'The `SHINYCHAT_TOOL_DISPLAY` env var must be one of: "none", "basic", or "rich"'
318390 )
391+
392+
393+ def restore_rendered_html (x : dict [str , Any ]):
394+ from htmltools import HTML , HTMLDependency , TagList
395+
396+ if "html" not in x or "dependencies" not in x :
397+ raise ValueError (f"Don't know how to restore HTML from { x } " )
398+
399+ deps : list [HTMLDependency ] = []
400+ for d in x ["dependencies" ]:
401+ if not isinstance (d , dict ):
402+ continue
403+ name = d ["name" ]
404+ version = d ["version" ]
405+ other = {k : v for k , v in d .items () if k not in ("name" , "version" )}
406+ # TODO: warn if the source is a tempdir?
407+ deps .append (HTMLDependency (name = name , version = version , ** other ))
408+
409+ res = TagList (HTML (x ["html" ]), * deps )
410+ if not deps :
411+ return res
412+
413+ session = None
414+ try :
415+ from shiny .session import get_current_session
416+
417+ session = get_current_session ()
418+ except Exception :
419+ pass
420+
421+ # De-dupe dependencies for the current Shiny session
422+ if session :
423+ session ._process_ui (res )
424+
425+ return res
0 commit comments