Skip to content

Commit dcd7d87

Browse files
committed
feat: add MoonshotAI provider with Kimi-K2 model support
- Add MoonshotAIProvider with OpenAI-compatible API - Implements OpenAI-style interface with custom base URL - Supports tool definitions but not strict tool validation - Add moonshotai:kimi-k2-0711-preview as known model - Configure to use OpenAIModel for compatibility - Add comprehensive tests for provider functionality - Update CLI and model name tests
1 parent 883e1ea commit dcd7d87

File tree

10 files changed

+204
-2
lines changed

10 files changed

+204
-2
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@
286286
'openai:o3-mini-2025-01-31',
287287
'openai:o4-mini',
288288
'openai:o4-mini-2025-04-16',
289+
'openai:computer-use-preview-2025-03-11',
290+
'moonshotai:kimi-k2-0711-preview',
289291
'test',
290292
],
291293
)
@@ -588,6 +590,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
588590
'azure',
589591
'openrouter',
590592
'grok',
593+
'moonshotai',
591594
'fireworks',
592595
'together',
593596
'heroku',

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,16 @@ def __init__(
190190
model_name: OpenAIModelName,
191191
*,
192192
provider: Literal[
193-
'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github'
193+
'openai',
194+
'deepseek',
195+
'azure',
196+
'openrouter',
197+
'grok',
198+
'moonshotai',
199+
'fireworks',
200+
'together',
201+
'heroku',
202+
'github',
194203
]
195204
| Provider[AsyncOpenAI] = 'openai',
196205
profile: ModelProfileSpec | None = None,
@@ -598,7 +607,18 @@ def __init__(
598607
self,
599608
model_name: OpenAIModelName,
600609
*,
601-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
610+
provider: Literal[
611+
'openai',
612+
'deepseek',
613+
'azure',
614+
'openrouter',
615+
'grok',
616+
'moonshotai',
617+
'fireworks',
618+
'together',
619+
'heroku',
620+
'github',
621+
]
602622
| Provider[AsyncOpenAI] = 'openai',
603623
profile: ModelProfileSpec | None = None,
604624
settings: ModelSettings | None = None,
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from __future__ import annotations as _annotations
2+
3+
from . import ModelProfile
4+
5+
6+
def moonshotai_model_profile(model_name: str) -> ModelProfile | None:
7+
"""Get the model profile for a MoonshotAI model."""
8+
return None

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
9999
from .grok import GrokProvider
100100

101101
return GrokProvider
102+
elif provider == 'moonshotai':
103+
from .moonshotai import MoonshotAIProvider
104+
105+
return MoonshotAIProvider
102106
elif provider == 'fireworks':
103107
from .fireworks import FireworksProvider
104108

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
from httpx import AsyncClient as AsyncHTTPClient
7+
from openai import AsyncOpenAI
8+
9+
from pydantic_ai.exceptions import UserError
10+
from pydantic_ai.models import cached_async_http_client
11+
from pydantic_ai.profiles import ModelProfile
12+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
13+
from pydantic_ai.profiles.openai import (
14+
OpenAIJsonSchemaTransformer,
15+
OpenAIModelProfile,
16+
)
17+
from pydantic_ai.providers import Provider
18+
19+
20+
class MoonshotAIProvider(Provider[AsyncOpenAI]):
21+
"""Provider for MoonshotAI platform (Kimi models)."""
22+
23+
@property
24+
def name(self) -> str:
25+
return 'moonshotai'
26+
27+
@property
28+
def base_url(self) -> str:
29+
# OpenAI-compatible endpoint, see MoonshotAI docs
30+
return 'https://api.moonshot.ai/v1'
31+
32+
@property
33+
def client(self) -> AsyncOpenAI:
34+
return self._client
35+
36+
def model_profile(self, model_name: str) -> ModelProfile | None:
37+
profile = moonshotai_model_profile(model_name)
38+
39+
# As the MoonshotAI API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
40+
# unless json_schema_transformer is set explicitly.
41+
# Also, MoonshotAI does not support strict tool definitions
42+
# https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-tool_choice
43+
# "Please note that the current version of Kimi API does not support the tool_choice=required parameter."
44+
return OpenAIModelProfile(
45+
json_schema_transformer=OpenAIJsonSchemaTransformer,
46+
openai_supports_strict_tool_definition=False,
47+
).update(profile)
48+
49+
# ---------------------------------------------------------------------
50+
# Construction helpers
51+
# ---------------------------------------------------------------------
52+
@overload
53+
def __init__(self) -> None: ...
54+
55+
@overload
56+
def __init__(self, *, api_key: str) -> None: ...
57+
58+
@overload
59+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
60+
61+
@overload
62+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
63+
64+
def __init__(
65+
self,
66+
*,
67+
api_key: str | None = None,
68+
openai_client: AsyncOpenAI | None = None,
69+
http_client: AsyncHTTPClient | None = None,
70+
) -> None:
71+
api_key = api_key or os.getenv('MOONSHOT_API_KEY')
72+
if not api_key and openai_client is None:
73+
raise UserError(
74+
'Set the `MOONSHOT_API_KEY` environment variable or pass it via '
75+
'`MoonshotAIProvider(api_key=...)` to use the MoonshotAI provider.'
76+
)
77+
78+
if openai_client is not None:
79+
self._client = openai_client
80+
elif http_client is not None:
81+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
82+
else:
83+
http_client = cached_async_http_client(provider='moonshotai')
84+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)

tests/models/test_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@
7272
'github',
7373
'OpenAIModel',
7474
),
75+
(
76+
'MOONSHOT_API_KEY',
77+
'moonshotai:kimi-k2-0711-preview',
78+
'kimi-k2-0711-preview',
79+
'moonshotai',
80+
'openai',
81+
'OpenAIModel',
82+
),
7583
]
7684

