Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions examples/mcp/elicitations/elicitation_forms_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
import sys
from typing import Optional
from typing import List, Optional, TypedDict

from mcp import ReadResourceResult
from mcp.server.elicitation import (
Expand All @@ -31,9 +31,40 @@
mcp = FastMCP("Elicitation Forms Demo Server", log_level="INFO")


class TitledEnumOption(TypedDict):
"""Type definition for oneOf/anyOf schema options."""

const: str
title: str


def _create_enum_schema_options(data: dict[str, str]) -> list[TitledEnumOption]:
"""Convert a dictionary to oneOf/anyOf schema format.

Args:
data: Dictionary mapping enum values to display titles

Returns:
List of schema options with 'const' and 'title' fields

Example:
>>> _create_enum_schema_options({"dark": "Dark Mode", "light": "Light Mode"})
[{"const": "dark", "title": "Dark Mode"}, {"const": "light", "title": "Light Mode"}]
"""
return [{"const": k, "title": v} for k, v in data.items()]


@mcp.resource(uri="elicitation://event-registration")
async def event_registration() -> ReadResourceResult:
"""Register for a tech conference event."""
workshop_names = {
"ai_basics": "AI Fundamentals",
"llm_apps": "Building LLM Applications",
"prompt_eng": "Prompt Engineering",
"rag_systems": "RAG Systems",
"fine_tuning": "Model Fine-tuning",
"deployment": "Production Deployment",
}

class EventRegistration(BaseModel):
name: str = Field(description="Your full name", min_length=2, max_length=100)
Expand All @@ -44,6 +75,13 @@ class EventRegistration(BaseModel):
event_date: str = Field(
description="Which event date works for you?", json_schema_extra={"format": "date"}
)
workshops: List[str] = Field(
default=["ai_basics", "llm_apps"],
description="Select workshops to attend (2-4 required)",
min_length=2,
max_length=4,
json_schema_extra={"items": {"anyOf": _create_enum_schema_options(workshop_names)}},
)
dietary_requirements: Optional[str] = Field(
None, description="Any dietary requirements? (optional)", max_length=200
)
Expand All @@ -61,7 +99,10 @@ class EventRegistration(BaseModel):
f"🏢 Company: {data.company_website or 'Not provided'}",
f"📅 Event Date: {data.event_date}",
f"🍽️ Dietary Requirements: {data.dietary_requirements or 'None'}",
f"🎓 Workshops ({len(data.workshops)} selected):",
]
for workshop in data.workshops:
lines.append(f" • {workshop_names.get(workshop, workshop)}")
response = "\n".join(lines)
case DeclinedElicitation():
response = "Registration declined - no ticket reserved"
Expand All @@ -80,6 +121,13 @@ class EventRegistration(BaseModel):
@mcp.resource(uri="elicitation://product-review")
async def product_review() -> ReadResourceResult:
"""Submit a product review with rating and comments."""
categories = {
"electronics": "Electronics",
"books": "Books & Media",
"clothing": "Clothing",
"home": "Home & Garden",
"sports": "Sports & Outdoors",
}

class ProductReview(BaseModel):
rating: int = Field(description="Rate this product (1-5 stars)", ge=1, le=5)
Expand All @@ -88,23 +136,15 @@ class ProductReview(BaseModel):
)
category: str = Field(
description="What type of product is this?",
json_schema_extra={
"enum": ["electronics", "books", "clothing", "home", "sports"],
"enumNames": [
"Electronics",
"Books & Media",
"Clothing",
"Home & Garden",
"Sports & Outdoors",
],
},
json_schema_extra={"oneOf": _create_enum_schema_options(categories)},
)
review_text: str = Field(
description="Tell us about your experience", min_length=10, max_length=1000
)

result = await mcp.get_context().elicit(
"Share your product review - Help others make informed decisions!", schema=ProductReview
"Share your product review - Help others make informed decisions!",
schema=ProductReview,
)

match result:
Expand All @@ -114,7 +154,7 @@ class ProductReview(BaseModel):
"🎯 Product Review Submitted!",
f"⭐ Rating: {stars} ({data.rating}/5)",
f"📊 Satisfaction: {data.satisfaction}/10.0",
f"📦 Category: {data.category.replace('_', ' ').title()}",
f"📦 Category: {categories.get(data.category, data.category)}",
f"💬 Review: {data.review_text}",
]
response = "\n".join(lines)
Expand All @@ -136,16 +176,15 @@ class ProductReview(BaseModel):
async def account_settings() -> ReadResourceResult:
"""Configure your account settings and preferences."""

