diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md index eee98b76..1c27a47b 100644 --- a/pkg-py/CHANGELOG.md +++ b/pkg-py/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features +* New and improved UI for tool calls that occur via [chatlas](https://posit-dev.github.io/chatlas/). As a reminder, tool call displays are enabled by setting `content="all"` in chatlas' `.stream()` (or `.stream_async()`) method. See the tests under the `pkg-py/tests/playwright/tools` directory for inspiration of what is now possible with custom tool displays via the new `ToolResultDisplay` class. (#107) * Added new `message_content()` and `message_content_chunk()` generic (`singledispatch`) functions. These functions aren't intended to be called directly by users, but instead, provide an opportunity to teach `Chat.append_message()`/`Chat.append_message_stream()` to extract message contents from different types of objects. (#96) ### Bug fixes diff --git a/pkg-py/src/shinychat/_chat.py b/pkg-py/src/shinychat/_chat.py index 18137bea..8803702f 100644 --- a/pkg-py/src/shinychat/_chat.py +++ b/pkg-py/src/shinychat/_chat.py @@ -54,6 +54,7 @@ set_chatlas_state, ) from ._chat_normalize import message_content, message_content_chunk +from ._chat_normalize_chatlas import hide_corresponding_request, is_tool_result from ._chat_provider_types import ( AnthropicMessage, # pyright: ignore[reportAttributeAccessIssue] GoogleMessage, @@ -754,6 +755,9 @@ async def _append_message_chunk( # Normalize various message types into a ChatMessage() msg = message_content_chunk(message) + if is_tool_result(message): + await hide_corresponding_request(message) + if operation == "replace": self._current_stream_message = ( self._message_stream_checkpoint + msg.content diff --git a/pkg-py/src/shinychat/_chat_normalize.py b/pkg-py/src/shinychat/_chat_normalize.py index 892c4dde..9a014335 100644 --- a/pkg-py/src/shinychat/_chat_normalize.py +++ b/pkg-py/src/shinychat/_chat_normalize.py @@ -3,15 +3,16 @@ import sys from functools import singledispatch -from htmltools import HTML, Tagifiable +from htmltools import HTML, Tagifiable, TagList +from ._chat_normalize_chatlas import tool_request_contents, tool_result_contents from ._chat_types import ChatMessage __all__ = ["message_content", "message_content_chunk"] @singledispatch -def message_content(message) -> ChatMessage: +def message_content(message): """ Extract content from various message types into a ChatMessage. @@ -42,7 +43,7 @@ def message_content(message) -> ChatMessage: If the message type is unsupported. """ if isinstance(message, (str, HTML)) or message is None: - return ChatMessage(content=message, role="assistant") + return ChatMessage(content=message) if isinstance(message, dict): if "content" not in message: raise ValueError("Message dictionary must have a 'content' key") @@ -57,7 +58,7 @@ def message_content(message) -> ChatMessage: @singledispatch -def message_content_chunk(chunk) -> ChatMessage: +def message_content_chunk(chunk): """ Extract content from various message chunk types into a ChatMessage. @@ -88,7 +89,7 @@ def message_content_chunk(chunk) -> ChatMessage: If the chunk type is unsupported. """ if isinstance(chunk, (str, HTML)) or chunk is None: - return ChatMessage(content=chunk, role="assistant") + return ChatMessage(content=chunk) if isinstance(chunk, dict): if "content" not in chunk: raise ValueError("Chunk dictionary must have a 'content' key") @@ -108,15 +109,71 @@ def message_content_chunk(chunk) -> ChatMessage: @message_content.register -def _(message: Tagifiable) -> ChatMessage: - return ChatMessage(content=message, role="assistant") +def _(message: Tagifiable): + return ChatMessage(content=message) @message_content_chunk.register -def _(chunk: Tagifiable) -> ChatMessage: - return ChatMessage(content=chunk, role="assistant") +def _(chunk: Tagifiable): + return ChatMessage(content=chunk) +# ----------------------------------------------------------------- +# chatlas tool call display +# ----------------------------------------------------------------- +try: + from chatlas import ContentToolRequest, ContentToolResult, Turn + from chatlas.types import Content, ContentText + + @message_content.register + def _(message: Content): + return ChatMessage(content=str(message)) + + @message_content_chunk.register + def _(chunk: Content): + return message_content(chunk) + + @message_content.register + def _(message: ContentText): + return ChatMessage(content=message.text) + + @message_content_chunk.register + def _(chunk: ContentText): + return message_content(chunk) + + @message_content.register + def _(chunk: ContentToolRequest): + return ChatMessage(content=tool_request_contents(chunk)) + + @message_content_chunk.register + def _(chunk: ContentToolRequest): + return message_content(chunk) + + @message_content.register + def _(chunk: ContentToolResult): + return ChatMessage(content=tool_result_contents(chunk)) + + @message_content_chunk.register + def _(chunk: ContentToolResult): + return message_content(chunk) + + @message_content.register + def _(message: Turn): + contents = TagList() + for x in message.contents: + contents.append(message_content(x).content) + return ChatMessage(content=contents) + + @message_content_chunk.register + def _(chunk: Turn): + return message_content(chunk) + + # N.B., unlike R, Python Chat stores UI state and so can replay + # it with additional workarounds. That's why R currently has a + # shinychat_contents() method for Chat, but Python doesn't. +except ImportError: + pass + # ------------------------------------------------------------------ # LangChain content extractor # ------------------------------------------------------------------ @@ -125,7 +182,7 @@ def _(chunk: Tagifiable) -> ChatMessage: from langchain_core.messages import BaseMessage, BaseMessageChunk @message_content.register - def _(message: BaseMessage) -> ChatMessage: + def _(message: BaseMessage): if isinstance(message.content, list): raise ValueError( "The `message.content` provided seems to represent numerous messages. " @@ -137,7 +194,7 @@ def _(message: BaseMessage) -> ChatMessage: ) @message_content_chunk.register - def _(chunk: BaseMessageChunk) -> ChatMessage: + def _(chunk: BaseMessageChunk): if isinstance(chunk.content, list): raise ValueError( "The `chunk.content` provided seems to represent numerous message chunks. " @@ -159,14 +216,14 @@ def _(chunk: BaseMessageChunk) -> ChatMessage: from openai.types.chat import ChatCompletion, ChatCompletionChunk @message_content.register - def _(message: ChatCompletion) -> ChatMessage: + def _(message: ChatCompletion): return ChatMessage( content=message.choices[0].message.content, role="assistant", ) @message_content_chunk.register - def _(chunk: ChatCompletionChunk) -> ChatMessage: + def _(chunk: ChatCompletionChunk): return ChatMessage( content=chunk.choices[0].delta.content, role="assistant", @@ -185,21 +242,23 @@ def _(chunk: ChatCompletionChunk) -> ChatMessage: ) @message_content.register - def _(message: AnthropicMessage) -> ChatMessage: + def _(message: AnthropicMessage): content = message.content[0] if content.type != "text": raise ValueError( f"Anthropic message type {content.type} not supported. " "Only 'text' type is currently supported" ) - return ChatMessage(content=content.text, role="assistant") + return ChatMessage(content=content.text) # Old versions of singledispatch doesn't seem to support union types if sys.version_info >= (3, 11): - from anthropic.types import RawMessageStreamEvent + from anthropic.types import ( # pyright: ignore[reportMissingImports] + RawMessageStreamEvent, + ) @message_content_chunk.register - def _(chunk: RawMessageStreamEvent) -> ChatMessage: + def _(chunk: RawMessageStreamEvent): content = "" if chunk.type == "content_block_delta": if chunk.delta.type != "text_delta": @@ -209,7 +268,7 @@ def _(chunk: RawMessageStreamEvent) -> ChatMessage: ) content = chunk.delta.text - return ChatMessage(content=content, role="assistant") + return ChatMessage(content=content) except ImportError: pass @@ -224,12 +283,12 @@ def _(chunk: RawMessageStreamEvent) -> ChatMessage: ) @message_content.register - def _(message: GenerateContentResponse) -> ChatMessage: - return ChatMessage(content=message.text, role="assistant") + def _(message: GenerateContentResponse): + return ChatMessage(content=message.text) @message_content_chunk.register - def _(chunk: GenerateContentResponse) -> ChatMessage: - return ChatMessage(content=chunk.text, role="assistant") + def _(chunk: GenerateContentResponse): + return ChatMessage(content=chunk.text) except ImportError: pass @@ -243,14 +302,14 @@ def _(chunk: GenerateContentResponse) -> ChatMessage: from ollama import ChatResponse @message_content.register - def _(message: ChatResponse) -> ChatMessage: + def _(message: ChatResponse): msg = message.message - return ChatMessage(msg.content, role="assistant") + return ChatMessage(msg.content) @message_content_chunk.register - def _(chunk: ChatResponse) -> ChatMessage: + def _(chunk: ChatResponse): msg = chunk.message - return ChatMessage(msg.content, role="assistant") + return ChatMessage(msg.content) except ImportError: pass diff --git a/pkg-py/src/shinychat/_chat_normalize_chatlas.py b/pkg-py/src/shinychat/_chat_normalize_chatlas.py new file mode 100644 index 00000000..991060f7 --- /dev/null +++ b/pkg-py/src/shinychat/_chat_normalize_chatlas.py @@ -0,0 +1,439 @@ +from __future__ import annotations + +import json +import os +import warnings +from typing import TYPE_CHECKING, Any, Literal, Optional, Sequence, Union + +from htmltools import ( + HTML, + MetadataNode, + RenderedHTML, + ReprHtml, + Tag, + Tagifiable, + TagList, +) +from packaging import version +from pydantic import BaseModel, field_serializer, field_validator +from typing_extensions import TypeAliasType + +from ._typing_extensions import TypeGuard + +if TYPE_CHECKING: + from chatlas.types import ContentToolRequest, ContentToolResult + +__all__ = [ + "ToolResultDisplay", +] + +# A version of the (recursive) TagChild type that actually works with Pydantic +# https://docs.pydantic.dev/2.11/concepts/types/#named-type-aliases +TagNode = Union[Tagifiable, MetadataNode, ReprHtml, str, HTML] +TagChild = TypeAliasType( + "TagChild", + "Union[TagNode, TagList, float, None, Sequence[TagChild]]", +) + + +class ToolCardComponent(BaseModel): + "A class that mirrors the ShinyToolCard component class in chat-tools.ts" + + request_id: str + """ + Unique identifier for the tool request or result. + This value links a request to a result and is therefore not unique on the page. + """ + + tool_name: str + "Name of the tool being executed, e.g. `get_weather`." + + tool_title: Optional[str] = None + "Display title for the card. If not provided, falls back to `tool_name`." + + icon: TagChild = None + "HTML content for the icon displayed in the card header." + + intent: Optional[str] = None + "Optional intent description explaining the purpose of the tool execution." + + expanded: bool = False + "Controls whether the card content is expanded/visible." + + model_config = {"arbitrary_types_allowed": True} + + @field_serializer("icon") + def _serialize_icon(self, value: TagChild): + return TagList(value).render() + + @field_validator("icon", mode="before") + @classmethod + def _validate_icon(cls, value: TagChild) -> TagChild: + if isinstance(value, dict): + return restore_rendered_html(value) + else: + return value + + +class ToolRequestComponent(ToolCardComponent): + "A class that mirrors the ShinyToolRequest component class from chat-tools.ts" + + arguments: str = "" + "The function arguments as requested by the LLM, typically in JSON format." + + def tagify(self): + icon_ui = TagList(self.icon).render() + + return Tag( + "shiny-tool-request", + request_id=self.request_id, + tool_name=self.tool_name, + tool_title=self.tool_title, + icon=icon_ui["html"] if self.icon else None, + intent=self.intent, + expanded="" if self.expanded else None, + arguments=self.arguments, + *icon_ui["dependencies"], + ) + + +ValueType = Literal["html", "markdown", "text", "code"] + + +class ToolResultComponent(ToolCardComponent): + "A class that mirrors the ShinyToolResult component class from chat-tools.ts" + + request_call: str = "" + "The original tool call that generated this result. Used to display the tool invocation." + + status: Literal["success", "error"] = "success" + """ + The status of the tool execution. When set to "error", displays in an error state with + red text and an exclamation icon. + """ + + show_request: bool = True + "Should the tool request should be displayed alongside the result?" + + value: TagChild = None + "The actual result content returned by the tool execution." + + value_type: ValueType = "code" + """ + Specifies how the value should be rendered. Supported types: + - "html": Renders the value as raw HTML + - "text": Renders the value as plain text in a paragraph + - "markdown": Renders the value as Markdown (default) + - "code": Renders the value as a code block + Any other value defaults to markdown rendering. + """ + + def tagify(self): + icon_ui = TagList(self.icon).render() + + if self.value_type == "html": + value_ui = TagList(self.value).render() + else: + value_ui: "RenderedHTML" = { + "html": str(self.value), + "dependencies": [], + } + + return Tag( + "shiny-tool-result", + request_id=self.request_id, + tool_name=self.tool_name, + tool_title=self.tool_title, + icon=icon_ui["html"] if self.icon else None, + intent=self.intent, + request_call=self.request_call, + status=self.status, + value=value_ui["html"], + value_type=self.value_type, + show_request="" if self.show_request else None, + expanded="" if self.expanded else None, + *icon_ui["dependencies"], + *value_ui["dependencies"], + ) + + +class ToolResultDisplay(BaseModel): + """ + Customize how tool results are displayed. + + Assign a `ToolResultDisplay` instance to a + [`chatlas.ContentToolResult`](https://posit-dev.github.io/chatlas/reference/types.ContentToolResult.html) + to customize the UI shown to the user when tool calls occur. + + Examples + -------- + + ```python + import chatlas as ctl + from shinychat.types import ToolResultDisplay + + + def my_tool(): + display = ToolResultDisplay( + title="Tool result title", + markdown="A _markdown_ message shown to user.", + ) + return ctl.ContentToolResult( + value="Value the model sees", + extra={"display": display}, + ) + + + chat_client = ctl.ChatAuto() + chat_client.register_tool(my_tool) + ``` + + Parameters + --------- + title + The title to display in the header of the tool result. + icon + An icon to display in the header (alongside the title). + show_request + Whether to show the tool request inside the tool result container. + open + Whether or not the tool result details are expanded by default. + html + Custom HTML content (to use in place of the default result display). + markdown + Custom Markdown string (to use in place of the default result display). + text + Custom plain text string (to use in place of the default result display). + """ + + title: Optional[str] = None + icon: TagChild = None + html: TagChild = None + show_request: bool = True + open: bool = False + markdown: Optional[str] = None + text: Optional[str] = None + + model_config = {"arbitrary_types_allowed": True} + + @field_serializer("html", "icon") + def _serialize_html_icon(self, value: TagChild): + return TagList(value).render() + + @field_validator("html", "icon", mode="before") + @classmethod + def _validate_html_icon(cls, value: TagChild) -> TagChild: + if isinstance(value, dict): + return restore_rendered_html(value) + else: + return value + + +def tool_request_contents(x: "ContentToolRequest") -> Tagifiable: + if tool_display_override() == "none": + return TagList() + + # These content objects do have tagify() methods, + # but that's for legacy behavior + if is_legacy(): + return x + + intent = None + if isinstance(x.arguments, dict): + intent = x.arguments.get("_intent") + + tool_title = None + if x.tool and x.tool.annotations: + tool_title = x.tool.annotations.get("title") + + return ToolRequestComponent( + request_id=x.id, + tool_name=x.name, + arguments=json.dumps(x.arguments), + intent=intent, + tool_title=tool_title, + ) + + +def tool_result_contents(x: "ContentToolResult") -> Tagifiable: + if tool_display_override() == "none": + return TagList() + + # These content objects do have tagify() methods, + # but that's the legacy behavior + if is_legacy(): + return x + + if x.request is None: + raise ValueError( + "`ContentToolResult` objects must have an associated `.request` attribute." + ) + + # TODO: look into better formating of the call? + request_call = json.dumps( + { + "id": x.id, + "name": x.request.name, + "arguments": x.request.arguments, + }, + indent=2, + ) + + display = get_tool_result_display(x, x.request) + value, value_type = tool_result_display(x, display) + + intent = None + if isinstance(x.arguments, dict): + intent = x.arguments.get("_intent") + + tool = x.request.tool + tool_title = None + if tool and tool.annotations: + tool_title = tool.annotations.get("title") + + # display (tool *result* level) takes precedence over + # annotations (tool *definition* level) + return ToolResultComponent( + request_id=x.id, + request_call=request_call, + tool_name=x.request.name, + tool_title=display.title or tool_title, + status="success" if x.error is None else "error", + value=value, + value_type=value_type, + icon=display.icon, + intent=intent, + show_request=display.show_request, + expanded=display.open, + ) + + +def get_tool_result_display( + x: "ContentToolResult", + request: "ContentToolRequest", +) -> ToolResultDisplay: + if not isinstance(x.extra, dict) or tool_display_override() == "basic": + return ToolResultDisplay() + + display = x.extra.get("display", ToolResultDisplay()) + + if isinstance(display, ToolResultDisplay): + return display + + if isinstance(display, dict): + return ToolResultDisplay(**display) + + warnings.warn( + "Invalid `display` value inside `ContentToolResult(extra={'display': display})` " + f"from {request.name} (call id: {request.id}). " + "Expected either a `shinychat.ToolResultDisplay()` instance or a dictionary, " + f"but got {type(display)}." + ) + + return ToolResultDisplay() + + +def tool_result_display( + x: "ContentToolResult", + display: ToolResultDisplay, +) -> tuple[TagChild, ValueType]: + if x.error is not None: + return str(x.error), "code" + + if tool_display_override() == "basic": + return str(x.get_model_value()), "code" + + if display.html is not None: + return display.html, "html" + + if display.markdown is not None: + return display.markdown, "markdown" + + if display.text is not None: + return display.text, "text" + + return str(x.get_model_value()), "code" + + +async def hide_corresponding_request(x: "ContentToolResult"): + if x.request is None: + return + + session = None + try: + from shiny.session import get_current_session + + session = get_current_session() + except Exception: + return + + if session is None: + return + + await session.send_custom_message( + "shiny-tool-request-hide", + x.request.id, # type: ignore + ) + + +def is_tool_result(val: object) -> "TypeGuard[ContentToolResult]": + try: + from chatlas.types import ContentToolResult + + return isinstance(val, ContentToolResult) + except ImportError: + return False + + +# Tools started getting added to ContentToolRequest staring with 0.11.1 +def is_legacy(): + import chatlas + + v = chatlas._version.version_tuple + ver = f"{v[0]}.{v[1]}.{v[2]}" + return version.parse(ver) < version.parse("0.11.1") + + +def tool_display_override() -> Literal["none", "basic", "rich"]: + val = os.getenv("SHINYCHAT_TOOL_DISPLAY", "rich") + if val == "rich" or val == "basic" or val == "none": + return val + else: + raise ValueError( + 'The `SHINYCHAT_TOOL_DISPLAY` env var must be one of: "none", "basic", or "rich"' + ) + + +def restore_rendered_html(x: dict[str, Any]): + from htmltools import HTMLDependency + + if "html" not in x or "dependencies" not in x: + raise ValueError(f"Don't know how to restore HTML from {x}") + + deps: list[HTMLDependency] = [] + for d in x["dependencies"]: + if not isinstance(d, dict): + continue + name = d["name"] + version = d["version"] + other = {k: v for k, v in d.items() if k not in ("name", "version")} + # TODO: warn if the source is a tempdir? + deps.append(HTMLDependency(name=name, version=version, **other)) + + res = TagList(HTML(x["html"]), *deps) + if not deps: + return res + + session = None + try: + from shiny.session import get_current_session + + session = get_current_session() + except Exception: + pass + + # De-dupe dependencies for the current Shiny session + if session: + session._process_ui(res) + + return res diff --git a/pkg-py/src/shinychat/_chat_types.py b/pkg-py/src/shinychat/_chat_types.py index ea7dda41..d0e7cc62 100644 --- a/pkg-py/src/shinychat/_chat_types.py +++ b/pkg-py/src/shinychat/_chat_types.py @@ -22,7 +22,7 @@ class ChatMessage: def __init__( self, content: TagChild, - role: Role, + role: Role = "assistant", ): self.role: Role = role diff --git a/pkg-py/src/shinychat/types/__init__.py b/pkg-py/src/shinychat/types/__init__.py index 36060196..0a9fb8c7 100644 --- a/pkg-py/src/shinychat/types/__init__.py +++ b/pkg-py/src/shinychat/types/__init__.py @@ -1,6 +1,10 @@ from .._chat import ChatMessage, ChatMessageDict +from .._chat_normalize_chatlas import ToolResultDisplay + +ToolResultDisplay.model_rebuild() __all__ = [ "ChatMessage", "ChatMessageDict", + "ToolResultDisplay", ] diff --git a/pkg-py/tests/playwright/tools/basic/app.py b/pkg-py/tests/playwright/tools/basic/app.py new file mode 100644 index 00000000..b8e7a433 --- /dev/null +++ b/pkg-py/tests/playwright/tools/basic/app.py @@ -0,0 +1,144 @@ +import asyncio +import os +import random +import time + +import faicons +from chatlas import ChatAuto, ContentToolResult +from chatlas.types import ToolAnnotations +from pydantic import BaseModel, Field +from shiny import reactive +from shiny.express import input, ui +from shinychat.express import Chat +from shinychat.types import ToolResultDisplay + +TOOL_OPTS = { + "async": os.getenv("TEST_TOOL_ASYNC", "TRUE").lower() == "true", + "with_intent": os.getenv("TEST_TOOL_WITH_INTENT", "TRUE").lower() == "true", + "with_title": os.getenv("TEST_TOOL_WITH_TITLE", "TRUE").lower() == "true", + "with_icon": os.getenv("TEST_TOOL_WITH_ICON", "TRUE").lower() == "true", +} + +chat_client = ChatAuto(provider="openai", model="gpt-4.1-nano") + + +def list_files_impl(): + # Randomly fail sometimes to test error handling + if random.choice([True, False, False, False]): + raise Exception("An error occurred while listing files.") + + extra = {} + if TOOL_OPTS["with_icon"]: + extra = { + "display": ToolResultDisplay(icon=faicons.icon_svg("folder-open")), + } + + return ContentToolResult( + value=["app.py", "data.csv"], + extra=extra, + ) + + +class ListFileParams(BaseModel): + """ + List files in the user's current directory. Always check again when asked. + """ + + path: str = Field(..., description="The path to list files from") + + +class ListFileParamsWithIntent(ListFileParams): + intent: str = Field( + ..., description="The user's intent for this tool", alias="_intent" + ) + + +annotations: ToolAnnotations = {} +if TOOL_OPTS["with_title"]: + annotations["title"] = "List Files" + +# Define the tool function based on configuration +if TOOL_OPTS["async"]: + if TOOL_OPTS["with_intent"]: + + async def list_files_func1(path: str, _intent: str): + await asyncio.sleep(random.uniform(1, 10)) + return list_files_impl() + + chat_client.register_tool( + list_files_func1, + name="list_files", + model=ListFileParamsWithIntent, + annotations=annotations, + ) + + else: + + async def list_files_func2(path: str): + await asyncio.sleep(random.uniform(1, 10)) + return list_files_impl() + + chat_client.register_tool( + list_files_func2, + name="list_files", + model=ListFileParams, + annotations=annotations, + ) + +else: + if TOOL_OPTS["with_intent"]: + + def list_files_func3(path: str, _intent: str): + time.sleep(random.uniform(1, 3)) + return list_files_impl() + + chat_client.register_tool( + list_files_func3, + name="list_files", + model=ListFileParamsWithIntent, + annotations=annotations, + ) + + else: + + def list_files_func4(path: str): + time.sleep(random.uniform(1, 3)) + return list_files_impl() + + chat_client.register_tool( + list_files_func4, + name="list_files", + model=ListFileParams, + annotations=annotations, + ) + +ui.page_opts(fillable=True) + +chat = Chat(id="chat") +chat.ui( + messages=[ + """ +

In three separate but parallel tool calls list the files in apps, data, docs

+

Write some basic Python code that demonstrates how to use pandas.

+

Brainstorm 10 ideas for a name for a package that creates interactive sparklines in tables.

+ """, + ], +) + + +@chat.on_user_submit +async def handle_user_input(user_input: str): + response = await chat_client.stream_async(user_input, content="all") + await chat.append_message_stream(response) + + +ui.input_action_button("click", "Click me") + + +@reactive.effect +@reactive.event(input.click) +def _(): + ui.update_action_button( + "click", + label=f"Clicked {input.click()} times", + ) diff --git a/pkg-py/tests/playwright/tools/map/app.py b/pkg-py/tests/playwright/tools/map/app.py new file mode 100644 index 00000000..3df088b1 --- /dev/null +++ b/pkg-py/tests/playwright/tools/map/app.py @@ -0,0 +1,66 @@ +import uuid + +import ipywidgets +from chatlas import ChatOpenAI, ContentToolResult +from ipyleaflet import CircleMarker, Map +from shiny.express import ui +from shinychat.express import Chat +from shinychat.types import ToolResultDisplay +from shinywidgets import output_widget, register_widget + + +def tool_show_map( + lat: float, + lon: float, + title: str, + description: str, +) -> ContentToolResult: + """Show a map with a marker. + + Use this tool whenever you're talking about a location with the user. + """ + + info = f"{title}
{description}" + + loc = (lat, lon) + m = Map(center=loc, zoom=10) + m.add_layer(CircleMarker(location=loc, popup=ipywidgets.HTML(info))) + + id = f"map_{uuid.uuid4().hex}" + register_widget(id, m) + + return ContentToolResult( + value="Map shown to the user.", + extra={ + "display": ToolResultDisplay( + html=output_widget(id), + show_request=False, + open=True, + title=f"Map of {title}", + ), + }, + ) + + +ui.page_opts(fillable=True, title="Map Tool") + +client = ChatOpenAI( + model="gpt-4.1-nano", + system_prompt=""" +You're a helpful guide who can tell users about places and show them maps. + +Anytime you mention a location, use the `tool_show_map` tool to show a map with a marker at the location. Don't make the user ask to see the map, just show it automatically when it'd be relevant to have a visual. +""", +) +client.register_tool(tool_show_map) + +chat = Chat(id="chat") +chat.ui( + messages=["Ask me about any location, and I'll show you a map!"], +) + + +@chat.on_user_submit +async def handle_user_input(user_input: str): + response = await client.stream_async(user_input, content="all") + await chat.append_message_stream(response) diff --git a/pkg-py/tests/playwright/tools/weather/app_01_simple.py b/pkg-py/tests/playwright/tools/weather/app_01_simple.py new file mode 100644 index 00000000..048130a0 --- /dev/null +++ b/pkg-py/tests/playwright/tools/weather/app_01_simple.py @@ -0,0 +1,22 @@ +from chatlas import ChatOpenAI +from shiny.express import app_opts, ui +from shinychat.express import Chat + +from .tools import get_weather_forecast + +client = ChatOpenAI(model="gpt-4.1-nano") +client.register_tool(get_weather_forecast) + +ui.page_opts(title="Weather Tool - Simple") +app_opts(bookmark_store="url") + +chat = Chat(id="chat") +chat.ui() + +chat.enable_bookmarking(client) + + +@chat.on_user_submit +async def handle_user_input(user_input: str): + response = await client.stream_async(user_input, content="all") + await chat.append_message_stream(response) diff --git a/pkg-py/tests/playwright/tools/weather/app_02_annotations.py b/pkg-py/tests/playwright/tools/weather/app_02_annotations.py new file mode 100644 index 00000000..a8577d24 --- /dev/null +++ b/pkg-py/tests/playwright/tools/weather/app_02_annotations.py @@ -0,0 +1,25 @@ +from chatlas import ChatOpenAI +from shiny.express import app_opts, ui +from shinychat.express import Chat + +from .tools import get_weather_forecast + +client = ChatOpenAI(model="gpt-4.1-nano") +client.register_tool( + get_weather_forecast, + annotations={"title": "Weather Forecast"}, +) + +ui.page_opts(title="Weather Tool - Annotations") +app_opts(bookmark_store="url") + +chat = Chat(id="chat") +chat.ui() + +chat.enable_bookmarking(client) + + +@chat.on_user_submit +async def handle_user_input(user_input: str): + response = await client.stream_async(user_input, content="all") + await chat.append_message_stream(response) diff --git a/pkg-py/tests/playwright/tools/weather/app_03_tool_result_simple.py b/pkg-py/tests/playwright/tools/weather/app_03_tool_result_simple.py new file mode 100644 index 00000000..5e6d047f --- /dev/null +++ b/pkg-py/tests/playwright/tools/weather/app_03_tool_result_simple.py @@ -0,0 +1,58 @@ +import faicons +from chatlas import ChatOpenAI, ContentToolResult +from shiny.express import app_opts, ui +from shinychat.express import Chat +from shinychat.types import ToolResultDisplay + +from . import tools + + +def get_weather_forecast( + lat: float, lon: float, location_name: str +) -> ContentToolResult: + """Get the weather forecast for a location.""" + forecast_data = tools.get_weather_forecast(lat, lon) + + # Determine icon based on temperature + if forecast_data["temperature_2m"] > 21: + icon = "sun" + elif forecast_data["temperature_2m"] < 7: + icon = "snowflake" + else: + icon = "cloud-sun" + + # Return ContentToolResult with extra display metadata + return ContentToolResult( + value=forecast_data, + extra={ + "display": ToolResultDisplay( + title=f"Weather Forecast for {location_name}", + icon=faicons.icon_svg(icon), + ) + }, + ) + + +# Create chat client and register tool +chat_client = ChatOpenAI(model="gpt-4.1-nano") +chat_client.register_tool( + get_weather_forecast, + annotations={"title": "Weather Forecast"}, +) + + +# The Shiny app +ui.page_opts(title="Weather Tool - Tool Result Simple") +app_opts(bookmark_store="url") + + +chat = Chat(id="chat") +chat.ui() + +chat.enable_bookmarking(chat_client) + + +@chat.on_user_submit +async def handle_user_input(user_input: str): + response = await chat_client.stream_async(user_input, content="all") + await chat.append_message_stream(response) diff --git a/pkg-py/tests/playwright/tools/weather/app_04_tool_result_table.py b/pkg-py/tests/playwright/tools/weather/app_04_tool_result_table.py new file mode 100644 index 00000000..af4ef8ed --- /dev/null +++ b/pkg-py/tests/playwright/tools/weather/app_04_tool_result_table.py @@ -0,0 +1,60 @@ +import os + +import pandas as pd +from chatlas import ChatOpenAI, ContentToolResult +from shiny.express import ui +from shinychat.express import Chat +from shinychat.types import ToolResultDisplay + +# Set environment variable for tool display +os.environ["SHINYCHAT_TOOL_DISPLAY"] = "rich" + + +def get_weather_forecast( + lat: float, lon: float, location_name: str +) -> ContentToolResult: + """Get the weather forecast for a location.""" + # Mock detailed forecast data as a pandas DataFrame (similar to R's data.frame) + forecast_data = pd.DataFrame( + { + "time": ["06:00", "12:00", "18:00", "24:00"], + "temperature": [65, 72, 68, 62], + "humidity": [70, 65, 72, 80], + "conditions": ["Clear", "Partly cloudy", "Cloudy", "Clear"], + "wind_speed": [5, 8, 6, 4], + } + ) + + # Convert DataFrame to HTML table + forecast_table = forecast_data.to_html( + index=False, classes="table table-striped" + ) + + # Return ContentToolResult with extra display metadata + return ContentToolResult( + value=forecast_table, + extra={ + "display": ToolResultDisplay( + html=ui.HTML(forecast_table), + title=f"Weather Forecast for {location_name}", + ) + }, + ) + + +client = ChatOpenAI(model="gpt-4.1-nano") +client.register_tool( + get_weather_forecast, + annotations={"title": "Weather Forecast"}, +) + +ui.page_opts(title="Weather Tool - Tool Result Table") + +chat = Chat(id="chat") +chat.ui() + + +@chat.on_user_submit +async def handle_user_input(user_input: str): + response = await client.stream_async(user_input, content="all") + await chat.append_message_stream(response) diff --git a/pkg-py/tests/playwright/tools/weather/app_05_tool_custom_result_class.py b/pkg-py/tests/playwright/tools/weather/app_05_tool_custom_result_class.py new file mode 100644 index 00000000..30113d97 --- /dev/null +++ b/pkg-py/tests/playwright/tools/weather/app_05_tool_custom_result_class.py @@ -0,0 +1,69 @@ +import pandas as pd +from chatlas import ChatOpenAI, ContentToolResult +from shiny.express import ui +from shinychat.express import Chat +from shinychat.types import ToolResultDisplay + + +class WeatherToolResult(ContentToolResult): + """ + Custom tool result class for weather forecasts. + + This example shows how to use a custom tool result class. In the R version, + this extends the contents_shinychat() generic to compute the HTML table on + the fly when rendering the result. In Python, we'll include the extra + rendering logic directly in the constructor. + """ + + def __init__(self, forecast_data, location_name: str, **kwargs): + # Create HTML table from the data + if isinstance(forecast_data, pd.DataFrame): + html_table = forecast_data.to_html( + index=False, classes="table table-striped" + ) + else: + html_table = str(forecast_data) # Fallback + + extra = { + "display": ToolResultDisplay( + html=ui.HTML(html_table), + title=f"Weather Forecast for {location_name}", + ) + } + + super().__init__(value=forecast_data, extra=extra, **kwargs) + + +def get_weather_forecast( + lat: float, lon: float, location_name: str +) -> WeatherToolResult: + """Get the weather forecast for a location.""" + # Mock detailed forecast data as a pandas DataFrame + forecast_data = pd.DataFrame( + { + "time": ["06:00", "12:00", "18:00", "24:00"], + "temperature": [65, 72, 68, 62], + "humidity": [70, 65, 72, 80], + "conditions": ["Clear", "Partly cloudy", "Cloudy", "Clear"], + "wind_speed": [5, 8, 6, 4], + } + ) + return WeatherToolResult(forecast_data, location_name) + + +client = ChatOpenAI(model="gpt-4.1-nano") +client.register_tool( + get_weather_forecast, + annotations={"title": "Weather Forecast"}, +) + +ui.page_opts(title="Weather Tool - Custom Result Class") + +chat = Chat(id="chat") +chat.ui() + + +@chat.on_user_submit +async def handle_user_input(user_input: str): + response = await client.stream_async(user_input, content="all") + await chat.append_message_stream(response) diff --git a/pkg-py/tests/playwright/tools/weather/app_06_tool_custom_output.py b/pkg-py/tests/playwright/tools/weather/app_06_tool_custom_output.py new file mode 100644 index 00000000..0644d5af --- /dev/null +++ b/pkg-py/tests/playwright/tools/weather/app_06_tool_custom_output.py @@ -0,0 +1,70 @@ +import faicons +import pandas as pd +from chatlas import ChatOpenAI, ContentToolResult +from shiny.express import ui +from shiny.ui import value_box +from shinychat import message_content_chunk +from shinychat.express import Chat +from shinychat.types import ChatMessage + + +class WeatherToolResult(ContentToolResult): + """ + Custom tool result class for weather forecasts with custom value box output. + This example shows how to use a custom tool result class that renders + the weather data as a custom UI component (value box in R, custom HTML in Python). + """ + + location_name: str + + +@message_content_chunk.register +def _(message: WeatherToolResult): + val = message.value + high_temp = str(val["temperature"].max()) + low_temp = str(val["temperature"].min()) + current = val.iloc[0] + + content = value_box( + message.location_name, + str(current["temperature"]), + f"{current['temperature']}°F (High: {high_temp}°F, Low: {low_temp}°F)", + showcase=faicons.icon_svg("sun"), + full_screen=True, + ) + return ChatMessage(content=content) + + +def get_weather_forecast( + lat: float, lon: float, location_name: str +) -> WeatherToolResult: + """Get the weather forecast for a location.""" + # Mock detailed forecast data as a pandas DataFrame + forecast_data = pd.DataFrame( + { + "time": ["Current", "06:00", "12:00", "18:00"], + "temperature": [68, 65, 72, 66], + "humidity": [75, 70, 65, 72], + "conditions": ["Partly cloudy", "Clear", "Sunny", "Cloudy"], + "wind_speed": [7, 5, 8, 6], + } + ) + return WeatherToolResult(value=forecast_data, location_name=location_name) + + +client = ChatOpenAI(model="gpt-4.1-nano") +client.register_tool( + get_weather_forecast, + annotations={"title": "Weather Forecast"}, +) + +ui.page_opts(title="Weather Tool - Custom Output") + +chat = Chat(id="chat") +chat.ui() + + +@chat.on_user_submit +async def handle_user_input(user_input: str): + response = await client.stream_async(user_input, content="all") + await chat.append_message_stream(response) diff --git a/pkg-py/tests/playwright/tools/weather/tools.py b/pkg-py/tests/playwright/tools/weather/tools.py new file mode 100644 index 00000000..a4fbd47f --- /dev/null +++ b/pkg-py/tests/playwright/tools/weather/tools.py @@ -0,0 +1,10 @@ +import requests + + +def get_weather_forecast(lat: float, lon: float) -> dict: + """Get the weather forecast for a location.""" + lat_lng = f"latitude={lat}&longitude={lon}" + url = f"https://api.open-meteo.com/v1/forecast?{lat_lng}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m" + response = requests.get(url) + json = response.json() + return json["current"] diff --git a/pyproject.toml b/pyproject.toml index 131bd921..98cdbd53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ Changelog = "https://github.com/posit-dev/shinychat/blob/main/pkg-py/CHANGELOG.m [project.optional-dependencies] providers = [ "anthropic;python_version>='3.11'", - "chatlas>=0.6.1", + "chatlas[mcp]>=0.11.1", "google-generativeai", "langchain-core", "ollama>=0.4.0",