Skip to content

Commit c84ff94

Browse files
committed
Fix for Anthropic; cleanup
1 parent 3bae3d3 commit c84ff94

File tree

3 files changed

+29
-14
lines changed

3 files changed

+29
-14
lines changed

pkg-py/src/shinychat/_chat_normalize.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,25 @@ def _(chunk: ChatCompletionChunk) -> ChatMessage:
180180

181181
try:
182182
from anthropic.types import Message as AnthropicMessage
183-
from anthropic.types import MessageStreamEvent
183+
from anthropic.types import (
184+
RawContentBlockDeltaEvent,
185+
RawContentBlockStartEvent,
186+
RawContentBlockStopEvent,
187+
RawMessageDeltaEvent,
188+
RawMessageStartEvent,
189+
RawMessageStopEvent,
190+
)
191+
192+
# Create a non-annotated type alias for RawMessageStreamEvent
193+
# (so it works with singledispatch)
194+
RawMessageStreamEvent = (
195+
RawMessageStartEvent
196+
| RawMessageDeltaEvent
197+
| RawMessageStopEvent
198+
| RawContentBlockStartEvent
199+
| RawContentBlockDeltaEvent
200+
| RawContentBlockStopEvent
201+
)
184202

185203
@get_message_content.register
186204
def _(message: AnthropicMessage) -> ChatMessage:
@@ -193,7 +211,7 @@ def _(message: AnthropicMessage) -> ChatMessage:
193211
return ChatMessage(content=content.text, role="assistant")
194212

195213
@get_message_chunk_content.register
196-
def _(chunk: MessageStreamEvent) -> ChatMessage:
214+
def _(chunk: RawMessageStreamEvent) -> ChatMessage:
197215
content = ""
198216
if chunk.type == "content_block_delta":
199217
if chunk.delta.type != "text_delta":

pkg-py/tests/pytest/test_chat.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def test_ollama_normalization():
398398
def test_as_anthropic_message():
399399
from anthropic.resources.messages import AsyncMessages, Messages
400400
from anthropic.types import MessageParam
401-
from shiny.ui._chat_provider_types import as_anthropic_message
401+
from shinychat._chat_provider_types import as_anthropic_message
402402

403403
# Make sure return type of llm.messages.create() hasn't changed
404404
assert (
@@ -416,7 +416,7 @@ def test_as_anthropic_message():
416416

417417

418418
def test_as_google_message():
419-
from shiny.ui._chat_provider_types import as_google_message
419+
from shinychat._chat_provider_types import as_google_message
420420

421421
# Not available for Python 3.8
422422
if sys.version_info < (3, 9):
@@ -460,7 +460,7 @@ def test_as_langchain_message():
460460
MessageLikeRepresentation,
461461
SystemMessage,
462462
)
463-
from shiny.ui._chat_provider_types import as_langchain_message
463+
from shinychat._chat_provider_types import as_langchain_message
464464

465465
assert BaseChatModel.invoke.__annotations__["input"] == "LanguageModelInput"
466466
assert BaseChatModel.stream.__annotations__["input"] == "LanguageModelInput"
@@ -491,7 +491,7 @@ def test_as_openai_message():
491491
ChatCompletionSystemMessageParam,
492492
ChatCompletionUserMessageParam,
493493
)
494-
from shiny.ui._chat_provider_types import as_openai_message
494+
from shinychat._chat_provider_types import as_openai_message
495495

496496
assert (
497497
Completions.create.__annotations__["messages"]
@@ -523,14 +523,11 @@ def test_as_ollama_message():
523523
import ollama
524524
from ollama import Message as OllamaMessage
525525

526-
# ollama 0.4.2 added Callable to the type hints, but pyright complains about
527-
# missing arguments to the Callable type. We'll ignore this for now.
528-
# https://github.com/ollama/ollama-python/commit/b50a65b
529-
chat = ollama.chat # type: ignore
530-
531-
assert "ollama._types.Message" in str(chat.__annotations__["messages"])
526+
assert "ollama._types.Message" in str(
527+
ollama.chat.__annotations__["messages"]
528+
)
532529

533-
from shiny.ui._chat_provider_types import as_ollama_message
530+
from shinychat._chat_provider_types import as_ollama_message
534531

535532
msg = ChatMessageDict(content="I have a question", role="user")
536533
assert as_ollama_message(msg) == OllamaMessage(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ providers = [
2828
"chatlas>=0.6.1",
2929
"google-generativeai;python_version>='3.9'",
3030
"langchain-core",
31-
"ollama",
31+
"ollama>=0.4.0",
3232
"openai",
3333
"tokenizers",
3434
]

0 commit comments

Comments
 (0)