themes = {"light": "Light Theme", "dark": "Dark Theme", "auto": "Auto (System)"}

class AccountSettings(BaseModel):
email_notifications: bool = Field(True, description="Receive email notifications?")
marketing_emails: bool = Field(False, description="Subscribe to marketing emails?")
theme: str = Field(
"dark",
description="Choose your preferred theme",
json_schema_extra={
"enum": ["light", "dark", "auto"],
"enumNames": ["Light Theme", "Dark Theme", "Auto (System)"],
},
json_schema_extra={"oneOf": _create_enum_schema_options(themes)},
)
privacy_public: bool = Field(False, description="Make your profile public?")
items_per_page: int = Field(
Expand All @@ -160,7 +199,7 @@ class AccountSettings(BaseModel):
"⚙️ Account Settings Updated!",
f"📧 Email notifications: {'On' if data.email_notifications else 'Off'}",
f"📬 Marketing emails: {'On' if data.marketing_emails else 'Off'}",
f"🎨 Theme: {data.theme.title()}",
f"🎨 Theme: {themes.get(data.theme, data.theme)}",
f"👥 Public profile: {'Yes' if data.privacy_public else 'No'}",
f"📄 Items per page: {data.items_per_page}",
]
Expand Down
103 changes: 101 additions & 2 deletions src/mcp_agent/human_input/elicitation_form.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Simplified, robust elicitation form dialog."""

from datetime import date, datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Tuple

from mcp.types import ElicitRequestedSchema
from prompt_toolkit import Application
Expand All @@ -26,6 +26,7 @@

from mcp_agent.human_input.elicitation_forms import ELICITATION_STYLE
from mcp_agent.human_input.elicitation_state import elicitation_state
from mcp_agent.human_input.form_elements import ValidatedCheckboxList


class SimpleNumberValidator(Validator):
Expand Down Expand Up @@ -413,6 +414,38 @@ def set_initial_focus():
self.app.invalidate() # Ensure layout is built
set_initial_focus()

def _extract_enum_schema_options(self, schema_def: Dict[str, Any]) -> List[Tuple[str, str]]:
"""Extract options from oneOf/anyOf/enum schema patterns.

Args:
schema_def: Schema definition potentially containing oneOf/anyOf/enum

