From 5211ea28612f1d9350997e2132b3a877c1d0a6e5 Mon Sep 17 00:00:00 2001 From: zorro30 Date: Wed, 14 May 2025 11:10:30 +0530 Subject: [PATCH] Added support for azure Open AI --- src/mcp_agent/cli/commands/setup.py | 5 ++ src/mcp_agent/config.py | 15 +++++ src/mcp_agent/llm/model_factory.py | 5 ++ src/mcp_agent/llm/provider_key_manager.py | 1 + src/mcp_agent/llm/provider_types.py | 1 + .../providers/augmented_llm_azure_openai.py | 58 +++++++++++++++++++ 6 files changed, 85 insertions(+) create mode 100644 src/mcp_agent/llm/providers/augmented_llm_azure_openai.py diff --git a/src/mcp_agent/cli/commands/setup.py b/src/mcp_agent/cli/commands/setup.py index cbe1e3d8..212c1d53 100644 --- a/src/mcp_agent/cli/commands/setup.py +++ b/src/mcp_agent/cli/commands/setup.py @@ -66,6 +66,11 @@ api_key: openrouter: api_key: +azure_openai: + api_key: + base_url: + api_version: + deployment_id: # Example of setting an MCP Server environment variable diff --git a/src/mcp_agent/config.py b/src/mcp_agent/config.py index 9f93a038..8b10c4cf 100644 --- a/src/mcp_agent/config.py +++ b/src/mcp_agent/config.py @@ -207,6 +207,18 @@ class TensorZeroSettings(BaseModel): api_key: Optional[str] = None model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) +class AzureOpenAISettings(BaseModel): + """ + Settings for using Azure OpenAI models in the fast-agent application. + """ + + api_key: str | None = None + base_url: str | None = None + api_version: str = "2025-01-01" + deployment_id: str | None = None + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + class LoggerSettings(BaseModel): """ @@ -302,6 +314,9 @@ class Settings(BaseSettings): tensorzero: Optional[TensorZeroSettings] = None """Settings for using TensorZero inference gateway""" + azure_openai: AzureOpenAISettings | None = None + """Settings for using Azure OpenAI models in the fast-agent application""" + logger: LoggerSettings | None = LoggerSettings() """Logger settings for the fast-agent application""" diff --git a/src/mcp_agent/llm/model_factory.py b/src/mcp_agent/llm/model_factory.py index d6d61e1b..af7ba334 100644 --- a/src/mcp_agent/llm/model_factory.py +++ b/src/mcp_agent/llm/model_factory.py @@ -16,6 +16,7 @@ from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM from mcp_agent.llm.providers.augmented_llm_openrouter import OpenRouterAugmentedLLM from mcp_agent.llm.providers.augmented_llm_tensorzero import TensorZeroAugmentedLLM +from mcp_agent.llm.providers.augmented_llm_azure_openai import AzureOpenAIAugmentedLLM from mcp_agent.mcp.interfaces import AugmentedLLMProtocol # from mcp_agent.workflows.llm.augmented_llm_deepseek import DeekSeekAugmentedLLM @@ -30,6 +31,7 @@ Type[DeepSeekAugmentedLLM], Type[OpenRouterAugmentedLLM], Type[TensorZeroAugmentedLLM], + Type[AzureOpenAIAugmentedLLM], ] @@ -86,6 +88,7 @@ class ModelFactory: "claude-3-opus-20240229": Provider.ANTHROPIC, "claude-3-opus-latest": Provider.ANTHROPIC, "deepseek-chat": Provider.DEEPSEEK, + "azure_openai": Provider.AZURE_OPENAI, # "deepseek-reasoner": Provider.DEEPSEEK, reinstate on release } @@ -101,6 +104,7 @@ class ModelFactory: "opus3": "claude-3-opus-latest", "deepseekv3": "deepseek-chat", "deepseek": "deepseek-chat", + "azure_openai": "azure_openai", } # Mapping of providers to their LLM classes @@ -113,6 +117,7 @@ class ModelFactory: Provider.GOOGLE: GoogleAugmentedLLM, # type: ignore Provider.OPENROUTER: OpenRouterAugmentedLLM, Provider.TENSORZERO: TensorZeroAugmentedLLM, + Provider.AZURE_OPENAI: AzureOpenAIAugmentedLLM, } # Mapping of special model names to their specific LLM classes diff --git a/src/mcp_agent/llm/provider_key_manager.py b/src/mcp_agent/llm/provider_key_manager.py index dc71e6de..88de52d2 100644 --- a/src/mcp_agent/llm/provider_key_manager.py +++ b/src/mcp_agent/llm/provider_key_manager.py @@ -17,6 +17,7 @@ "google": "GOOGLE_API_KEY", "openrouter": "OPENROUTER_API_KEY", "generic": "GENERIC_API_KEY", + "azure_openai": "AZURE_OPENAI_API_KEY", } API_KEY_HINT_TEXT = "" diff --git a/src/mcp_agent/llm/provider_types.py b/src/mcp_agent/llm/provider_types.py index 9316101a..71d88013 100644 --- a/src/mcp_agent/llm/provider_types.py +++ b/src/mcp_agent/llm/provider_types.py @@ -16,3 +16,4 @@ class Provider(Enum): GENERIC = "generic" OPENROUTER = "openrouter" TENSORZERO = "tensorzero" # For TensorZero Gateway + AZURE_OPENAI = "azure_openai" \ No newline at end of file diff --git a/src/mcp_agent/llm/providers/augmented_llm_azure_openai.py b/src/mcp_agent/llm/providers/augmented_llm_azure_openai.py new file mode 100644 index 00000000..937e7f5c --- /dev/null +++ b/src/mcp_agent/llm/providers/augmented_llm_azure_openai.py @@ -0,0 +1,58 @@ +from mcp_agent.core.request_params import RequestParams +from mcp_agent.llm.provider_types import Provider +from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM +from openai import AzureOpenAI, AuthenticationError +from mcp_agent.core.exceptions import ProviderKeyError + +DEFAULT_AZURE_OPENAI_MODEL = "azure_openai" +DEFAULT_AZURE_API_VERSION = "2025-01-01" + +class AzureOpenAIAugmentedLLM(OpenAIAugmentedLLM): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, provider=Provider.AZURE_OPENAI, **kwargs) + + def _initialize_default_params(self, kwargs: dict) -> RequestParams: + """Initialize Azure OpenAI default parameters""" + chosen_model = kwargs.get("model", DEFAULT_AZURE_OPENAI_MODEL) + + return RequestParams( + model=chosen_model, + systemPrompt=self.instruction, + parallel_tool_calls=True, + max_iterations=20, + use_history=True, + ) + + def _base_url(self) -> str: + base_url = None + if self.context and self.context.config and hasattr(self.context.config, "azure_openai"): + if self.context.config.azure_openai and hasattr(self.context.config.azure_openai, "base_url"): + base_url = self.context.config.azure_openai.base_url + return base_url or "" + + def _openai_client(self) -> AzureOpenAI: + """Create an Azure OpenAI client with the appropriate configuration""" + try: + api_key = self._api_key() + api_version = DEFAULT_AZURE_API_VERSION + azure_endpoint = self._base_url() + + # Safely get api_version if available + if (self.context and self.context.config and + hasattr(self.context.config, "azure_openai") and + self.context.config.azure_openai): + + if hasattr(self.context.config.azure_openai, "api_version") and self.context.config.azure_openai.api_version: + api_version = self.context.config.azure_openai.api_version + + return AzureOpenAI( + api_key=api_key, + api_version=api_version, + azure_endpoint=azure_endpoint + ) + except AuthenticationError as e: + raise ProviderKeyError( + "Invalid OpenAI API key", + "The configured OpenAI API key was rejected.\n" + "Please check that your API key is valid and not expired.", + ) from e \ No newline at end of file