Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/mcp_agent/cli/commands/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
api_key: <your-api-key-here>
openrouter:
api_key: <your-api-key-here>
azure_openai:
api_key: <your-api-key-here>
base_url: <your-endpoint-here>
api_version: <your-api-version-here>
deployment_id: <your-deployment-id-here>


# Example of setting an MCP Server environment variable
Expand Down
15 changes: 15 additions & 0 deletions src/mcp_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ class TensorZeroSettings(BaseModel):
api_key: Optional[str] = None
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)

class AzureOpenAISettings(BaseModel):
"""
Settings for using Azure OpenAI models in the fast-agent application.
"""

api_key: str | None = None
base_url: str | None = None
api_version: str = "2025-01-01"
deployment_id: str | None = None

model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)


class LoggerSettings(BaseModel):
"""
Expand Down Expand Up @@ -302,6 +314,9 @@ class Settings(BaseSettings):
tensorzero: Optional[TensorZeroSettings] = None
"""Settings for using TensorZero inference gateway"""

azure_openai: AzureOpenAISettings | None = None
"""Settings for using Azure OpenAI models in the fast-agent application"""

logger: LoggerSettings | None = LoggerSettings()
"""Logger settings for the fast-agent application"""

Expand Down
5 changes: 5 additions & 0 deletions src/mcp_agent/llm/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
from mcp_agent.llm.providers.augmented_llm_openrouter import OpenRouterAugmentedLLM
from mcp_agent.llm.providers.augmented_llm_tensorzero import TensorZeroAugmentedLLM
from mcp_agent.llm.providers.augmented_llm_azure_openai import AzureOpenAIAugmentedLLM
from mcp_agent.mcp.interfaces import AugmentedLLMProtocol

# from mcp_agent.workflows.llm.augmented_llm_deepseek import DeekSeekAugmentedLLM
Expand All @@ -30,6 +31,7 @@
Type[DeepSeekAugmentedLLM],
Type[OpenRouterAugmentedLLM],
Type[TensorZeroAugmentedLLM],
Type[AzureOpenAIAugmentedLLM],
]


Expand Down Expand Up @@ -86,6 +88,7 @@ class ModelFactory:
"claude-3-opus-20240229": Provider.ANTHROPIC,
"claude-3-opus-latest": Provider.ANTHROPIC,
"deepseek-chat": Provider.DEEPSEEK,
"azure_openai": Provider.AZURE_OPENAI,
# "deepseek-reasoner": Provider.DEEPSEEK, reinstate on release
}

Expand All @@ -101,6 +104,7 @@ class ModelFactory:
"opus3": "claude-3-opus-latest",
"deepseekv3": "deepseek-chat",
"deepseek": "deepseek-chat",
"azure_openai": "azure_openai",
}

# Mapping of providers to their LLM classes
Expand All @@ -113,6 +117,7 @@ class ModelFactory:
Provider.GOOGLE: GoogleAugmentedLLM, # type: ignore
Provider.OPENROUTER: OpenRouterAugmentedLLM,
Provider.TENSORZERO: TensorZeroAugmentedLLM,
Provider.AZURE_OPENAI: AzureOpenAIAugmentedLLM,
}

# Mapping of special model names to their specific LLM classes
Expand Down
1 change: 1 addition & 0 deletions src/mcp_agent/llm/provider_key_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"google": "GOOGLE_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
"generic": "GENERIC_API_KEY",
"azure_openai": "AZURE_OPENAI_API_KEY",
}
API_KEY_HINT_TEXT = "<your-api-key-here>"

Expand Down
1 change: 1 addition & 0 deletions src/mcp_agent/llm/provider_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ class Provider(Enum):
GENERIC = "generic"
OPENROUTER = "openrouter"
TENSORZERO = "tensorzero" # For TensorZero Gateway
AZURE_OPENAI = "azure_openai"
58 changes: 58 additions & 0 deletions src/mcp_agent/llm/providers/augmented_llm_azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from mcp_agent.core.request_params import RequestParams
from mcp_agent.llm.provider_types import Provider
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
from openai import AzureOpenAI, AuthenticationError
from mcp_agent.core.exceptions import ProviderKeyError

DEFAULT_AZURE_OPENAI_MODEL = "azure_openai"
DEFAULT_AZURE_API_VERSION = "2025-01-01"

class AzureOpenAIAugmentedLLM(OpenAIAugmentedLLM):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, provider=Provider.AZURE_OPENAI, **kwargs)

def _initialize_default_params(self, kwargs: dict) -> RequestParams:
"""Initialize Azure OpenAI default parameters"""
chosen_model = kwargs.get("model", DEFAULT_AZURE_OPENAI_MODEL)

return RequestParams(
model=chosen_model,
systemPrompt=self.instruction,
parallel_tool_calls=True,
max_iterations=20,
use_history=True,
)

def _base_url(self) -> str:
base_url = None
if self.context and self.context.config and hasattr(self.context.config, "azure_openai"):
if self.context.config.azure_openai and hasattr(self.context.config.azure_openai, "base_url"):
base_url = self.context.config.azure_openai.base_url
return base_url or ""

def _openai_client(self) -> AzureOpenAI:
"""Create an Azure OpenAI client with the appropriate configuration"""
try:
api_key = self._api_key()
api_version = DEFAULT_AZURE_API_VERSION
azure_endpoint = self._base_url()

# Safely get api_version if available
if (self.context and self.context.config and
hasattr(self.context.config, "azure_openai") and
self.context.config.azure_openai):

if hasattr(self.context.config.azure_openai, "api_version") and self.context.config.azure_openai.api_version:
api_version = self.context.config.azure_openai.api_version

return AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=azure_endpoint
)
except AuthenticationError as e:
raise ProviderKeyError(
"Invalid OpenAI API key",
"The configured OpenAI API key was rejected.\n"
"Please check that your API key is valid and not expired.",
) from e
Loading