Skip to content
Merged
1 change: 1 addition & 0 deletions pkg-py/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### New features

* New and improved UI for tool calls that occur via [chatlas](https://posit-dev.github.io/chatlas/). As a reminder, tool call displays are enabled by setting `content="all"` in chatlas' `.stream()` (or `.stream_async()`) method. See the tests under the `pkg-py/tests/playwright/tools` directory for inspiration of what is now possible with custom tool displays via the new `ToolResultDisplay` class. (#107)
* Added new `message_content()` and `message_content_chunk()` generic (`singledispatch`) functions. These functions aren't intended to be called directly by users, but instead, provide an opportunity to teach `Chat.append_message()`/`Chat.append_message_stream()` to extract message contents from different types of objects. (#96)

### Bug fixes
Expand Down
4 changes: 4 additions & 0 deletions pkg-py/src/shinychat/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
set_chatlas_state,
)
from ._chat_normalize import message_content, message_content_chunk
from ._chat_normalize_chatlas import hide_corresponding_request, is_tool_result
from ._chat_provider_types import (
AnthropicMessage, # pyright: ignore[reportAttributeAccessIssue]
GoogleMessage,
Expand Down Expand Up @@ -754,6 +755,9 @@ async def _append_message_chunk(
# Normalize various message types into a ChatMessage()
msg = message_content_chunk(message)

if is_tool_result(message):
await hide_corresponding_request(message)

if operation == "replace":
self._current_stream_message = (
self._message_stream_checkpoint + msg.content
Expand Down
111 changes: 85 additions & 26 deletions pkg-py/src/shinychat/_chat_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import sys
from functools import singledispatch

from htmltools import HTML, Tagifiable
from htmltools import HTML, Tagifiable, TagList

from ._chat_normalize_chatlas import tool_request_contents, tool_result_contents
from ._chat_types import ChatMessage

__all__ = ["message_content", "message_content_chunk"]


@singledispatch
def message_content(message) -> ChatMessage:
def message_content(message):
"""
Extract content from various message types into a ChatMessage.

Expand Down Expand Up @@ -42,7 +43,7 @@ def message_content(message) -> ChatMessage:
If the message type is unsupported.
"""
if isinstance(message, (str, HTML)) or message is None:
return ChatMessage(content=message, role="assistant")
return ChatMessage(content=message)
if isinstance(message, dict):
if "content" not in message:
raise ValueError("Message dictionary must have a 'content' key")
Expand All @@ -57,7 +58,7 @@ def message_content(message) -> ChatMessage:


@singledispatch
def message_content_chunk(chunk) -> ChatMessage:
def message_content_chunk(chunk):
"""
Extract content from various message chunk types into a ChatMessage.

Expand Down Expand Up @@ -88,7 +89,7 @@ def message_content_chunk(chunk) -> ChatMessage:
If the chunk type is unsupported.
"""
if isinstance(chunk, (str, HTML)) or chunk is None:
return ChatMessage(content=chunk, role="assistant")
return ChatMessage(content=chunk)
if isinstance(chunk, dict):
if "content" not in chunk:
raise ValueError("Chunk dictionary must have a 'content' key")
Expand All @@ -108,15 +109,71 @@ def message_content_chunk(chunk) -> ChatMessage:


@message_content.register
def _(message: Tagifiable) -> ChatMessage:
return ChatMessage(content=message, role="assistant")
def _(message: Tagifiable):
return ChatMessage(content=message)


@message_content_chunk.register
def _(chunk: Tagifiable) -> ChatMessage:
return ChatMessage(content=chunk, role="assistant")
def _(chunk: Tagifiable):
return ChatMessage(content=chunk)


# -----------------------------------------------------------------
# chatlas tool call display
# -----------------------------------------------------------------
try:
from chatlas import ContentToolRequest, ContentToolResult, Turn
from chatlas.types import Content, ContentText

@message_content.register
def _(message: Content):
return ChatMessage(content=str(message))

@message_content_chunk.register
def _(chunk: Content):
return message_content(chunk)

@message_content.register
def _(message: ContentText):
return ChatMessage(content=message.text)

@message_content_chunk.register
def _(chunk: ContentText):
return message_content(chunk)

@message_content.register
def _(chunk: ContentToolRequest):
return ChatMessage(content=tool_request_contents(chunk))

@message_content_chunk.register
def _(chunk: ContentToolRequest):
return message_content(chunk)

@message_content.register
def _(chunk: ContentToolResult):
return ChatMessage(content=tool_result_contents(chunk))

@message_content_chunk.register
def _(chunk: ContentToolResult):
return message_content(chunk)

@message_content.register
def _(message: Turn):
contents = TagList()
for x in message.contents:
contents.append(message_content(x).content)
return ChatMessage(content=contents)

@message_content_chunk.register
def _(chunk: Turn):
return message_content(chunk)

# N.B., unlike R, Python Chat stores UI state and so can replay
# it with additional workarounds. That's why R currently has a
# shinychat_contents() method for Chat, but Python doesn't.
except ImportError:
pass

# ------------------------------------------------------------------
# LangChain content extractor
# ------------------------------------------------------------------
Expand All @@ -125,7 +182,7 @@ def _(chunk: Tagifiable) -> ChatMessage:
from langchain_core.messages import BaseMessage, BaseMessageChunk

@message_content.register
def _(message: BaseMessage) -> ChatMessage:
def _(message: BaseMessage):
if isinstance(message.content, list):
raise ValueError(
"The `message.content` provided seems to represent numerous messages. "
Expand All @@ -137,7 +194,7 @@ def _(message: BaseMessage) -> ChatMessage:
)

@message_content_chunk.register
def _(chunk: BaseMessageChunk) -> ChatMessage:
def _(chunk: BaseMessageChunk):
if isinstance(chunk.content, list):
raise ValueError(
"The `chunk.content` provided seems to represent numerous message chunks. "
Expand All @@ -159,14 +216,14 @@ def _(chunk: BaseMessageChunk) -> ChatMessage:
from openai.types.chat import ChatCompletion, ChatCompletionChunk

@message_content.register
def _(message: ChatCompletion) -> ChatMessage:
def _(message: ChatCompletion):
return ChatMessage(
content=message.choices[0].message.content,
role="assistant",
)

@message_content_chunk.register
def _(chunk: ChatCompletionChunk) -> ChatMessage:
def _(chunk: ChatCompletionChunk):
return ChatMessage(
content=chunk.choices[0].delta.content,
role="assistant",
Expand All @@ -185,21 +242,23 @@ def _(chunk: ChatCompletionChunk) -> ChatMessage:
)

@message_content.register
def _(message: AnthropicMessage) -> ChatMessage:
def _(message: AnthropicMessage):
content = message.content[0]
if content.type != "text":
raise ValueError(
f"Anthropic message type {content.type} not supported. "
"Only 'text' type is currently supported"
)
return ChatMessage(content=content.text, role="assistant")
return ChatMessage(content=content.text)

# Old versions of singledispatch doesn't seem to support union types
if sys.version_info >= (3, 11):
from anthropic.types import RawMessageStreamEvent
from anthropic.types import ( # pyright: ignore[reportMissingImports]
RawMessageStreamEvent,
)

@message_content_chunk.register
def _(chunk: RawMessageStreamEvent) -> ChatMessage:
def _(chunk: RawMessageStreamEvent):
content = ""
if chunk.type == "content_block_delta":
if chunk.delta.type != "text_delta":
Expand All @@ -209,7 +268,7 @@ def _(chunk: RawMessageStreamEvent) -> ChatMessage:
)
content = chunk.delta.text

return ChatMessage(content=content, role="assistant")
return ChatMessage(content=content)
except ImportError:
pass

Expand All @@ -224,12 +283,12 @@ def _(chunk: RawMessageStreamEvent) -> ChatMessage:
)

@message_content.register
def _(message: GenerateContentResponse) -> ChatMessage:
return ChatMessage(content=message.text, role="assistant")
def _(message: GenerateContentResponse):
return ChatMessage(content=message.text)

@message_content_chunk.register
def _(chunk: GenerateContentResponse) -> ChatMessage:
return ChatMessage(content=chunk.text, role="assistant")
def _(chunk: GenerateContentResponse):
return ChatMessage(content=chunk.text)

except ImportError:
pass
Expand All @@ -243,14 +302,14 @@ def _(chunk: GenerateContentResponse) -> ChatMessage:
from ollama import ChatResponse

@message_content.register
def _(message: ChatResponse) -> ChatMessage:
def _(message: ChatResponse):
msg = message.message
return ChatMessage(msg.content, role="assistant")
return ChatMessage(msg.content)

@message_content_chunk.register
def _(chunk: ChatResponse) -> ChatMessage:
def _(chunk: ChatResponse):
msg = chunk.message
return ChatMessage(msg.content, role="assistant")
return ChatMessage(msg.content)

except ImportError:
pass
Loading