From a344a4e80ebf08d4fe4b9e32b43568688b72c78d Mon Sep 17 00:00:00 2001 From: Guikai Date: Sat, 28 Jun 2025 17:02:45 -0700 Subject: [PATCH 1/3] feat(examples): add PostgreSQL MCP server as independent subproject --- examples/postgres/.env.example | 17 + examples/postgres/.gitignore | 45 ++ examples/postgres/Dockerfile | 37 + examples/postgres/Makefile | 87 +++ examples/postgres/README.md | 179 +++++ examples/postgres/demo.sh | 93 +++ examples/postgres/docker-compose.yml | 26 + examples/postgres/go.mod | 13 + examples/postgres/go.sum | 9 + examples/postgres/init.sql | 43 ++ examples/postgres/main.go | 390 ++++++++++ examples/postgres/main_test.go | 1021 ++++++++++++++++++++++++++ 12 files changed, 1960 insertions(+) create mode 100644 examples/postgres/.env.example create mode 100644 examples/postgres/.gitignore create mode 100644 examples/postgres/Dockerfile create mode 100644 examples/postgres/Makefile create mode 100644 examples/postgres/README.md create mode 100755 examples/postgres/demo.sh create mode 100644 examples/postgres/docker-compose.yml create mode 100644 examples/postgres/go.mod create mode 100644 examples/postgres/go.sum create mode 100644 examples/postgres/init.sql create mode 100644 examples/postgres/main.go create mode 100644 examples/postgres/main_test.go diff --git a/examples/postgres/.env.example b/examples/postgres/.env.example new file mode 100644 index 0000000..f7dc9ab --- /dev/null +++ b/examples/postgres/.env.example @@ -0,0 +1,17 @@ +# Example configuration for PostgreSQL MCP Server +# Copy this to .env and modify as needed + +# Database connection string +DATABASE_URL=postgres://testuser:testpass@localhost:5432/testdb?sslmode=disable + +# Alternative format examples: +# DATABASE_URL=postgres://username:password@hostname:5432/database_name +# DATABASE_URL=postgres://user@localhost/dbname?sslmode=require +# DATABASE_URL=postgresql://user:pass@localhost:5432/db?sslmode=disable + +# For production, consider using SSL: +# DATABASE_URL=postgres://user:pass@host:5432/db?sslmode=require + +# For cloud databases: +# DATABASE_URL=postgres://user:pass@host.amazonaws.com:5432/db?sslmode=require +# DATABASE_URL=postgres://user:pass@host.gcp.com:5432/db?sslmode=require diff --git a/examples/postgres/.gitignore b/examples/postgres/.gitignore new file mode 100644 index 0000000..4bfb073 --- /dev/null +++ b/examples/postgres/.gitignore @@ -0,0 +1,45 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib +postgres-mcp-server + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool +*.out +*.prof + +# Go workspace file +go.work + +# Dependency directories +vendor/ + +# IDE files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +Thumbs.db + +# Docker volumes +postgres_data/ + +# Log files +*.log + +# Environment files +.env +.env.local + +# Distribution packages +*.tar.gz +*.zip diff --git a/examples/postgres/Dockerfile b/examples/postgres/Dockerfile new file mode 100644 index 0000000..ce565af --- /dev/null +++ b/examples/postgres/Dockerfile @@ -0,0 +1,37 @@ +# Build stage +FROM golang:1.23-alpine AS builder + +# Install build dependencies +RUN apk add --no-cache git + +# Set working directory +WORKDIR /app + +# Copy go mod files +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy source code +COPY . . + +# Build the application +RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o postgres-mcp-server . + +# Final stage +FROM alpine:latest + +# Install ca-certificates for HTTPS requests +RUN apk --no-cache add ca-certificates + +WORKDIR /root/ + +# Copy the binary from builder stage +COPY --from=builder /app/postgres-mcp-server . + +# Expose port for HTTP mode (optional) +EXPOSE 8080 + +# Default to stdio mode, but allow HTTP mode via environment +CMD ["./postgres-mcp-server"] diff --git a/examples/postgres/Makefile b/examples/postgres/Makefile new file mode 100644 index 0000000..321d120 --- /dev/null +++ b/examples/postgres/Makefile @@ -0,0 +1,87 @@ +# PostgreSQL MCP Server Makefile + +.PHONY: help build test clean docker docker-compose run run-http dev + +# Default target +help: + @echo "Available targets:" + @echo " build - Build the PostgreSQL MCP server binary" + @echo " test - Run tests with coverage" + @echo " clean - Clean build artifacts" + @echo " docker - Build Docker image" + @echo " docker-compose - Start development environment with Docker Compose" + @echo " run - Run server in stdio mode (requires DATABASE_URL)" + @echo " run-http - Run server in HTTP mode on :8080" + @echo " dev - Start local development with sample database" + +# Build the binary +build: + go build -o postgres-mcp-server main.go + +# Run tests with coverage +test: + go test -v -cover + +# Clean build artifacts +clean: + rm -f postgres-mcp-server + +# Build Docker image +docker: + docker build -t postgres-mcp-server . + +# Start development environment +docker-compose: + docker-compose up -d + +# Stop development environment +docker-compose-down: + docker-compose down -v + +# Run server in stdio mode +run: + @if [ -z "$(DATABASE_URL)" ]; then \ + echo "DATABASE_URL environment variable is required"; \ + echo "Example: make run DATABASE_URL='postgres://user:pass@localhost:5432/db'"; \ + exit 1; \ + fi + go run main.go + +# Run server in HTTP mode for debugging +run-http: + @if [ -z "$(DATABASE_URL)" ]; then \ + echo "DATABASE_URL environment variable is required"; \ + echo "Example: make run-http DATABASE_URL='postgres://user:pass@localhost:5432/db'"; \ + exit 1; \ + fi + go run main.go -http=:8080 + +# Start local development with sample database +dev: docker-compose + @echo "Waiting for PostgreSQL to start..." + @sleep 5 + @echo "Development environment ready!" + @echo "Database: postgres://testuser:testpass@localhost:5432/testdb" + @echo "To run the MCP server:" + @echo " export DATABASE_URL='postgres://testuser:testpass@localhost:5432/testdb?sslmode=disable'" + @echo " make run" + +# Install dependencies +deps: + go mod tidy + +# Run linting (if available) +lint: + @if command -v golangci-lint >/dev/null 2>&1; then \ + golangci-lint run; \ + else \ + echo "golangci-lint not installed, skipping..."; \ + go vet ./...; \ + fi + +# Run all checks +check: test lint + +# Package for distribution +package: clean build + tar -czf postgres-mcp-server.tar.gz postgres-mcp-server README.md init.sql docker-compose.yml Dockerfile diff --git a/examples/postgres/README.md b/examples/postgres/README.md new file mode 100644 index 0000000..a4786bf --- /dev/null +++ b/examples/postgres/README.md @@ -0,0 +1,179 @@ +# PostgreSQL + +A Model Context Protocol server that provides read-only access to PostgreSQL databases. This server enables LLMs to inspect database schemas and execute read-only queries. + +## Components + +### Tools + +- **query** + - Execute read-only SQL queries against the connected database + - Input: `sql` (string): The SQL query to execute + - All queries are executed within a READ ONLY transaction + +### Resources + +The server provides schema information for each table in the database: + +- **Table Schemas** (`postgres:////schema`) + - JSON schema information for each table + - Includes column names and data types + - Automatically discovered from database metadata + +## Configuration + +### Usage with Claude Desktop + +To use this server with the Claude Desktop app, add the following configuration to the "mcpServers" section of your `claude_desktop_config.json`: + +### Docker + +- When running Docker on macOS, use `host.docker.internal` if the PostgreSQL server is running on the host network (e.g. localhost) +- Username/password can be added to the PostgreSQL URL with `postgresql://user:password@host:port/db-name` + +```json +{ + "mcpServers": { + "postgres": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "mcp/postgres-go", + "postgresql://host.docker.internal:5432/mydb" + ] + } + } +} +``` + +### Go Binary + +```json +{ + "mcpServers": { + "postgres": { + "command": "/path/to/postgres-mcp-server", + "args": ["postgresql://localhost/mydb"] + } + } +} +``` + +Replace `/mydb` with your database name and `/path/to/postgres-mcp-server` with the actual path to your built binary. + +### Usage with VS Code + +For manual installation, add the following JSON block to your User Settings (JSON) file in VS Code. You can do this by pressing `Ctrl + Shift + P` and typing `Preferences: Open User Settings (JSON)`. + +Optionally, you can add it to a file called `.vscode/mcp.json` in your workspace. This will allow you to share the configuration with others. + +> Note that the `mcp` key is not needed in the `.vscode/mcp.json` file. + +### Docker + +**Note**: When using Docker and connecting to a PostgreSQL server on your host machine, use `host.docker.internal` instead of `localhost` in the connection URL. + +```json +{ + "mcp": { + "inputs": [ + { + "type": "promptString", + "id": "pg_url", + "description": "PostgreSQL URL (e.g. postgresql://user:pass@host.docker.internal:5432/mydb)" + } + ], + "servers": { + "postgres": { + "command": "docker", + "args": ["run", "-i", "--rm", "mcp/postgres-go", "${input:pg_url}"] + } + } + } +} +``` + +### Go Binary + +```json +{ + "mcp": { + "inputs": [ + { + "type": "promptString", + "id": "pg_url", + "description": "PostgreSQL URL (e.g. postgresql://user:pass@localhost:5432/mydb)" + } + ], + "servers": { + "postgres": { + "command": "/path/to/postgres-mcp-server", + "args": ["${input:pg_url}"] + } + } + } +} +``` + +## Development + +### Quick Start + +1. Start the development environment: + +```bash +make dev +``` + +2. Run the server: + +```bash +export DATABASE_URL="postgres://testuser:testpass@localhost:5432/testdb?sslmode=disable" +go run main.go +``` + +### Building + +Go binary: + +```bash +go build -o postgres-mcp-server main.go +``` + +Docker: + +```bash +docker build -t mcp/postgres-go . +``` + +### Testing + +Run the test suite: + +```bash +go test -v -cover +``` + +Or use the Makefile: + +```bash +make test +``` + +### Environment Variables + +The server can be configured using environment variables: + +- `DATABASE_URL`: PostgreSQL connection string (required) + +Example: + +```bash +export DATABASE_URL="postgres://user:password@localhost:5432/database?sslmode=disable" +``` + +## License + +This MCP server is licensed under the MIT License. This means you are free to use, modify, and distribute the software, subject to the terms and conditions of the MIT License. For more details, please see the LICENSE file in the project repository. diff --git a/examples/postgres/demo.sh b/examples/postgres/demo.sh new file mode 100755 index 0000000..0d299f0 --- /dev/null +++ b/examples/postgres/demo.sh @@ -0,0 +1,93 @@ +#!/bin/bash + +# PostgreSQL MCP Server Demo Script +# This script demonstrates the PostgreSQL MCP server functionality + +set -e + +echo "๐Ÿ˜ PostgreSQL MCP Server Demo" +echo "=============================" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Check if Docker is running +if ! docker info >/dev/null 2>&1; then + echo -e "${RED}โŒ Docker is not running. Please start Docker first.${NC}" + exit 1 +fi + +echo -e "${YELLOW}๐Ÿš€ Starting PostgreSQL database with sample data...${NC}" +docker-compose up -d postgres + +echo -e "${YELLOW}โณ Waiting for PostgreSQL to be ready...${NC}" +sleep 10 + +# Check if PostgreSQL is ready +until docker-compose exec postgres pg_isready -U testuser -d testdb >/dev/null 2>&1; do + echo "Waiting for PostgreSQL..." + sleep 2 +done + +echo -e "${GREEN}โœ… PostgreSQL is ready!${NC}" + +# Build the MCP server +echo -e "${YELLOW}๐Ÿ”จ Building the MCP server...${NC}" +go build -o postgres-mcp-server main.go + +echo -e "${GREEN}โœ… MCP server built successfully!${NC}" + +# Set the database URL +export DATABASE_URL="postgres://testuser:testpass@localhost:5432/testdb?sslmode=disable" + +echo -e "${YELLOW}๐Ÿ“Š Testing server functionality...${NC}" + +# Start the server in HTTP mode for testing +echo -e "${YELLOW}๐ŸŒ Starting MCP server in HTTP mode on :8080...${NC}" +./postgres-mcp-server -http=:8080 & +SERVER_PID=$! + +# Wait for server to start +sleep 3 + +echo -e "${GREEN}โœ… Server started! PID: $SERVER_PID${NC}" + +echo -e "${YELLOW}๐Ÿ“‹ Available database tables:${NC}" +docker-compose exec postgres psql -U testuser -d testdb -c "\dt" + +echo -e "${YELLOW}๐Ÿ‘ฅ Sample users in the database:${NC}" +docker-compose exec postgres psql -U testuser -d testdb -c "SELECT * FROM users;" + +echo -e "${YELLOW}๐Ÿ›๏ธ Sample products in the database:${NC}" +docker-compose exec postgres psql -U testuser -d testdb -c "SELECT * FROM products;" + +echo -e "${YELLOW}๐Ÿ“ฆ Sample orders in the database:${NC}" +docker-compose exec postgres psql -U testuser -d testdb -c "SELECT * FROM orders;" + +echo "" +echo -e "${GREEN}๐ŸŽ‰ Demo setup complete!${NC}" +echo "" +echo "The PostgreSQL MCP server is now running and ready to use." +echo "" +echo "You can:" +echo " โ€ข Test the HTTP endpoint at http://localhost:8080" +echo " โ€ข Connect your MCP client to the server" +echo " โ€ข Query the database using the 'query' tool" +echo " โ€ข Browse table schemas as MCP resources" +echo "" +echo "Example query to try:" +echo ' {"tool": "query", "arguments": {"sql": "SELECT name, email FROM users LIMIT 3"}}' +echo "" +echo -e "${YELLOW}Press any key to stop the demo...${NC}" +read -n 1 -s + +# Clean up +echo -e "${YELLOW}๐Ÿงน Cleaning up...${NC}" +kill $SERVER_PID 2>/dev/null || true +docker-compose down -v +rm -f postgres-mcp-server + +echo -e "${GREEN}โœ… Demo cleanup complete!${NC}" diff --git a/examples/postgres/docker-compose.yml b/examples/postgres/docker-compose.yml new file mode 100644 index 0000000..86a9586 --- /dev/null +++ b/examples/postgres/docker-compose.yml @@ -0,0 +1,26 @@ +version: "3.8" + +services: + postgres: + image: postgres:15-alpine + environment: + POSTGRES_DB: testdb + POSTGRES_USER: testuser + POSTGRES_PASSWORD: testpass + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + - ./init.sql:/docker-entrypoint-initdb.d/init.sql + + postgres-mcp-server: + build: . + environment: + DATABASE_URL: "postgres://testuser:testpass@postgres:5432/testdb?sslmode=disable" + depends_on: + - postgres + stdin_open: true + tty: true + +volumes: + postgres_data: diff --git a/examples/postgres/go.mod b/examples/postgres/go.mod new file mode 100644 index 0000000..aa410aa --- /dev/null +++ b/examples/postgres/go.mod @@ -0,0 +1,13 @@ +module github.com/modelcontextprotocol/go-sdk/examples/postgres + +go 1.23.0 + +require ( + github.com/lib/pq v1.10.9 + github.com/modelcontextprotocol/go-sdk v0.0.0 +) + +require github.com/DATA-DOG/go-sqlmock v1.5.2 + +// Replace with local SDK for development +replace github.com/modelcontextprotocol/go-sdk => ../../ diff --git a/examples/postgres/go.sum b/examples/postgres/go.sum new file mode 100644 index 0000000..fbff1ff --- /dev/null +++ b/examples/postgres/go.sum @@ -0,0 +1,9 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/examples/postgres/init.sql b/examples/postgres/init.sql new file mode 100644 index 0000000..ab6a8a4 --- /dev/null +++ b/examples/postgres/init.sql @@ -0,0 +1,43 @@ +-- Sample data for testing the PostgreSQL MCP server + +-- Create a users table +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(255) UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Create a products table +CREATE TABLE products ( + id SERIAL PRIMARY KEY, + name VARCHAR(200) NOT NULL, + price DECIMAL(10,2) NOT NULL, + description TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Create an orders table +CREATE TABLE orders ( + id SERIAL PRIMARY KEY, + user_id INTEGER REFERENCES users(id), + total_amount DECIMAL(10,2) NOT NULL, + status VARCHAR(50) DEFAULT 'pending', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Insert sample data +INSERT INTO users (name, email) VALUES + ('Alice Johnson', 'alice@example.com'), + ('Bob Smith', 'bob@example.com'), + ('Charlie Brown', 'charlie@example.com'); + +INSERT INTO products (name, price, description) VALUES + ('Laptop', 999.99, 'High-performance laptop for work and gaming'), + ('Mouse', 29.99, 'Wireless optical mouse'), + ('Keyboard', 79.99, 'Mechanical keyboard with RGB lighting'); + +INSERT INTO orders (user_id, total_amount, status) VALUES + (1, 1029.98, 'completed'), + (2, 79.99, 'pending'), + (3, 999.99, 'shipped'); diff --git a/examples/postgres/main.go b/examples/postgres/main.go new file mode 100644 index 0000000..7112e4d --- /dev/null +++ b/examples/postgres/main.go @@ -0,0 +1,390 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "database/sql" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "net/url" + "os" + + _ "github.com/lib/pq" // PostgreSQL driver + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + +const SCHEMA_PATH = "schema" + +type PostgresServer struct { + db *sql.DB + resourceBaseURL *url.URL +} + +// QueryArgs represents the arguments for the SQL query tool +type QueryArgs struct { + SQL string `json:"sql"` +} + +// NewPostgresServer creates a new PostgreSQL MCP server +func NewPostgresServer(databaseURL string) (*PostgresServer, error) { + db, err := sql.Open("postgres", databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + resourceBaseURL, err := url.Parse(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse database URL: %w", err) + } + resourceBaseURL.Scheme = "postgres" + resourceBaseURL.User = nil // Remove credentials for security + + return &PostgresServer{ + db: db, + resourceBaseURL: resourceBaseURL, + }, nil +} + +// Close closes the database connection +func (ps *PostgresServer) Close() error { + return ps.db.Close() +} + +// ListTables returns all tables in the public schema as resources +func (ps *PostgresServer) ListTables(ctx context.Context) ([]*mcp.ServerResource, error) { + query := "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" + rows, err := ps.db.QueryContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to query tables: %w", err) + } + defer rows.Close() + + var resources []*mcp.ServerResource + for rows.Next() { + var tableName string + if err := rows.Scan(&tableName); err != nil { + return nil, fmt.Errorf("failed to scan table name: %w", err) + } + + resourceURI := fmt.Sprintf("%s/%s/%s", ps.resourceBaseURL.String(), tableName, SCHEMA_PATH) + resource := &mcp.ServerResource{ + Resource: &mcp.Resource{ + URI: resourceURI, + MIMEType: "application/json", + Name: fmt.Sprintf(`"%s" database schema`, tableName), + }, + Handler: ps.readTableSchema, + } + resources = append(resources, resource) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating tables: %w", err) + } + + return resources, nil +} + +// readTableSchema handles reading table schema information +func (ps *PostgresServer) readTableSchema(ctx context.Context, ss *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) { + resourceURL, err := url.Parse(params.URI) + if err != nil { + return nil, fmt.Errorf("invalid resource URI: %w", err) + } + + pathComponents := []string{} + for _, component := range []string{resourceURL.Path} { + if component != "" { + pathComponents = append(pathComponents, component) + } + } + + // Parse path: /tableName/schema + parts := []string{} + for _, part := range pathComponents { + if part == "/" { + continue + } + subParts := []string{} + for _, subPart := range []string{part} { + if subPart != "" { + for _, p := range []string{subPart} { + if p != "/" { + subParts = append(subParts, p) + } + } + } + } + parts = append(parts, subParts...) + } + + // Extract table name and schema path from URI + urlPath := resourceURL.Path + if urlPath == "" { + return nil, fmt.Errorf("empty URI path") + } + + // Remove leading slash and split by '/' + trimmedPath := urlPath + if trimmedPath[0] == '/' { + trimmedPath = trimmedPath[1:] + } + pathParts := []string{} + for _, part := range []string{trimmedPath} { + if part != "" { + // Split by '/' + for i, p := range []string{part} { + if i == 0 { + subparts := []string{} + current := "" + for _, r := range p { + if r == '/' { + if current != "" { + subparts = append(subparts, current) + current = "" + } + } else { + current += string(r) + } + } + if current != "" { + subparts = append(subparts, current) + } + pathParts = append(pathParts, subparts...) + } + } + } + } + + if len(pathParts) < 2 { + return nil, fmt.Errorf("invalid resource URI format: expected /tableName/schema") + } + + tableName := pathParts[len(pathParts)-2] + schema := pathParts[len(pathParts)-1] + + if schema != SCHEMA_PATH { + return nil, fmt.Errorf("invalid resource URI: expected schema path") + } + + // Query table columns + query := "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position" + rows, err := ps.db.QueryContext(ctx, query, tableName) + if err != nil { + return nil, fmt.Errorf("failed to query table schema: %w", err) + } + defer rows.Close() + + type ColumnInfo struct { + ColumnName string `json:"column_name"` + DataType string `json:"data_type"` + } + + var columns []ColumnInfo + for rows.Next() { + var column ColumnInfo + if err := rows.Scan(&column.ColumnName, &column.DataType); err != nil { + return nil, fmt.Errorf("failed to scan column info: %w", err) + } + columns = append(columns, column) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating columns: %w", err) + } + + jsonData, err := json.MarshalIndent(columns, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal columns to JSON: %w", err) + } + + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: params.URI, + MIMEType: "application/json", + Text: string(jsonData), + }, + }, + }, nil +} + +// QueryTool executes a read-only SQL query +func (ps *PostgresServer) QueryTool(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[QueryArgs]) (*mcp.CallToolResultFor[struct{}], error) { + sqlQuery := params.Arguments.SQL + + // Start a read-only transaction + tx, err := ps.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Failed to start transaction: %v", err)}, + }, + IsError: true, + }, nil + } + defer tx.Rollback() // Always rollback read-only transaction + + rows, err := tx.QueryContext(ctx, sqlQuery) + if err != nil { + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Query execution failed: %v", err)}, + }, + IsError: true, + }, nil + } + defer rows.Close() + + // Get column names + columns, err := rows.Columns() + if err != nil { + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Failed to get columns: %v", err)}, + }, + IsError: true, + }, nil + } + + // Prepare result slice + var results []map[string]interface{} + + // Scan rows + for rows.Next() { + // Create slice of interface{} to hold column values + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Failed to scan row: %v", err)}, + }, + IsError: true, + }, nil + } + + // Convert to map + row := make(map[string]interface{}) + for i, col := range columns { + val := values[i] + if b, ok := val.([]byte); ok { + row[col] = string(b) + } else { + row[col] = val + } + } + results = append(results, row) + } + + if err := rows.Err(); err != nil { + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Error iterating rows: %v", err)}, + }, + IsError: true, + }, nil + } + + // Ensure results is not nil for proper JSON marshaling + if results == nil { + results = []map[string]interface{}{} + } + + // Convert results to JSON + jsonData, err := json.MarshalIndent(results, "", " ") + if err != nil { + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Failed to marshal results: %v", err)}, + }, + IsError: true, + }, nil + } + + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: string(jsonData)}, + }, + }, nil +} + +func main() { + flag.Parse() + + // Get database URL from environment variable or command line + var databaseURL string + + // First try environment variable + if envURL := os.Getenv("DATABASE_URL"); envURL != "" { + databaseURL = envURL + log.Printf("Using DATABASE_URL from environment") + } else { + // Fall back to command line argument + args := os.Args[1:] + if len(args) == 0 { + // Default to local development database if not specified + databaseURL = "postgres://testuser:testpass@localhost:5432/testdb?sslmode=disable" + log.Printf("No DATABASE_URL or command line argument provided, using default: %s", databaseURL) + } else { + databaseURL = args[0] + log.Printf("Using database URL from command line") + } + } + + // Create PostgreSQL server + postgresServer, err := NewPostgresServer(databaseURL) + if err != nil { + log.Fatalf("Failed to create PostgreSQL server: %v", err) + } + defer postgresServer.Close() + + log.Printf("Connected to PostgreSQL database successfully") + + // Create MCP server + server := mcp.NewServer("postgres", "0.1.0", nil) + + // Add the query tool + server.AddTools(mcp.NewServerTool("query", "Run a read-only SQL query", postgresServer.QueryTool, mcp.Input( + mcp.Property("sql", mcp.Description("The SQL query to execute")), + ))) + + // Get and add resources (tables) dynamically + ctx := context.Background() + resources, err := postgresServer.ListTables(ctx) + if err != nil { + log.Fatalf("Failed to list database tables: %v", err) + } + server.AddResources(resources...) + + // Start server + if *httpAddr != "" { + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + log.Printf("PostgreSQL MCP server listening at %s", *httpAddr) + http.ListenAndServe(*httpAddr, handler) + } else { + log.Printf("PostgreSQL MCP server running on stdio") + t := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr) + if err := server.Run(context.Background(), t); err != nil { + log.Printf("Server failed: %v", err) + } + } +} diff --git a/examples/postgres/main_test.go b/examples/postgres/main_test.go new file mode 100644 index 0000000..4529a6f --- /dev/null +++ b/examples/postgres/main_test.go @@ -0,0 +1,1021 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// createMockPostgresServer creates a PostgresServer with a mock database for testing +func createMockPostgresServer(t *testing.T) (*PostgresServer, sqlmock.Sqlmock, func()) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("Failed to create mock database: %v", err) + } + + baseURL, _ := url.Parse("postgres://testuser@localhost:5432/testdb") + baseURL.Scheme = "postgres" + baseURL.User = nil + + server := &PostgresServer{ + db: db, + resourceBaseURL: baseURL, + } + + cleanup := func() { + db.Close() + } + + return server, mock, cleanup +} + +func TestNewPostgresServer(t *testing.T) { + tests := []struct { + name string + databaseURL string + wantError bool + errorMsg string + }{ + { + name: "invalid URL", + databaseURL: "invalid-url", + wantError: true, + errorMsg: "failed to ping database", // SQL driver tries to ping and fails + }, + { + name: "empty URL", + databaseURL: "", + wantError: true, + errorMsg: "failed to ping database", // Empty URL also fails at ping stage + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewPostgresServer(tt.databaseURL) + if tt.wantError { + if err == nil { + t.Errorf("NewPostgresServer() expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errorMsg) { + t.Errorf("NewPostgresServer() error = %v, want error containing %v", err, tt.errorMsg) + } + } else { + if err != nil { + t.Errorf("NewPostgresServer() unexpected error = %v", err) + } + } + }) + } +} + +func TestPostgresServer_Close(t *testing.T) { + server, mock, cleanup := createMockPostgresServer(t) + defer cleanup() + + mock.ExpectClose() + + err := server.Close() + if err != nil { + t.Errorf("Close() error = %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } +} + +func TestPostgresServer_ListTables(t *testing.T) { + server, mock, cleanup := createMockPostgresServer(t) + defer cleanup() + + ctx := context.Background() + + // Test successful table listing + t.Run("successful listing", func(t *testing.T) { + rows := sqlmock.NewRows([]string{"table_name"}). + AddRow("users"). + AddRow("products"). + AddRow("orders") + + mock.ExpectQuery("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"). + WillReturnRows(rows) + + resources, err := server.ListTables(ctx) + if err != nil { + t.Fatalf("ListTables() error = %v", err) + } + + if len(resources) != 3 { + t.Errorf("Expected 3 resources, got %d", len(resources)) + } + + expectedTables := []string{"users", "products", "orders"} + for i, resource := range resources { + expectedName := fmt.Sprintf(`"%s" database schema`, expectedTables[i]) + if resource.Resource.Name != expectedName { + t.Errorf("Expected resource name %s, got %s", expectedName, resource.Resource.Name) + } + + expectedURI := fmt.Sprintf("%s/%s/%s", server.resourceBaseURL.String(), expectedTables[i], SCHEMA_PATH) + if resource.Resource.URI != expectedURI { + t.Errorf("Expected resource URI %s, got %s", expectedURI, resource.Resource.URI) + } + + if resource.Resource.MIMEType != "application/json" { + t.Errorf("Expected MIME type application/json, got %s", resource.Resource.MIMEType) + } + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test database error + t.Run("database error", func(t *testing.T) { + mock.ExpectQuery("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"). + WillReturnError(fmt.Errorf("database connection failed")) + + _, err := server.ListTables(ctx) + if err == nil { + t.Error("Expected error, got nil") + } + + if !strings.Contains(err.Error(), "failed to query tables") { + t.Errorf("Expected error to contain 'failed to query tables', got %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test scan error + t.Run("scan error", func(t *testing.T) { + rows := sqlmock.NewRows([]string{"table_name"}). + AddRow("users"). + AddRow(nil) // This will cause a scan error + + mock.ExpectQuery("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"). + WillReturnRows(rows) + + _, err := server.ListTables(ctx) + if err == nil { + t.Error("Expected error, got nil") + } + + if !strings.Contains(err.Error(), "failed to scan table name") { + t.Errorf("Expected error to contain 'failed to scan table name', got %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) +} + +func TestPostgresServer_readTableSchema(t *testing.T) { + server, mock, cleanup := createMockPostgresServer(t) + defer cleanup() + + ctx := context.Background() + + // Test successful schema reading + t.Run("successful schema reading", func(t *testing.T) { + rows := sqlmock.NewRows([]string{"column_name", "data_type"}). + AddRow("id", "integer"). + AddRow("name", "character varying"). + AddRow("email", "character varying"). + AddRow("created_at", "timestamp with time zone") + + mock.ExpectQuery("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = \\$1 ORDER BY ordinal_position"). + WithArgs("users"). + WillReturnRows(rows) + + params := &mcp.ReadResourceParams{ + URI: fmt.Sprintf("%s/users/%s", server.resourceBaseURL.String(), SCHEMA_PATH), + } + + result, err := server.readTableSchema(ctx, nil, params) + if err != nil { + t.Fatalf("readTableSchema() error = %v", err) + } + + if len(result.Contents) != 1 { + t.Fatalf("Expected 1 content item, got %d", len(result.Contents)) + } + + content := result.Contents[0] + if content.URI != params.URI { + t.Errorf("Expected URI %s, got %s", params.URI, content.URI) + } + + if content.MIMEType != "application/json" { + t.Errorf("Expected MIME type application/json, got %s", content.MIMEType) + } + + // Parse and verify JSON content + var columns []map[string]interface{} + err = json.Unmarshal([]byte(content.Text), &columns) + if err != nil { + t.Fatalf("Failed to parse JSON: %v", err) + } + + if len(columns) != 4 { + t.Errorf("Expected 4 columns, got %d", len(columns)) + } + + expectedColumns := []map[string]string{ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "name", "data_type": "character varying"}, + {"column_name": "email", "data_type": "character varying"}, + {"column_name": "created_at", "data_type": "timestamp with time zone"}, + } + + for i, col := range columns { + expected := expectedColumns[i] + if col["column_name"] != expected["column_name"] { + t.Errorf("Expected column name %s, got %s", expected["column_name"], col["column_name"]) + } + if col["data_type"] != expected["data_type"] { + t.Errorf("Expected data type %s, got %s", expected["data_type"], col["data_type"]) + } + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test invalid URI + t.Run("invalid URI", func(t *testing.T) { + params := &mcp.ReadResourceParams{ + URI: "invalid-uri", + } + + _, err := server.readTableSchema(ctx, nil, params) + if err == nil { + t.Error("Expected error for invalid URI, got nil") + } + + if !strings.Contains(err.Error(), "invalid resource URI") { + t.Errorf("Expected error to contain 'invalid resource URI', got %v", err) + } + }) + + // Test invalid URI format + t.Run("invalid URI format", func(t *testing.T) { + params := &mcp.ReadResourceParams{ + URI: fmt.Sprintf("%s/users", server.resourceBaseURL.String()), // Missing schema path + } + + _, err := server.readTableSchema(ctx, nil, params) + if err == nil { + t.Error("Expected error for invalid URI format, got nil") + } + + // The actual error message depends on the path parsing logic + // It should be either about format or schema path + if !strings.Contains(err.Error(), "invalid resource URI") { + t.Errorf("Expected error to contain 'invalid resource URI', got %v", err) + } + }) + + // Test wrong schema path + t.Run("wrong schema path", func(t *testing.T) { + params := &mcp.ReadResourceParams{ + URI: fmt.Sprintf("%s/users/wrong", server.resourceBaseURL.String()), + } + + _, err := server.readTableSchema(ctx, nil, params) + if err == nil { + t.Error("Expected error for wrong schema path, got nil") + } + + if !strings.Contains(err.Error(), "invalid resource URI: expected schema path") { + t.Errorf("Expected error to contain 'invalid resource URI: expected schema path', got %v", err) + } + }) + + // Test database error + t.Run("database error", func(t *testing.T) { + mock.ExpectQuery("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = \\$1 ORDER BY ordinal_position"). + WithArgs("users"). + WillReturnError(fmt.Errorf("table not found")) + + params := &mcp.ReadResourceParams{ + URI: fmt.Sprintf("%s/users/%s", server.resourceBaseURL.String(), SCHEMA_PATH), + } + + _, err := server.readTableSchema(ctx, nil, params) + if err == nil { + t.Error("Expected error, got nil") + } + + if !strings.Contains(err.Error(), "failed to query table schema") { + t.Errorf("Expected error to contain 'failed to query table schema', got %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) +} + +func TestPostgresServer_QueryTool(t *testing.T) { + server, mock, cleanup := createMockPostgresServer(t) + defer cleanup() + + ctx := context.Background() + + // Test successful query + t.Run("successful query", func(t *testing.T) { + mock.ExpectBegin() + + rows := sqlmock.NewRows([]string{"id", "name", "email"}). + AddRow(1, "John Doe", "john@example.com"). + AddRow(2, "Jane Smith", "jane@example.com") + + mock.ExpectQuery("SELECT \\* FROM users LIMIT 2"). + WillReturnRows(rows) + + mock.ExpectRollback() + + args := QueryArgs{ + SQL: "SELECT * FROM users LIMIT 2", + } + + params := &mcp.CallToolParamsFor[QueryArgs]{ + Name: "query", + Arguments: args, + } + + result, err := server.QueryTool(ctx, nil, params) + if err != nil { + t.Fatalf("QueryTool() error = %v", err) + } + + if result.IsError { + t.Error("Expected successful result, got error") + } + + if len(result.Content) != 1 { + t.Fatalf("Expected 1 content item, got %d", len(result.Content)) + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + // Parse and verify JSON result + var queryResults []map[string]interface{} + err = json.Unmarshal([]byte(textContent.Text), &queryResults) + if err != nil { + t.Fatalf("Failed to parse query result JSON: %v", err) + } + + if len(queryResults) != 2 { + t.Errorf("Expected 2 rows, got %d", len(queryResults)) + } + + // Verify first row + firstRow := queryResults[0] + if firstRow["name"] != "John Doe" { + t.Errorf("Expected name 'John Doe', got %v", firstRow["name"]) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test transaction begin error + t.Run("transaction begin error", func(t *testing.T) { + mock.ExpectBegin().WillReturnError(fmt.Errorf("connection failed")) + + args := QueryArgs{ + SQL: "SELECT * FROM users", + } + + params := &mcp.CallToolParamsFor[QueryArgs]{ + Name: "query", + Arguments: args, + } + + result, err := server.QueryTool(ctx, nil, params) + if err != nil { + t.Fatalf("QueryTool() error = %v", err) + } + + if !result.IsError { + t.Error("Expected error result") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "Failed to start transaction") { + t.Errorf("Expected error message about transaction, got %s", textContent.Text) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test query execution error + t.Run("query execution error", func(t *testing.T) { + mock.ExpectBegin() + mock.ExpectQuery("SELECT \\* FROM nonexistent"). + WillReturnError(fmt.Errorf("table does not exist")) + mock.ExpectRollback() + + args := QueryArgs{ + SQL: "SELECT * FROM nonexistent", + } + + params := &mcp.CallToolParamsFor[QueryArgs]{ + Name: "query", + Arguments: args, + } + + result, err := server.QueryTool(ctx, nil, params) + if err != nil { + t.Fatalf("QueryTool() error = %v", err) + } + + if !result.IsError { + t.Error("Expected error result") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "Query execution failed") { + t.Errorf("Expected query execution error message, got %s", textContent.Text) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test with byte array values (common in PostgreSQL) + t.Run("query with byte array values", func(t *testing.T) { + mock.ExpectBegin() + + rows := sqlmock.NewRows([]string{"id", "data"}). + AddRow(1, []byte("binary data")) + + mock.ExpectQuery("SELECT id, data FROM binary_table"). + WillReturnRows(rows) + + mock.ExpectRollback() + + args := QueryArgs{ + SQL: "SELECT id, data FROM binary_table", + } + + params := &mcp.CallToolParamsFor[QueryArgs]{ + Name: "query", + Arguments: args, + } + + result, err := server.QueryTool(ctx, nil, params) + if err != nil { + t.Fatalf("QueryTool() error = %v", err) + } + + if result.IsError { + t.Error("Expected successful result, got error") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + // Parse and verify JSON result + var queryResults []map[string]interface{} + err = json.Unmarshal([]byte(textContent.Text), &queryResults) + if err != nil { + t.Fatalf("Failed to parse query result JSON: %v", err) + } + + if len(queryResults) != 1 { + t.Errorf("Expected 1 row, got %d", len(queryResults)) + } + + // Verify byte array was converted to string + row := queryResults[0] + if row["data"] != "binary data" { + t.Errorf("Expected data 'binary data', got %v", row["data"]) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test empty result set + t.Run("empty result set", func(t *testing.T) { + mock.ExpectBegin() + + rows := sqlmock.NewRows([]string{"id", "name"}) + + mock.ExpectQuery("SELECT \\* FROM users WHERE id = 999"). + WillReturnRows(rows) + + mock.ExpectRollback() + + args := QueryArgs{ + SQL: "SELECT * FROM users WHERE id = 999", + } + + params := &mcp.CallToolParamsFor[QueryArgs]{ + Name: "query", + Arguments: args, + } + + result, err := server.QueryTool(ctx, nil, params) + if err != nil { + t.Fatalf("QueryTool() error = %v", err) + } + + if result.IsError { + t.Error("Expected successful result, got error") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + // Should return empty JSON array + if !strings.Contains(textContent.Text, "[]") { + t.Errorf("Expected JSON containing empty array, got %s", textContent.Text) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) +} + +func TestQueryArgs(t *testing.T) { + // Test JSON marshaling/unmarshaling of QueryArgs + args := QueryArgs{ + SQL: "SELECT * FROM users", + } + + data, err := json.Marshal(args) + if err != nil { + t.Fatalf("Failed to marshal QueryArgs: %v", err) + } + + var unmarshaled QueryArgs + err = json.Unmarshal(data, &unmarshaled) + if err != nil { + t.Fatalf("Failed to unmarshal QueryArgs: %v", err) + } + + if unmarshaled.SQL != args.SQL { + t.Errorf("Expected SQL %s, got %s", args.SQL, unmarshaled.SQL) + } +} + +func TestSCHEMA_PATH(t *testing.T) { + // Test that the schema path constant is correct + if SCHEMA_PATH != "schema" { + t.Errorf("Expected SCHEMA_PATH to be 'schema', got %s", SCHEMA_PATH) + } +} + +// TestPostgresServerIntegration tests the integration between different components +func TestPostgresServerIntegration(t *testing.T) { + server, mock, cleanup := createMockPostgresServer(t) + defer cleanup() + + ctx := context.Background() + + // Test complete flow: ListTables -> ReadTableSchema + t.Run("complete flow", func(t *testing.T) { + // Mock ListTables + tableRows := sqlmock.NewRows([]string{"table_name"}). + AddRow("users") + + mock.ExpectQuery("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"). + WillReturnRows(tableRows) + + resources, err := server.ListTables(ctx) + if err != nil { + t.Fatalf("ListTables() error = %v", err) + } + + if len(resources) != 1 { + t.Fatalf("Expected 1 resource, got %d", len(resources)) + } + + // Use the returned resource to test readTableSchema + schemaRows := sqlmock.NewRows([]string{"column_name", "data_type"}). + AddRow("id", "integer"). + AddRow("name", "character varying") + + mock.ExpectQuery("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = \\$1 ORDER BY ordinal_position"). + WithArgs("users"). + WillReturnRows(schemaRows) + + params := &mcp.ReadResourceParams{ + URI: resources[0].Resource.URI, + } + + result, err := resources[0].Handler(ctx, nil, params) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + + if len(result.Contents) != 1 { + t.Fatalf("Expected 1 content item, got %d", len(result.Contents)) + } + + // Verify the schema content + var columns []map[string]interface{} + err = json.Unmarshal([]byte(result.Contents[0].Text), &columns) + if err != nil { + t.Fatalf("Failed to parse schema JSON: %v", err) + } + + if len(columns) != 2 { + t.Errorf("Expected 2 columns, got %d", len(columns)) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) +} + +// TestEdgeCases tests various edge cases and error conditions +func TestEdgeCases(t *testing.T) { + server, mock, cleanup := createMockPostgresServer(t) + defer cleanup() + + ctx := context.Background() + + // Test readTableSchema with empty path + t.Run("empty path in URI", func(t *testing.T) { + params := &mcp.ReadResourceParams{ + URI: server.resourceBaseURL.String(), // No path + } + + _, err := server.readTableSchema(ctx, nil, params) + if err == nil { + t.Error("Expected error for empty path, got nil") + } + + // The path parsing logic will handle this as invalid format + if !strings.Contains(err.Error(), "invalid resource URI") { + t.Errorf("Expected error about invalid resource URI, got %v", err) + } + }) + + // Test QueryTool with columns() error + t.Run("columns error in QueryTool", func(t *testing.T) { + mock.ExpectBegin() + + rows := sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, "test"). + CloseError(fmt.Errorf("columns error")) + + mock.ExpectQuery("SELECT \\* FROM test"). + WillReturnRows(rows) + + mock.ExpectRollback() + + args := QueryArgs{ + SQL: "SELECT * FROM test", + } + + params := &mcp.CallToolParamsFor[QueryArgs]{ + Name: "query", + Arguments: args, + } + + result, err := server.QueryTool(ctx, nil, params) + if err != nil { + t.Fatalf("QueryTool() error = %v", err) + } + + if !result.IsError { + t.Error("Expected error result") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + // CloseError actually affects rows.Err(), not Columns() + if !strings.Contains(textContent.Text, "Error iterating rows") { + t.Errorf("Expected row error message, got %s", textContent.Text) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test QueryTool with rows.Next() error + t.Run("rows iteration error in QueryTool", func(t *testing.T) { + mock.ExpectBegin() + + rows := sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, "test"). + RowError(0, fmt.Errorf("row iteration error")) + + mock.ExpectQuery("SELECT \\* FROM test"). + WillReturnRows(rows) + + mock.ExpectRollback() + + args := QueryArgs{ + SQL: "SELECT * FROM test", + } + + params := &mcp.CallToolParamsFor[QueryArgs]{ + Name: "query", + Arguments: args, + } + + result, err := server.QueryTool(ctx, nil, params) + if err != nil { + t.Fatalf("QueryTool() error = %v", err) + } + + if !result.IsError { + t.Error("Expected error result") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "Error iterating rows") { + t.Errorf("Expected row iteration error message, got %s", textContent.Text) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test readTableSchema with JSON marshal error (simulate with invalid data) + t.Run("JSON marshal error in readTableSchema", func(t *testing.T) { + // This is hard to simulate with sqlmock, but we can test the structure + // The current implementation should handle this gracefully + + schemaRows := sqlmock.NewRows([]string{"column_name", "data_type"}). + AddRow("valid_column", "text") + + mock.ExpectQuery("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = \\$1 ORDER BY ordinal_position"). + WithArgs("test_table"). + WillReturnRows(schemaRows) + + params := &mcp.ReadResourceParams{ + URI: fmt.Sprintf("%s/test_table/%s", server.resourceBaseURL.String(), SCHEMA_PATH), + } + + result, err := server.readTableSchema(ctx, nil, params) + if err != nil { + t.Fatalf("readTableSchema() error = %v", err) + } + + // Should succeed with valid data + if len(result.Contents) != 1 { + t.Fatalf("Expected 1 content item, got %d", len(result.Contents)) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) +} + +// TestResourceURIGeneration tests the URI generation logic +func TestResourceURIGeneration(t *testing.T) { + server, mock, cleanup := createMockPostgresServer(t) + defer cleanup() + + ctx := context.Background() + + // Test with different table names including special characters + testCases := []struct { + tableName string + expectedURI string + }{ + { + tableName: "users", + expectedURI: fmt.Sprintf("%s/users/%s", server.resourceBaseURL.String(), SCHEMA_PATH), + }, + { + tableName: "user_profiles", + expectedURI: fmt.Sprintf("%s/user_profiles/%s", server.resourceBaseURL.String(), SCHEMA_PATH), + }, + { + tableName: "orders123", + expectedURI: fmt.Sprintf("%s/orders123/%s", server.resourceBaseURL.String(), SCHEMA_PATH), + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("table_%s", tc.tableName), func(t *testing.T) { + rows := sqlmock.NewRows([]string{"table_name"}). + AddRow(tc.tableName) + + mock.ExpectQuery("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"). + WillReturnRows(rows) + + resources, err := server.ListTables(ctx) + if err != nil { + t.Fatalf("ListTables() error = %v", err) + } + + if len(resources) != 1 { + t.Fatalf("Expected 1 resource, got %d", len(resources)) + } + + if resources[0].Resource.URI != tc.expectedURI { + t.Errorf("Expected URI %s, got %s", tc.expectedURI, resources[0].Resource.URI) + } + + expectedName := fmt.Sprintf(`"%s" database schema`, tc.tableName) + if resources[0].Resource.Name != expectedName { + t.Errorf("Expected name %s, got %s", expectedName, resources[0].Resource.Name) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + } +} + +// TestDifferentDataTypes tests handling of various PostgreSQL data types +func TestDifferentDataTypes(t *testing.T) { + server, mock, cleanup := createMockPostgresServer(t) + defer cleanup() + + ctx := context.Background() + + // Test schema reading with various PostgreSQL data types + t.Run("various data types", func(t *testing.T) { + rows := sqlmock.NewRows([]string{"column_name", "data_type"}). + AddRow("id", "bigint"). + AddRow("name", "character varying"). + AddRow("description", "text"). + AddRow("price", "numeric"). + AddRow("created_at", "timestamp with time zone"). + AddRow("updated_at", "timestamp without time zone"). + AddRow("is_active", "boolean"). + AddRow("metadata", "jsonb"). + AddRow("tags", "ARRAY") + + mock.ExpectQuery("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = \\$1 ORDER BY ordinal_position"). + WithArgs("products"). + WillReturnRows(rows) + + params := &mcp.ReadResourceParams{ + URI: fmt.Sprintf("%s/products/%s", server.resourceBaseURL.String(), SCHEMA_PATH), + } + + result, err := server.readTableSchema(ctx, nil, params) + if err != nil { + t.Fatalf("readTableSchema() error = %v", err) + } + + if len(result.Contents) != 1 { + t.Fatalf("Expected 1 content item, got %d", len(result.Contents)) + } + + // Parse and verify all data types + var columns []map[string]interface{} + err = json.Unmarshal([]byte(result.Contents[0].Text), &columns) + if err != nil { + t.Fatalf("Failed to parse JSON: %v", err) + } + + if len(columns) != 9 { + t.Errorf("Expected 9 columns, got %d", len(columns)) + } + + expectedTypes := []string{ + "bigint", "character varying", "text", "numeric", + "timestamp with time zone", "timestamp without time zone", + "boolean", "jsonb", "ARRAY", + } + + for i, col := range columns { + if col["data_type"] != expectedTypes[i] { + t.Errorf("Expected data type %s at index %d, got %s", expectedTypes[i], i, col["data_type"]) + } + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) + + // Test query results with various data types + t.Run("query with various data types", func(t *testing.T) { + mock.ExpectBegin() + + rows := sqlmock.NewRows([]string{"id", "name", "price", "is_active", "created_at"}). + AddRow(1, "Product 1", 19.99, true, "2024-01-01T00:00:00Z"). + AddRow(2, "Product 2", 29.99, false, "2024-01-02T00:00:00Z") + + mock.ExpectQuery("SELECT id, name, price, is_active, created_at FROM products"). + WillReturnRows(rows) + + mock.ExpectRollback() + + args := QueryArgs{ + SQL: "SELECT id, name, price, is_active, created_at FROM products", + } + + params := &mcp.CallToolParamsFor[QueryArgs]{ + Name: "query", + Arguments: args, + } + + result, err := server.QueryTool(ctx, nil, params) + if err != nil { + t.Fatalf("QueryTool() error = %v", err) + } + + if result.IsError { + t.Error("Expected successful result, got error") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + // Parse and verify results + var queryResults []map[string]interface{} + err = json.Unmarshal([]byte(textContent.Text), &queryResults) + if err != nil { + t.Fatalf("Failed to parse query result JSON: %v", err) + } + + if len(queryResults) != 2 { + t.Errorf("Expected 2 rows, got %d", len(queryResults)) + } + + // Verify first row data types and values + firstRow := queryResults[0] + // SQL mock returns different numeric types depending on the driver + if id, ok := firstRow["id"].(int64); ok { + if id != 1 { + t.Errorf("Expected id 1, got %v", id) + } + } else if id, ok := firstRow["id"].(int); ok { + if id != 1 { + t.Errorf("Expected id 1, got %v", id) + } + } else if id, ok := firstRow["id"].(float64); ok { + if id != 1.0 { + t.Errorf("Expected id 1, got %v", id) + } + } else { + t.Errorf("Expected id to be numeric, got %v (type: %T)", firstRow["id"], firstRow["id"]) + } + + if firstRow["name"] != "Product 1" { + t.Errorf("Expected name 'Product 1', got %v", firstRow["name"]) + } + if firstRow["is_active"] != true { + t.Errorf("Expected is_active true, got %v", firstRow["is_active"]) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Unfulfilled expectations: %v", err) + } + }) +} From b037d1c5621e1a50cee85f45e43a68fbed54ad67 Mon Sep 17 00:00:00 2001 From: Guikai Date: Thu, 10 Jul 2025 01:05:15 -0700 Subject: [PATCH 2/3] Added postgres_server.go and optimized code --- examples/postgres/main.go | 309 ------------------ examples/postgres/postgres_server.go | 276 ++++++++++++++++ .../{main_test.go => postgres_server_test.go} | 171 +++++++--- 3 files changed, 403 insertions(+), 353 deletions(-) create mode 100644 examples/postgres/postgres_server.go rename examples/postgres/{main_test.go => postgres_server_test.go} (85%) diff --git a/examples/postgres/main.go b/examples/postgres/main.go index 7112e4d..5e3e231 100644 --- a/examples/postgres/main.go +++ b/examples/postgres/main.go @@ -6,325 +6,16 @@ package main import ( "context" - "database/sql" - "encoding/json" "flag" - "fmt" "log" "net/http" - "net/url" "os" - _ "github.com/lib/pq" // PostgreSQL driver "github.com/modelcontextprotocol/go-sdk/mcp" ) var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") -const SCHEMA_PATH = "schema" - -type PostgresServer struct { - db *sql.DB - resourceBaseURL *url.URL -} - -// QueryArgs represents the arguments for the SQL query tool -type QueryArgs struct { - SQL string `json:"sql"` -} - -// NewPostgresServer creates a new PostgreSQL MCP server -func NewPostgresServer(databaseURL string) (*PostgresServer, error) { - db, err := sql.Open("postgres", databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("failed to ping database: %w", err) - } - - resourceBaseURL, err := url.Parse(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse database URL: %w", err) - } - resourceBaseURL.Scheme = "postgres" - resourceBaseURL.User = nil // Remove credentials for security - - return &PostgresServer{ - db: db, - resourceBaseURL: resourceBaseURL, - }, nil -} - -// Close closes the database connection -func (ps *PostgresServer) Close() error { - return ps.db.Close() -} - -// ListTables returns all tables in the public schema as resources -func (ps *PostgresServer) ListTables(ctx context.Context) ([]*mcp.ServerResource, error) { - query := "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" - rows, err := ps.db.QueryContext(ctx, query) - if err != nil { - return nil, fmt.Errorf("failed to query tables: %w", err) - } - defer rows.Close() - - var resources []*mcp.ServerResource - for rows.Next() { - var tableName string - if err := rows.Scan(&tableName); err != nil { - return nil, fmt.Errorf("failed to scan table name: %w", err) - } - - resourceURI := fmt.Sprintf("%s/%s/%s", ps.resourceBaseURL.String(), tableName, SCHEMA_PATH) - resource := &mcp.ServerResource{ - Resource: &mcp.Resource{ - URI: resourceURI, - MIMEType: "application/json", - Name: fmt.Sprintf(`"%s" database schema`, tableName), - }, - Handler: ps.readTableSchema, - } - resources = append(resources, resource) - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating tables: %w", err) - } - - return resources, nil -} - -// readTableSchema handles reading table schema information -func (ps *PostgresServer) readTableSchema(ctx context.Context, ss *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) { - resourceURL, err := url.Parse(params.URI) - if err != nil { - return nil, fmt.Errorf("invalid resource URI: %w", err) - } - - pathComponents := []string{} - for _, component := range []string{resourceURL.Path} { - if component != "" { - pathComponents = append(pathComponents, component) - } - } - - // Parse path: /tableName/schema - parts := []string{} - for _, part := range pathComponents { - if part == "/" { - continue - } - subParts := []string{} - for _, subPart := range []string{part} { - if subPart != "" { - for _, p := range []string{subPart} { - if p != "/" { - subParts = append(subParts, p) - } - } - } - } - parts = append(parts, subParts...) - } - - // Extract table name and schema path from URI - urlPath := resourceURL.Path - if urlPath == "" { - return nil, fmt.Errorf("empty URI path") - } - - // Remove leading slash and split by '/' - trimmedPath := urlPath - if trimmedPath[0] == '/' { - trimmedPath = trimmedPath[1:] - } - pathParts := []string{} - for _, part := range []string{trimmedPath} { - if part != "" { - // Split by '/' - for i, p := range []string{part} { - if i == 0 { - subparts := []string{} - current := "" - for _, r := range p { - if r == '/' { - if current != "" { - subparts = append(subparts, current) - current = "" - } - } else { - current += string(r) - } - } - if current != "" { - subparts = append(subparts, current) - } - pathParts = append(pathParts, subparts...) - } - } - } - } - - if len(pathParts) < 2 { - return nil, fmt.Errorf("invalid resource URI format: expected /tableName/schema") - } - - tableName := pathParts[len(pathParts)-2] - schema := pathParts[len(pathParts)-1] - - if schema != SCHEMA_PATH { - return nil, fmt.Errorf("invalid resource URI: expected schema path") - } - - // Query table columns - query := "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position" - rows, err := ps.db.QueryContext(ctx, query, tableName) - if err != nil { - return nil, fmt.Errorf("failed to query table schema: %w", err) - } - defer rows.Close() - - type ColumnInfo struct { - ColumnName string `json:"column_name"` - DataType string `json:"data_type"` - } - - var columns []ColumnInfo - for rows.Next() { - var column ColumnInfo - if err := rows.Scan(&column.ColumnName, &column.DataType); err != nil { - return nil, fmt.Errorf("failed to scan column info: %w", err) - } - columns = append(columns, column) - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating columns: %w", err) - } - - jsonData, err := json.MarshalIndent(columns, "", " ") - if err != nil { - return nil, fmt.Errorf("failed to marshal columns to JSON: %w", err) - } - - return &mcp.ReadResourceResult{ - Contents: []*mcp.ResourceContents{ - { - URI: params.URI, - MIMEType: "application/json", - Text: string(jsonData), - }, - }, - }, nil -} - -// QueryTool executes a read-only SQL query -func (ps *PostgresServer) QueryTool(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[QueryArgs]) (*mcp.CallToolResultFor[struct{}], error) { - sqlQuery := params.Arguments.SQL - - // Start a read-only transaction - tx, err := ps.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) - if err != nil { - return &mcp.CallToolResultFor[struct{}]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Failed to start transaction: %v", err)}, - }, - IsError: true, - }, nil - } - defer tx.Rollback() // Always rollback read-only transaction - - rows, err := tx.QueryContext(ctx, sqlQuery) - if err != nil { - return &mcp.CallToolResultFor[struct{}]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Query execution failed: %v", err)}, - }, - IsError: true, - }, nil - } - defer rows.Close() - - // Get column names - columns, err := rows.Columns() - if err != nil { - return &mcp.CallToolResultFor[struct{}]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Failed to get columns: %v", err)}, - }, - IsError: true, - }, nil - } - - // Prepare result slice - var results []map[string]interface{} - - // Scan rows - for rows.Next() { - // Create slice of interface{} to hold column values - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range values { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - return &mcp.CallToolResultFor[struct{}]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Failed to scan row: %v", err)}, - }, - IsError: true, - }, nil - } - - // Convert to map - row := make(map[string]interface{}) - for i, col := range columns { - val := values[i] - if b, ok := val.([]byte); ok { - row[col] = string(b) - } else { - row[col] = val - } - } - results = append(results, row) - } - - if err := rows.Err(); err != nil { - return &mcp.CallToolResultFor[struct{}]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error iterating rows: %v", err)}, - }, - IsError: true, - }, nil - } - - // Ensure results is not nil for proper JSON marshaling - if results == nil { - results = []map[string]interface{}{} - } - - // Convert results to JSON - jsonData, err := json.MarshalIndent(results, "", " ") - if err != nil { - return &mcp.CallToolResultFor[struct{}]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Failed to marshal results: %v", err)}, - }, - IsError: true, - }, nil - } - - return &mcp.CallToolResultFor[struct{}]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: string(jsonData)}, - }, - }, nil -} - func main() { flag.Parse() diff --git a/examples/postgres/postgres_server.go b/examples/postgres/postgres_server.go new file mode 100644 index 0000000..2ef7e3b --- /dev/null +++ b/examples/postgres/postgres_server.go @@ -0,0 +1,276 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net/url" + "strings" + + _ "github.com/lib/pq" // PostgreSQL driver + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +const SCHEMA_PATH = "schema" + +type PostgresServer struct { + db *sql.DB + resourceBaseURL *url.URL +} + +// QueryArgs represents the arguments for the SQL query tool +type QueryArgs struct { + SQL string `json:"sql"` +} + +// NewPostgresServer creates a new PostgreSQL MCP server +func NewPostgresServer(databaseURL string) (*PostgresServer, error) { + db, err := sql.Open("postgres", databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + resourceBaseURL, err := url.Parse(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse database URL: %w", err) + } + resourceBaseURL.Scheme = "postgres" + resourceBaseURL.User = nil // Remove credentials for security + + return &PostgresServer{ + db: db, + resourceBaseURL: resourceBaseURL, + }, nil +} + +// Close closes the database connection +func (ps *PostgresServer) Close() error { + return ps.db.Close() +} + +// ListTables returns all tables in the public schema as resources +func (ps *PostgresServer) ListTables(ctx context.Context) ([]*mcp.ServerResource, error) { + query := "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" + rows, err := ps.db.QueryContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to query tables: %w", err) + } + defer rows.Close() + + var resources []*mcp.ServerResource + for rows.Next() { + var tableName string + if err := rows.Scan(&tableName); err != nil { + return nil, fmt.Errorf("failed to scan table name: %w", err) + } + + resourceURI := fmt.Sprintf("%s/%s/%s", ps.resourceBaseURL.String(), tableName, SCHEMA_PATH) + resource := &mcp.ServerResource{ + Resource: &mcp.Resource{ + URI: resourceURI, + MIMEType: "application/json", + Name: fmt.Sprintf(`"%s" database schema`, tableName), + }, + Handler: ps.readTableSchema, + } + resources = append(resources, resource) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating tables: %w", err) + } + + return resources, nil +} + +// parseResourceURI extracts the table name from a MCP resource URI. +// +// The expected URI format is: postgres://host:port/database/tableName/schema +// where: +// - postgres://host:port/database is the base URL +// - tableName is the name of the database table +// - schema is the literal string "schema" (defined by SCHEMA_PATH constant) +// +// Examples: +// - "postgres://localhost:5432/mydb/users/schema" โ†’ tableName: "users" +// - "postgres://localhost:5432/mydb/order_items/schema" โ†’ tableName: "order_items" +// +// Returns the extracted table name or an error if the URI format is invalid. +func parseResourceURI(resourceURI string) (tableName string, err error) { + resourceURL, err := url.Parse(resourceURI) + if err != nil { + return "", fmt.Errorf("invalid resource URI: %w", err) + } + + // Extract path from URI + urlPath := resourceURL.Path + if urlPath == "" { + return "", fmt.Errorf("empty URI path") + } + + // Remove leading slash and split by '/' + trimmedPath := strings.TrimPrefix(urlPath, "/") + + // Split path into components using strings.Split for better performance (O(n) vs O(nยฒ)) + var pathParts []string + if trimmedPath != "" { + pathParts = strings.Split(trimmedPath, "/") + // Filter out empty parts that might result from consecutive slashes + var filteredParts []string + for _, part := range pathParts { + if part != "" { + filteredParts = append(filteredParts, part) + } + } + pathParts = filteredParts + } + + // Validate path format: expect at least 2 parts (tableName/schema) + if len(pathParts) < 2 { + return "", fmt.Errorf("invalid resource URI format: expected /tableName/schema, got: %s", urlPath) + } + + // Extract table name and schema path + tableName = pathParts[len(pathParts)-2] + schema := pathParts[len(pathParts)-1] + + // Validate schema path + if schema != SCHEMA_PATH { + return "", fmt.Errorf("invalid resource URI: expected schema path '%s', got '%s'", SCHEMA_PATH, schema) + } + + return tableName, nil +} + +// readTableSchema handles reading table schema information +func (ps *PostgresServer) readTableSchema(ctx context.Context, ss *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) { + // Parse the resource URI to extract table name + tableName, err := parseResourceURI(params.URI) + if err != nil { + return nil, err + } + + // Query table columns + query := "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position" + rows, err := ps.db.QueryContext(ctx, query, tableName) + if err != nil { + return nil, fmt.Errorf("failed to query table schema: %w", err) + } + defer rows.Close() + + type ColumnInfo struct { + ColumnName string `json:"column_name"` + DataType string `json:"data_type"` + } + + var columns []ColumnInfo + for rows.Next() { + var column ColumnInfo + if err := rows.Scan(&column.ColumnName, &column.DataType); err != nil { + return nil, fmt.Errorf("failed to scan column info: %w", err) + } + columns = append(columns, column) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating columns: %w", err) + } + + jsonData, err := json.MarshalIndent(columns, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal columns to JSON: %w", err) + } + + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: params.URI, + MIMEType: "application/json", + Text: string(jsonData), + }, + }, + }, nil +} + +// QueryTool executes a read-only SQL query +func (ps *PostgresServer) QueryTool(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[QueryArgs]) (*mcp.CallToolResultFor[struct{}], error) { + sqlQuery := params.Arguments.SQL + + // Start a read-only transaction + tx, err := ps.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.Rollback() // Always rollback read-only transaction + + rows, err := tx.QueryContext(ctx, sqlQuery) + if err != nil { + return nil, fmt.Errorf("query execution failed: %w", err) + } + defer rows.Close() + + // Get column names + columns, err := rows.Columns() + if err != nil { + return nil, fmt.Errorf("failed to get columns: %w", err) + } + + // Prepare result slice + var results []map[string]any + + // Scan rows + for rows.Next() { + // Create slice of any to hold column values + values := make([]any, len(columns)) + valuePtrs := make([]any, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + + // Convert to map + row := make(map[string]any) + for i, col := range columns { + val := values[i] + if b, ok := val.([]byte); ok { + row[col] = string(b) + } else { + row[col] = val + } + } + results = append(results, row) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating rows: %w", err) + } + + // Ensure results is not nil for proper JSON marshaling + if results == nil { + results = []map[string]any{} + } + + // Convert results to JSON + jsonData, err := json.MarshalIndent(results, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal results: %w", err) + } + + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: string(jsonData)}, + }, + }, nil +} diff --git a/examples/postgres/main_test.go b/examples/postgres/postgres_server_test.go similarity index 85% rename from examples/postgres/main_test.go rename to examples/postgres/postgres_server_test.go index 4529a6f..7af58cc 100644 --- a/examples/postgres/main_test.go +++ b/examples/postgres/postgres_server_test.go @@ -39,6 +39,109 @@ func createMockPostgresServer(t *testing.T) (*PostgresServer, sqlmock.Sqlmock, f return server, mock, cleanup } +func TestParseResourceURI(t *testing.T) { + tests := []struct { + name string + resourceURI string + expectedTable string + wantError bool + errorContains string + }{ + { + name: "valid URI with users table", + resourceURI: "postgres://localhost:5432/testdb/users/schema", + expectedTable: "users", + wantError: false, + }, + { + name: "valid URI with underscore table name", + resourceURI: "postgres://localhost:5432/testdb/user_profiles/schema", + expectedTable: "user_profiles", + wantError: false, + }, + { + name: "valid URI with numeric table name", + resourceURI: "postgres://localhost:5432/testdb/orders123/schema", + expectedTable: "orders123", + wantError: false, + }, + { + name: "valid URI with multiple path segments in database", + resourceURI: "postgres://localhost:5432/path/to/db/products/schema", + expectedTable: "products", + wantError: false, + }, + { + name: "invalid URI - malformed URL", + resourceURI: "not-a-url", + wantError: true, + errorContains: "invalid resource URI", + }, + { + name: "invalid URI - empty path", + resourceURI: "postgres://localhost:5432", + wantError: true, + errorContains: "empty URI path", + }, + { + name: "invalid URI - missing schema path", + resourceURI: "postgres://localhost:5432/testdb/users", + wantError: true, + errorContains: "invalid resource URI: expected schema path 'schema', got 'users'", + }, + { + name: "invalid URI - only table name, no schema", + resourceURI: "postgres://localhost:5432/testdb/users/", + wantError: true, + errorContains: "invalid resource URI: expected schema path 'schema', got 'users'", + }, + { + name: "invalid URI - wrong schema path", + resourceURI: "postgres://localhost:5432/testdb/users/wrong", + wantError: true, + errorContains: "invalid resource URI: expected schema path 'schema', got 'wrong'", + }, + { + name: "invalid URI - only one path component", + resourceURI: "postgres://localhost:5432/users", + wantError: true, + errorContains: "invalid resource URI format: expected /tableName/schema, got: /users", + }, + { + name: "invalid URI - empty path after slash", + resourceURI: "postgres://localhost:5432/", + wantError: true, + errorContains: "invalid resource URI format: expected /tableName/schema", + }, + { + name: "edge case - table name with special characters", + resourceURI: "postgres://localhost:5432/testdb/test-table_123/schema", + expectedTable: "test-table_123", + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tableName, err := parseResourceURI(tt.resourceURI) + + if tt.wantError { + if err == nil { + t.Errorf("parseResourceURI() expected error, got nil") + } else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("parseResourceURI() error = %v, want error containing %v", err, tt.errorContains) + } + } else { + if err != nil { + t.Errorf("parseResourceURI() unexpected error = %v", err) + } else if tableName != tt.expectedTable { + t.Errorf("parseResourceURI() tableName = %v, want %v", tableName, tt.expectedTable) + } + } + }) + } +} + func TestNewPostgresServer(t *testing.T) { tests := []struct { name string @@ -414,21 +517,16 @@ func TestPostgresServer_QueryTool(t *testing.T) { } result, err := server.QueryTool(ctx, nil, params) - if err != nil { - t.Fatalf("QueryTool() error = %v", err) + if err == nil { + t.Fatal("Expected error, got nil") } - if !result.IsError { - t.Error("Expected error result") + if result != nil { + t.Error("Expected nil result when error is returned") } - textContent, ok := result.Content[0].(*mcp.TextContent) - if !ok { - t.Fatal("Expected TextContent") - } - - if !strings.Contains(textContent.Text, "Failed to start transaction") { - t.Errorf("Expected error message about transaction, got %s", textContent.Text) + if !strings.Contains(err.Error(), "failed to start transaction") { + t.Errorf("Expected error message about transaction, got %s", err.Error()) } if err := mock.ExpectationsWereMet(); err != nil { @@ -453,21 +551,16 @@ func TestPostgresServer_QueryTool(t *testing.T) { } result, err := server.QueryTool(ctx, nil, params) - if err != nil { - t.Fatalf("QueryTool() error = %v", err) + if err == nil { + t.Fatal("Expected error, got nil") } - if !result.IsError { - t.Error("Expected error result") + if result != nil { + t.Error("Expected nil result when error is returned") } - textContent, ok := result.Content[0].(*mcp.TextContent) - if !ok { - t.Fatal("Expected TextContent") - } - - if !strings.Contains(textContent.Text, "Query execution failed") { - t.Errorf("Expected query execution error message, got %s", textContent.Text) + if !strings.Contains(err.Error(), "query execution failed") { + t.Errorf("Expected query execution error message, got %s", err.Error()) } if err := mock.ExpectationsWereMet(); err != nil { @@ -717,22 +810,17 @@ func TestEdgeCases(t *testing.T) { } result, err := server.QueryTool(ctx, nil, params) - if err != nil { - t.Fatalf("QueryTool() error = %v", err) - } - - if !result.IsError { - t.Error("Expected error result") + if err == nil { + t.Fatal("Expected error, got nil") } - textContent, ok := result.Content[0].(*mcp.TextContent) - if !ok { - t.Fatal("Expected TextContent") + if result != nil { + t.Error("Expected nil result when error is returned") } // CloseError actually affects rows.Err(), not Columns() - if !strings.Contains(textContent.Text, "Error iterating rows") { - t.Errorf("Expected row error message, got %s", textContent.Text) + if !strings.Contains(err.Error(), "error iterating rows") { + t.Errorf("Expected row error message, got %s", err.Error()) } if err := mock.ExpectationsWereMet(); err != nil { @@ -763,21 +851,16 @@ func TestEdgeCases(t *testing.T) { } result, err := server.QueryTool(ctx, nil, params) - if err != nil { - t.Fatalf("QueryTool() error = %v", err) - } - - if !result.IsError { - t.Error("Expected error result") + if err == nil { + t.Fatal("Expected error, got nil") } - textContent, ok := result.Content[0].(*mcp.TextContent) - if !ok { - t.Fatal("Expected TextContent") + if result != nil { + t.Error("Expected nil result when error is returned") } - if !strings.Contains(textContent.Text, "Error iterating rows") { - t.Errorf("Expected row iteration error message, got %s", textContent.Text) + if !strings.Contains(err.Error(), "error iterating rows") { + t.Errorf("Expected row iteration error message, got %s", err.Error()) } if err := mock.ExpectationsWereMet(); err != nil { From 92bfa16ec72025b5b9a45a4470751245cf43454b Mon Sep 17 00:00:00 2001 From: Guikai Date: Thu, 10 Jul 2025 01:12:31 -0700 Subject: [PATCH 3/3] format the code --- examples/postgres/postgres_server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/postgres/postgres_server.go b/examples/postgres/postgres_server.go index 2ef7e3b..c58d02b 100644 --- a/examples/postgres/postgres_server.go +++ b/examples/postgres/postgres_server.go @@ -119,7 +119,7 @@ func parseResourceURI(resourceURI string) (tableName string, err error) { // Remove leading slash and split by '/' trimmedPath := strings.TrimPrefix(urlPath, "/") - + // Split path into components using strings.Split for better performance (O(n) vs O(nยฒ)) var pathParts []string if trimmedPath != "" {