Returns:
List of (value, title) tuples for the options
"""
values = []

# First check for bare enum (most common pattern for arrays)
if "enum" in schema_def:
enum_values = schema_def["enum"]
enum_names = schema_def.get("enumNames", enum_values)
for val, name in zip(enum_values, enum_names):
values.append((val, str(name)))
return values

# Then check for oneOf/anyOf patterns
options = schema_def.get("oneOf", [])
if not options:
options = schema_def.get("anyOf", [])

for option in options:
if "const" in option:
value = option["const"]
title = option.get("title", str(value))
values.append((value, title))

return values

def _extract_string_constraints(self, field_def: Dict[str, Any]) -> Dict[str, Any]:
"""Extract string constraints from field definition, handling anyOf schemas."""
constraints = {}
Expand All @@ -435,7 +468,6 @@ def _extract_string_constraints(self, field_def: Dict[str, Any]) -> Dict[str, An

return constraints


def _create_field(self, field_name: str, field_def: Dict[str, Any]):
"""Create a field widget."""

Expand All @@ -455,6 +487,24 @@ def _create_field(self, field_name: str, field_def: Dict[str, Any]):
hints = []
format_hint = None

# Check if this is an array type with enum/oneOf/anyOf items
if field_type == "array" and "items" in field_def:
items_def = field_def["items"]

# Add minItems/maxItems hints
min_items = field_def.get("minItems")
max_items = field_def.get("maxItems")

if min_items is not None and max_items is not None:
if min_items == max_items:
hints.append(f"select exactly {min_items}")
else:
hints.append(f"select {min_items}-{max_items}")
elif min_items is not None:
hints.append(f"select at least {min_items}")
elif max_items is not None:
hints.append(f"select up to {max_items}")

if field_type == "string":
constraints = self._extract_string_constraints(field_def)
if constraints.get("minLength"):
Expand Down Expand Up @@ -507,6 +557,7 @@ def _create_field(self, field_name: str, field_def: Dict[str, Any]):
return HSplit([label, Frame(checkbox)])

elif field_type == "string" and "enum" in field_def:
# Leaving this here for existing enum schema
enum_values = field_def["enum"]
enum_names = field_def.get("enumNames", enum_values)
values = [(val, name) for val, name in zip(enum_values, enum_names)]
Expand All @@ -517,6 +568,39 @@ def _create_field(self, field_name: str, field_def: Dict[str, Any]):

return HSplit([label, Frame(radio_list, height=min(len(values) + 2, 6))])

elif field_type == "string" and "oneOf" in field_def:
# Handle oneOf pattern for single selection enums
values = self._extract_enum_schema_options(field_def)
if values:
default_value = field_def.get("default")
radio_list = RadioList(values=values, default=default_value)
self.field_widgets[field_name] = radio_list
return HSplit([label, Frame(radio_list, height=min(len(values) + 2, 6))])

elif field_type == "array" and "items" in field_def:
# Handle array types with enum/oneOf/anyOf items
items_def = field_def["items"]
values = self._extract_enum_schema_options(items_def)
if values:
# Create checkbox list for multi-selection
min_items = field_def.get("minItems")
max_items = field_def.get("maxItems")
default_values = field_def.get("default", [])

checkbox_list = ValidatedCheckboxList(
values=values,
default_values=default_values,
min_items=min_items,
max_items=max_items,
)

# Store the widget directly (consistent with other widgets)
self.field_widgets[field_name] = checkbox_list

# Create scrollable frame if many options
height = min(len(values) + 2, 8)
return HSplit([label, Frame(checkbox_list, height=height)])

else:
# Text/number input
validator = None
Expand Down Expand Up @@ -630,6 +714,10 @@ def _validate_form(self) -> tuple[bool, Optional[str]]:
if widget.validation_error:
title = field_def.get("title", field_name)
return False, f"'{title}': {widget.validation_error.message}"
elif isinstance(widget, ValidatedCheckboxList):
if widget.validation_error:
title = field_def.get("title", field_name)
return False, f"'{title}': {widget.validation_error.message}"

# Then check if required fields are empty
for field_name in self.required_fields:
Expand All @@ -646,6 +734,10 @@ def _validate_form(self) -> tuple[bool, Optional[str]]:
if widget.current_value is None:
title = self.properties[field_name].get("title", field_name)
return False, f"'{title}' is required"
elif isinstance(widget, ValidatedCheckboxList):
if not widget.current_values:
title = self.properties[field_name].get("title", field_name)
return False, f"'{title}' is required"

return True, None

Expand Down Expand Up @@ -687,6 +779,13 @@ def _get_form_data(self) -> Dict[str, Any]:
if widget.current_value is not None:
data[field_name] = widget.current_value

elif isinstance(widget, ValidatedCheckboxList):
selected_values = widget.current_values
if selected_values:
data[field_name] = list(selected_values)
elif field_name not in self.required_fields:
data[field_name] = []

return data

def _accept(self):
Expand Down
59 changes: 59 additions & 0 deletions src/mcp_agent/human_input/form_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Custom form elements for elicitation forms."""

from typing import Optional, Sequence, TypeVar

from prompt_toolkit.formatted_text import AnyFormattedText
from prompt_toolkit.validation import ValidationError
from prompt_toolkit.widgets import CheckboxList

_T = TypeVar("_T")


class ValidatedCheckboxList(CheckboxList[_T]):
"""CheckboxList with min/max items validation."""

def __init__(
self,
values: Sequence[tuple[_T, AnyFormattedText]],
default_values: Optional[Sequence[_T]] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
):
"""
Initialize checkbox list with validation.

Args:
values: List of (value, label) tuples
default_values: Initially selected values
min_items: Minimum number of items that must be selected
max_items: Maximum number of items that can be selected
"""
super().__init__(values, default_values=default_values)
self.min_items = min_items
self.max_items = max_items

@property
def validation_error(self) -> Optional[ValidationError]:
"""
Check if current selection is valid.

Returns:
ValidationError if invalid, None if valid
"""
selected_count = len(self.current_values)

if self.min_items is not None and selected_count < self.min_items:
if self.min_items == 1:
message = "At least 1 selection required"
else:
message = f"At least {self.min_items} selections required"
return ValidationError(message=message)

if self.max_items is not None and selected_count > self.max_items:
if self.max_items == 1:
message = "Only 1 selection allowed"
else:
message = f"Maximum {self.max_items} selections allowed"
return ValidationError(message=message)

return None
Loading