Skip to content

Commit 6a7456d

Browse files
zachmayerDouweM
authored andcommitted
Add MoonshotAI provider with Kimi-K2 model support (pydantic#2211)
Co-authored-by: Douwe Maan <[email protected]>
1 parent edc48a5 commit 6a7456d

File tree

12 files changed

+246
-1
lines changed

12 files changed

+246
-1
lines changed

docs/models/openai.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,24 @@ agent = Agent(model)
401401
...
402402
```
403403

404+
### MoonshotAI
405+
406+
Create an API key in the [Moonshot Console](https://platform.moonshot.ai/console).
407+
With that key you can instantiate the [`MoonshotAIProvider`][pydantic_ai.providers.moonshotai.MoonshotAIProvider]:
408+
409+
```python
410+
from pydantic_ai import Agent
411+
from pydantic_ai.models.openai import OpenAIModel
412+
from pydantic_ai.providers.moonshotai import MoonshotAIProvider
413+
414+
model = OpenAIModel(
415+
'kimi-k2-0711-preview',
416+
provider=MoonshotAIProvider(api_key='your-moonshot-api-key'),
417+
)
418+
agent = Agent(model)
419+
...
420+
```
421+
404422
### GitHub Models
405423

406424
To use [GitHub Models](https://docs.github.com/en/github-models), you'll need a GitHub personal access token with the `models: read` permission.

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,15 @@
233233
'mistral:mistral-large-latest',
234234
'mistral:mistral-moderation-latest',
235235
'mistral:mistral-small-latest',
236+
'moonshotai:moonshot-v1-8k',
237+
'moonshotai:moonshot-v1-32k',
238+
'moonshotai:moonshot-v1-128k',
239+
'moonshotai:moonshot-v1-8k-vision-preview',
240+
'moonshotai:moonshot-v1-32k-vision-preview',
241+
'moonshotai:moonshot-v1-128k-vision-preview',
242+
'moonshotai:kimi-latest',
243+
'moonshotai:kimi-thinking-preview',
244+
'moonshotai:kimi-k2-0711-preview',
236245
'o1',
237246
'o1-2024-12-17',
238247
'o1-mini',
@@ -617,6 +626,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
617626
'openrouter',
618627
'vercel',
619628
'grok',
629+
'moonshotai',
620630
'fireworks',
621631
'together',
622632
'heroku',

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def __init__(
195195
'deepseek',
196196
'azure',
197197
'openrouter',
198+
'moonshotai',
198199
'vercel',
199200
'grok',
200201
'fireworks',
@@ -299,7 +300,10 @@ async def _completions_create(
299300
tools = self._get_tools(model_request_parameters)
300301
if not tools:
301302
tool_choice: Literal['none', 'required', 'auto'] | None = None
302-
elif not model_request_parameters.allow_text_output:
303+
elif (
304+
not model_request_parameters.allow_text_output
305+
and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required
306+
):
303307
tool_choice = 'required'
304308
else:
305309
tool_choice = 'auto'

pydantic_ai_slim/pydantic_ai/profiles/openai.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ class OpenAIModelProfile(ModelProfile):
2121
openai_supports_sampling_settings: bool = True
2222
"""Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
2323

24+
# Some OpenAI-compatible providers (e.g. MoonshotAI) currently do **not** accept
25+
# `tool_choice="required"`. This flag lets the calling model know whether it's
26+
# safe to pass that value along. Default is `True` to preserve existing
27+
# behaviour for OpenAI itself and most providers.
28+
openai_supports_tool_choice_required: bool = True
29+
"""Whether the provider accepts the value ``tool_choice='required'`` in the
30+
request payload."""
31+
2432

2533
def openai_model_profile(model_name: str) -> ModelProfile:
2634
"""Get the model profile for an OpenAI model."""

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
103103
from .grok import GrokProvider
104104

105105
return GrokProvider
106+
elif provider == 'moonshotai':
107+
from .moonshotai import MoonshotAIProvider
108+
109+
return MoonshotAIProvider
106110
elif provider == 'fireworks':
107111
from .fireworks import FireworksProvider
108112

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import Literal, overload
5+
6+
from httpx import AsyncClient as AsyncHTTPClient
7+
from openai import AsyncOpenAI
8+
9+
from pydantic_ai.exceptions import UserError
10+
from pydantic_ai.models import cached_async_http_client
11+
from pydantic_ai.profiles import ModelProfile
12+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
13+
from pydantic_ai.profiles.openai import (
14+
OpenAIJsonSchemaTransformer,
15+
OpenAIModelProfile,
16+
)
17+
from pydantic_ai.providers import Provider
18+
19+
MoonshotAIModelName = Literal[
20+
'moonshot-v1-8k',
21+
'moonshot-v1-32k',
22+
'moonshot-v1-128k',
23+
'moonshot-v1-8k-vision-preview',
24+
'moonshot-v1-32k-vision-preview',
25+
'moonshot-v1-128k-vision-preview',
26+
'kimi-latest',
27+
'kimi-thinking-preview',
28+
'kimi-k2-0711-preview',
29+
]
30+
31+
32+
class MoonshotAIProvider(Provider[AsyncOpenAI]):
33+
"""Provider for MoonshotAI platform (Kimi models)."""
34+
35+
@property
36+
def name(self) -> str:
37+
return 'moonshotai'
38+
39+
@property
40+
def base_url(self) -> str:
41+
# OpenAI-compatible endpoint, see MoonshotAI docs
42+
return 'https://api.moonshot.ai/v1'
43+
44+
@property
45+
def client(self) -> AsyncOpenAI:
46+
return self._client
47+
48+
def model_profile(self, model_name: str) -> ModelProfile | None:
49+
profile = moonshotai_model_profile(model_name)
50+
51+
# As the MoonshotAI API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
52+
# unless json_schema_transformer is set explicitly.
53+
# Also, MoonshotAI does not support strict tool definitions
54+
# https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-tool_choice
55+
# "Please note that the current version of Kimi API does not support the tool_choice=required parameter."
56+
return OpenAIModelProfile(
57+
json_schema_transformer=OpenAIJsonSchemaTransformer,
58+
openai_supports_tool_choice_required=False,
59+
supports_json_object_output=True,
60+
).update(profile)
61+
62+
# ---------------------------------------------------------------------
63+
# Construction helpers
64+
# ---------------------------------------------------------------------
65+
@overload
66+
def __init__(self) -> None: ...
67+
68+
@overload
69+
def __init__(self, *, api_key: str) -> None: ...
70+
71+
@overload
72+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
73+
74+
@overload
75+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
76+
77+
def __init__(
78+
self,
79+
*,
80+
api_key: str | None = None,
81+
openai_client: AsyncOpenAI | None = None,
82+
http_client: AsyncHTTPClient | None = None,
83+
) -> None:
84+
api_key = api_key or os.getenv('MOONSHOTAI_API_KEY')
85+
if not api_key and openai_client is None:
86+
raise UserError(
87+
'Set the `MOONSHOTAI_API_KEY` environment variable or pass it via '
88+
'`MoonshotAIProvider(api_key=...)` to use the MoonshotAI provider.'
89+
)
90+
91+
if openai_client is not None:
92+
self._client = openai_client
93+
elif http_client is not None:
94+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
95+
else:
96+
http_client = cached_async_http_client(provider='moonshotai')
97+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)

tests/models/test_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@
7272
'github',
7373
'OpenAIModel',
7474
),
75+
(
76+
'MOONSHOTAI_API_KEY',
77+
'moonshotai:kimi-k2-0711-preview',
78+
'kimi-k2-0711-preview',
79+
'moonshotai',
80+
'moonshotai',
81+
'OpenAIModel',
82+
),
7583
(
7684
'GROK_API_KEY',
7785
'grok:grok-3',

tests/models/test_model_names.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pydantic_ai.models.mistral import MistralModelName
2121
from pydantic_ai.models.openai import OpenAIModelName
2222
from pydantic_ai.providers.grok import GrokModelName
23+
from pydantic_ai.providers.moonshotai import MoonshotAIModelName
2324

2425
pytestmark = [
2526
pytest.mark.skipif(not imports_successful(), reason='some model package was not installed'),
@@ -51,6 +52,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
5152
]
5253
grok_names = [f'grok:{n}' for n in get_model_names(GrokModelName)]
5354
groq_names = [f'groq:{n}' for n in get_model_names(GroqModelName)]
55+
moonshotai_names = [f'moonshotai:{n}' for n in get_model_names(MoonshotAIModelName)]
5456
mistral_names = [f'mistral:{n}' for n in get_model_names(MistralModelName)]
5557
openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] + [
5658
n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt') or n.startswith('o3')
@@ -68,6 +70,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
6870
+ grok_names
6971
+ groq_names
7072
+ mistral_names
73+
+ moonshotai_names
7174
+ openai_names
7275
+ bedrock_names
7376
+ deepseek_names

tests/models/test_openai.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ToolReturnPart,
3737
UserPromptPart,
3838
)
39+
from pydantic_ai.models import ModelRequestParameters
3940
from pydantic_ai.models.gemini import GeminiModel
4041
from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput
4142
from pydantic_ai.profiles import ModelProfile
@@ -2631,3 +2632,22 @@ async def test_process_response_no_created_timestamp(allow_model_requests: None)
26312632
response_message = messages[1]
26322633
assert isinstance(response_message, ModelResponse)
26332634
assert response_message.timestamp == IsNow(tz=timezone.utc)
2635+
2636+
2637+
@pytest.mark.anyio()
2638+
async def test_tool_choice_fallback(allow_model_requests: None) -> None:
2639+
profile = OpenAIModelProfile(openai_supports_tool_choice_required=False).update(openai_model_profile('stub'))
2640+
2641+
mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant')))
2642+
model = OpenAIModel('stub', provider=OpenAIProvider(openai_client=mock_client), profile=profile)
2643+
2644+
params = ModelRequestParameters(function_tools=[ToolDefinition(name='x')], allow_text_output=False)
2645+
2646+
await model._completions_create( # pyright: ignore[reportPrivateUsage]
2647+
messages=[],
2648+
stream=False,
2649+
model_settings={},
2650+
model_request_parameters=params,
2651+
)
2652+
2653+
assert get_mock_chat_completion_kwargs(mock_client)[0]['tool_choice'] == 'auto'

tests/providers/test_moonshotai.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import re
2+
3+
import httpx
4+
import pytest
5+
6+
from pydantic_ai.exceptions import UserError
7+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
8+
9+
from ..conftest import TestEnv, try_import
10+
11+
with try_import() as imports_successful:
12+
import openai
13+
14+
from pydantic_ai.models.openai import OpenAIModel
15+
from pydantic_ai.providers.moonshotai import MoonshotAIProvider
16+
17+
pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')
18+
19+
20+
def test_moonshotai_provider():
21+
"""Test basic MoonshotAI provider initialization."""
22+
provider = MoonshotAIProvider(api_key='api-key')
23+
assert provider.name == 'moonshotai'
24+
assert provider.base_url == 'https://api.moonshot.ai/v1'
25+
assert isinstance(provider.client, openai.AsyncOpenAI)
26+
assert provider.client.api_key == 'api-key'
27+
28+
29+
def test_moonshotai_provider_need_api_key(env: TestEnv) -> None:
30+
"""Test that MoonshotAI provider requires an API key."""
31+
env.remove('MOONSHOTAI_API_KEY')
32+
with pytest.raises(
33+
UserError,
34+
match=re.escape(
35+
'Set the `MOONSHOTAI_API_KEY` environment variable or pass it via `MoonshotAIProvider(api_key=...)`'
36+
' to use the MoonshotAI provider.'
37+
),
38+
):
39+
MoonshotAIProvider()
40+
41+
42+
def test_moonshotai_provider_pass_http_client() -> None:
43+
"""Test passing a custom HTTP client to MoonshotAI provider."""
44+
http_client = httpx.AsyncClient()
45+
provider = MoonshotAIProvider(http_client=http_client, api_key='api-key')
46+
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]
47+
48+
49+
def test_moonshotai_pass_openai_client() -> None:
50+
"""Test passing a custom OpenAI client to MoonshotAI provider."""
51+
openai_client = openai.AsyncOpenAI(api_key='api-key')
52+
provider = MoonshotAIProvider(openai_client=openai_client)
53+
assert provider.client == openai_client
54+
55+
56+
def test_moonshotai_provider_with_cached_http_client() -> None:
57+
"""Test MoonshotAI provider using cached HTTP client (covers line 76)."""
58+
# This should use the else branch with cached_async_http_client
59+
provider = MoonshotAIProvider(api_key='api-key')
60+
assert isinstance(provider.client, openai.AsyncOpenAI)
61+
assert provider.client.api_key == 'api-key'
62+
63+
64+
def test_moonshotai_model_profile():
65+
provider = MoonshotAIProvider(api_key='api-key')
66+
model = OpenAIModel('kimi-k2-0711-preview', provider=provider)
67+
assert isinstance(model.profile, OpenAIModelProfile)
68+
assert model.profile.json_schema_transformer == OpenAIJsonSchemaTransformer
69+
assert model.profile.openai_supports_tool_choice_required is False
70+
assert model.profile.supports_json_object_output is True

0 commit comments

Comments
 (0)