7785

tests/models/test_model_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
4949
f'google-vertex:{n}' for n in get_model_names(GeminiModelName)
5050
]
5151
groq_names = [f'groq:{n}' for n in get_model_names(GroqModelName)]
52+
moonshotai_names = ['moonshotai:kimi-k2-0711-preview']
5253
mistral_names = [f'mistral:{n}' for n in get_model_names(MistralModelName)]
5354
openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] + [
5455
n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt') or n.startswith('o3')
@@ -64,6 +65,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
6465
+ cohere_names
6566
+ google_names
6667
+ groq_names
68+
+ moonshotai_names
6769
+ mistral_names
6870
+ openai_names
6971
+ bedrock_names

tests/providers/test_moonshotai.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import re
2+
3+
import httpx
4+
import pytest
5+
6+
from pydantic_ai.exceptions import UserError
7+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
8+
9+
from ..conftest import TestEnv, try_import
10+
11+
with try_import() as imports_successful:
12+
import openai
13+
14+
from pydantic_ai.models.openai import OpenAIModel
15+
from pydantic_ai.providers.moonshotai import MoonshotAIProvider
16+
17+
pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')
18+
19+
20+
def test_moonshotai_provider():
21+
"""Test basic MoonshotAI provider initialization."""
22+
provider = MoonshotAIProvider(api_key='api-key')
23+
assert provider.name == 'moonshotai'
24+
assert provider.base_url == 'https://api.moonshot.ai/v1'
25+
assert isinstance(provider.client, openai.AsyncOpenAI)
26+
assert provider.client.api_key == 'api-key'
27+
28+
29+
def test_moonshotai_provider_need_api_key(env: TestEnv) -> None:
30+
"""Test that MoonshotAI provider requires an API key."""
31+
env.remove('MOONSHOT_API_KEY')
32+
with pytest.raises(
33+
UserError,
34+
match=re.escape(
35+
'Set the `MOONSHOT_API_KEY` environment variable or pass it via `MoonshotAIProvider(api_key=...)`'
36+
' to use the MoonshotAI provider.'
37+
),
38+
):
39+
MoonshotAIProvider()
40+
41+
42+
def test_moonshotai_provider_pass_http_client() -> None:
43+
"""Test passing a custom HTTP client to MoonshotAI provider."""
44+
http_client = httpx.AsyncClient()
45+
provider = MoonshotAIProvider(http_client=http_client, api_key='api-key')
46+
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]
47+
48+
49+
def test_moonshotai_pass_openai_client() -> None:
50+
"""Test passing a custom OpenAI client to MoonshotAI provider."""
51+
openai_client = openai.AsyncOpenAI(api_key='api-key')
52+
provider = MoonshotAIProvider(openai_client=openai_client)
53+
assert provider.client == openai_client
54+
55+
56+
def test_moonshotai_provider_with_cached_http_client() -> None:
57+
"""Test MoonshotAI provider using cached HTTP client (covers line 76)."""
58+
# This should use the else branch with cached_async_http_client
59+
provider = MoonshotAIProvider(api_key='api-key')
60+
assert isinstance(provider.client, openai.AsyncOpenAI)
61+
assert provider.client.api_key == 'api-key'
62+
63+
64+
def test_moonshotai_model_profile():
65+
provider = MoonshotAIProvider(api_key='api-key')
66+
model = OpenAIModel('kimi-k2-0711-preview', provider=provider)
67+
assert isinstance(model.profile, OpenAIModelProfile)
68+
assert model.profile.json_schema_transformer == OpenAIJsonSchemaTransformer
69+
assert model.profile.openai_supports_strict_tool_definition is False

tests/providers/test_provider_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pydantic_ai.providers.groq import GroqProvider
2727
from pydantic_ai.providers.heroku import HerokuProvider
2828
from pydantic_ai.providers.mistral import MistralProvider
29+
from pydantic_ai.providers.moonshotai import MoonshotAIProvider
2930
from pydantic_ai.providers.openai import OpenAIProvider
3031
from pydantic_ai.providers.openrouter import OpenRouterProvider
3132
from pydantic_ai.providers.together import TogetherProvider
@@ -42,6 +43,7 @@
4243
('groq', GroqProvider, 'GROQ_API_KEY'),
4344
('mistral', MistralProvider, 'MISTRAL_API_KEY'),
4445
('grok', GrokProvider, 'GROK_API_KEY'),
46+
('moonshotai', MoonshotAIProvider, 'MOONSHOT_API_KEY'),
4547
('fireworks', FireworksProvider, 'FIREWORKS_API_KEY'),
4648
('together', TogetherProvider, 'TOGETHER_API_KEY'),
4749
('heroku', HerokuProvider, 'HEROKU_INFERENCE_KEY'),

tests/test_cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def test_list_models(capfd: CaptureFixture[str]):
144144
'cohere',
145145
'deepseek',
146146
'heroku',
147+
'grok',
148+
'moonshotai',
147149
'huggingface',
148150
)
149151
models = {line.strip().split(' ')[0] for line in output[3:]}

0 commit comments

Comments
 (0)