From d9bec15b00f755bdc9f5a7bac11425737cbb3c10 Mon Sep 17 00:00:00 2001 From: Andrew Qu Date: Mon, 22 Sep 2025 17:05:18 -0700 Subject: [PATCH] test --- examples/auth-observability/route.ts | 123 +++++++++++++ examples/observability/route.ts | 70 ++++++++ src/index.ts | 9 + src/observability/index.ts | 8 + src/observability/observability-wrapper.ts | 198 +++++++++++++++++++++ tests/observability.test.ts | 195 ++++++++++++++++++++ 6 files changed, 603 insertions(+) create mode 100644 examples/auth-observability/route.ts create mode 100644 examples/observability/route.ts create mode 100644 src/observability/index.ts create mode 100644 src/observability/observability-wrapper.ts create mode 100644 tests/observability.test.ts diff --git a/examples/auth-observability/route.ts b/examples/auth-observability/route.ts new file mode 100644 index 0000000..fdc5b1d --- /dev/null +++ b/examples/auth-observability/route.ts @@ -0,0 +1,123 @@ +import { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types"; +import { + createMcpHandler, + withMcpAuth, + withObservability, + type ObservabilityConfig, + addSpanAttribute, +} from "mcp-handler"; +import { z } from "zod"; + +// Define the handler with proper parameter validation +const handler = createMcpHandler( + (server) => { + server.tool( + "secure-echo", + "Echo a message back with both authentication and observability", + { + message: z.string().describe("The message to echo back"), + }, + async ({ message }, extra) => { + // Add custom attributes to the current span + if (extra.authInfo?.clientId) { + addSpanAttribute("user.client_id", extra.authInfo.clientId); + } + + addSpanAttribute("operation.type", "echo"); + addSpanAttribute("message.length", message.length); + + return { + content: [ + { + type: "text", + text: `Secure Echo: ${message}${ + extra.authInfo?.token + ? ` (authenticated as ${extra.authInfo.clientId})` + : "" + }`, + }, + ], + }; + } + ); + }, + // Server capabilities + { + capabilities: { + auth: { + type: "bearer", + required: true, + }, + tools: {}, + }, + }, + // Route configuration + { + streamableHttpEndpoint: "/mcp", + sseEndpoint: "/sse", + sseMessageEndpoint: "/message", + basePath: "/api/mcp", + redisUrl: process.env.REDIS_URL, + } +); + +/** + * Verify the bearer token and return auth information + * In a real implementation, this would validate against your auth service + */ +const verifyToken = async ( + req: Request, + bearerToken?: string +): Promise => { + if (!bearerToken) return undefined; + + // Add tracing for auth verification + addSpanAttribute("auth.token_present", true); + + // TODO: Replace with actual token verification logic + const isValid = bearerToken.startsWith("__TEST_VALUE__"); + + addSpanAttribute("auth.validation_result", isValid); + + if (!isValid) return undefined; + + return { + token: bearerToken, + scopes: ["read:messages", "write:messages"], + clientId: "example-client", + extra: { + userId: "user-123", + permissions: ["user"], + timestamp: new Date().toISOString(), + }, + }; +}; + +// Observability configuration +const observabilityConfig: ObservabilityConfig = { + serviceName: "secure-mcp-service", + serviceVersion: "1.0.0", + traceIdHeader: "x-trace-id", + spanIdHeader: "x-span-id", + customAttributes: { + "service.environment": process.env.NODE_ENV || "development", + "service.instance": process.env.HOSTNAME || "local", + "service.auth_enabled": "true", + }, + enableRequestLogging: true, + enableErrorTracking: true, + ignoreEndpoints: ["/health", "/metrics", "/.well-known/oauth-protected-resource"], + samplingRate: 1.0, +}; + +// Apply wrappers in order: observability first, then auth +// This ensures auth errors are also traced +const observabilityHandler = withObservability(handler, observabilityConfig); +const authAndObservabilityHandler = withMcpAuth(observabilityHandler, verifyToken, { + required: true, + requiredScopes: ["read:messages"], + resourceMetadataPath: "/.well-known/oauth-protected-resource", +}); + +// Export the handler for both GET and POST methods +export { authAndObservabilityHandler as GET, authAndObservabilityHandler as POST }; \ No newline at end of file diff --git a/examples/observability/route.ts b/examples/observability/route.ts new file mode 100644 index 0000000..66b1608 --- /dev/null +++ b/examples/observability/route.ts @@ -0,0 +1,70 @@ +import { + createMcpHandler, + withObservability, + type ObservabilityConfig, +} from "mcp-handler"; +import { z } from "zod"; + +// Define the handler with proper parameter validation +const handler = createMcpHandler( + (server) => { + server.tool( + "echo-with-tracing", + "Echo a message back with observability tracing", + { + message: z.string().describe("The message to echo back"), + }, + async ({ message }, extra) => { + // You can add custom span attributes within your handler + if (typeof req !== 'undefined' && req.traceId) { + console.log(`Processing echo request with trace ID: ${req.traceId}`); + } + + return { + content: [ + { + type: "text", + text: `Echo: ${message}`, + }, + ], + }; + } + ); + }, + // Server capabilities + { + capabilities: { + tools: {}, + }, + }, + // Route configuration + { + streamableHttpEndpoint: "/mcp", + sseEndpoint: "/sse", + sseMessageEndpoint: "/message", + basePath: "/api/mcp", + redisUrl: process.env.REDIS_URL, + } +); + +// Observability configuration +const observabilityConfig: ObservabilityConfig = { + serviceName: "mcp-echo-service", + serviceVersion: "1.0.0", + traceIdHeader: "x-trace-id", + spanIdHeader: "x-span-id", + customAttributes: { + "service.environment": process.env.NODE_ENV || "development", + "service.instance": process.env.HOSTNAME || "local", + }, + enableRequestLogging: true, + enableErrorTracking: true, + ignoreEndpoints: ["/health", "/metrics"], + samplingRate: 1.0, // 100% sampling for demo, adjust as needed +}; + +// Wrap the handler with observability +const observabilityHandler = withObservability(handler, observabilityConfig); + +// Export the handler for both GET and POST methods +export { observabilityHandler as GET, observabilityHandler as POST }; \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index 93b3235..6e83bb1 100644 --- a/src/index.ts +++ b/src/index.ts @@ -13,3 +13,12 @@ export { generateProtectedResourceMetadata, metadataCorsOptionsRequestHandler, } from "./auth/auth-metadata"; + +export { + withObservability, + createObservabilitySpan, + getCurrentSpan, + addSpanAttribute, + addSpanEvent, + type ObservabilityConfig, +} from "./observability"; diff --git a/src/observability/index.ts b/src/observability/index.ts new file mode 100644 index 0000000..7a9d681 --- /dev/null +++ b/src/observability/index.ts @@ -0,0 +1,8 @@ +export { + withObservability, + createObservabilitySpan, + getCurrentSpan, + addSpanAttribute, + addSpanEvent, + type ObservabilityConfig, +} from './observability-wrapper'; \ No newline at end of file diff --git a/src/observability/observability-wrapper.ts b/src/observability/observability-wrapper.ts new file mode 100644 index 0000000..8af3016 --- /dev/null +++ b/src/observability/observability-wrapper.ts @@ -0,0 +1,198 @@ +import { trace, context, SpanStatusCode, SpanKind } from '@opentelemetry/api'; + +export interface ObservabilityConfig { + serviceName: string; + serviceVersion?: string; + traceIdHeader?: string; + spanIdHeader?: string; + customAttributes?: Record; + enableRequestLogging?: boolean; + enableErrorTracking?: boolean; + ignoreEndpoints?: string[]; + samplingRate?: number; +} + +declare global { + interface Request { + traceId?: string; + spanId?: string; + } +} + +export function withObservability( + handler: (req: Request) => Response | Promise, + config: ObservabilityConfig +) { + const { + serviceName, + serviceVersion = '1.0.0', + traceIdHeader = 'x-trace-id', + spanIdHeader = 'x-span-id', + customAttributes = {}, + enableRequestLogging = true, + enableErrorTracking = true, + ignoreEndpoints = [], + samplingRate = 1.0, + } = config; + + const tracer = trace.getTracer(serviceName, serviceVersion); + + return async (req: Request) => { + const url = new URL(req.url); + const method = req.method || 'GET'; + const pathname = url.pathname; + + // Skip tracing for ignored endpoints + if (ignoreEndpoints.some(endpoint => pathname.startsWith(endpoint))) { + return handler(req); + } + + // Apply sampling rate + if (Math.random() > samplingRate) { + return handler(req); + } + + // Extract trace context from headers if present + const traceId = req.headers.get(traceIdHeader); + const spanId = req.headers.get(spanIdHeader); + + // Attach trace info to request + if (traceId) req.traceId = traceId; + if (spanId) req.spanId = spanId; + + const spanName = `${method} ${pathname}`; + + return tracer.startActiveSpan( + spanName, + { + kind: SpanKind.SERVER, + attributes: { + 'http.method': method, + 'http.url': req.url, + 'http.scheme': url.protocol.replace(':', ''), + 'http.host': url.host, + 'http.target': pathname, + 'http.user_agent': req.headers.get('user-agent') || '', + 'service.name': serviceName, + 'service.version': serviceVersion, + ...customAttributes, + }, + }, + async (span) => { + const startTime = Date.now(); + + try { + // Log request if enabled + if (enableRequestLogging) { + console.log(`[${serviceName}] ${method} ${pathname} - Started`, { + traceId: span.spanContext().traceId, + spanId: span.spanContext().spanId, + }); + } + + // Execute the original handler + const response = await handler(req); + const duration = Date.now() - startTime; + + // Add response attributes to span + span.setAttributes({ + 'http.status_code': response.status, + 'http.response.duration_ms': duration, + }); + + // Set span status based on response + if (response.status >= 400) { + span.setStatus({ + code: SpanStatusCode.ERROR, + message: `HTTP ${response.status}`, + }); + } else { + span.setStatus({ code: SpanStatusCode.OK }); + } + + // Log response if enabled + if (enableRequestLogging) { + console.log(`[${serviceName}] ${method} ${pathname} - ${response.status} (${duration}ms)`, { + traceId: span.spanContext().traceId, + spanId: span.spanContext().spanId, + status: response.status, + duration, + }); + } + + // Add trace headers to response + const responseHeaders = new Headers(response.headers); + responseHeaders.set(traceIdHeader, span.spanContext().traceId); + responseHeaders.set(spanIdHeader, span.spanContext().spanId); + + return new Response(response.body, { + status: response.status, + statusText: response.statusText, + headers: responseHeaders, + }); + + } catch (error) { + const duration = Date.now() - startTime; + + // Record error in span + span.recordException(error as Error); + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error instanceof Error ? error.message : String(error), + }); + + span.setAttributes({ + 'error': true, + 'error.type': error instanceof Error ? error.constructor.name : 'Unknown', + 'error.message': error instanceof Error ? error.message : String(error), + 'http.response.duration_ms': duration, + }); + + // Log error if enabled + if (enableErrorTracking) { + console.error(`[${serviceName}] ${method} ${pathname} - Error (${duration}ms)`, { + traceId: span.spanContext().traceId, + spanId: span.spanContext().spanId, + error: error instanceof Error ? error.message : String(error), + duration, + }); + } + + throw error; + } finally { + span.end(); + } + } + ); + }; +} + +export function createObservabilitySpan( + name: string, + serviceName: string, + attributes?: Record +) { + const tracer = trace.getTracer(serviceName); + return tracer.startSpan(name, { + kind: SpanKind.INTERNAL, + attributes, + }); +} + +export function getCurrentSpan() { + return trace.getActiveSpan(); +} + +export function addSpanAttribute(key: string, value: string | number | boolean) { + const span = getCurrentSpan(); + if (span) { + span.setAttribute(key, value); + } +} + +export function addSpanEvent(name: string, attributes?: Record) { + const span = getCurrentSpan(); + if (span) { + span.addEvent(name, attributes); + } +} \ No newline at end of file diff --git a/tests/observability.test.ts b/tests/observability.test.ts new file mode 100644 index 0000000..976fabc --- /dev/null +++ b/tests/observability.test.ts @@ -0,0 +1,195 @@ +import { describe, it, expect, beforeEach, vi } from "vitest"; +import { withObservability, type ObservabilityConfig } from "../src/observability"; + +// Mock OpenTelemetry +vi.mock('@opentelemetry/api', () => ({ + trace: { + getTracer: vi.fn(() => ({ + startActiveSpan: vi.fn((name, options, callback) => { + const mockSpan = { + spanContext: () => ({ + traceId: 'mock-trace-id-123', + spanId: 'mock-span-id-456', + }), + setAttributes: vi.fn(), + setStatus: vi.fn(), + recordException: vi.fn(), + setAttribute: vi.fn(), + addEvent: vi.fn(), + end: vi.fn(), + }; + return callback(mockSpan); + }), + })), + getActiveSpan: vi.fn(() => ({ + spanContext: () => ({ + traceId: 'mock-trace-id-123', + spanId: 'mock-span-id-456', + }), + setAttribute: vi.fn(), + addEvent: vi.fn(), + })), + }, + context: {}, + SpanStatusCode: { + OK: 'OK', + ERROR: 'ERROR', + }, + SpanKind: { + SERVER: 'SERVER', + INTERNAL: 'INTERNAL', + }, +})); + +describe("withObservability", () => { + let mockHandler: ReturnType; + let config: ObservabilityConfig; + + beforeEach(() => { + mockHandler = vi.fn(); + config = { + serviceName: "test-service", + serviceVersion: "1.0.0", + enableRequestLogging: false, // Disable for cleaner tests + enableErrorTracking: true, + }; + }); + + it("should wrap handler and add trace headers to response", async () => { + const mockResponse = new Response("test response", { status: 200 }); + mockHandler.mockResolvedValue(mockResponse); + + const wrappedHandler = withObservability(mockHandler, config); + const request = new Request("https://example.com/test", { method: "GET" }); + + const result = await wrappedHandler(request); + + expect(mockHandler).toHaveBeenCalledWith(request); + expect(result.headers.get("x-trace-id")).toBe("mock-trace-id-123"); + expect(result.headers.get("x-span-id")).toBe("mock-span-id-456"); + }); + + it("should extract trace context from request headers", async () => { + const mockResponse = new Response("test response", { status: 200 }); + mockHandler.mockResolvedValue(mockResponse); + + const wrappedHandler = withObservability(mockHandler, config); + const request = new Request("https://example.com/test", { + method: "GET", + headers: { + "x-trace-id": "existing-trace-123", + "x-span-id": "existing-span-456", + }, + }); + + await wrappedHandler(request); + + expect(request.traceId).toBe("existing-trace-123"); + expect(request.spanId).toBe("existing-span-456"); + }); + + it("should skip tracing for ignored endpoints", async () => { + const mockResponse = new Response("health ok", { status: 200 }); + mockHandler.mockResolvedValue(mockResponse); + + const configWithIgnored: ObservabilityConfig = { + ...config, + ignoreEndpoints: ["/health", "/metrics"], + }; + + const wrappedHandler = withObservability(mockHandler, configWithIgnored); + const request = new Request("https://example.com/health", { method: "GET" }); + + const result = await wrappedHandler(request); + + expect(mockHandler).toHaveBeenCalledWith(request); + // Should not have trace headers for ignored endpoints + expect(result.headers.get("x-trace-id")).toBeNull(); + expect(result.headers.get("x-span-id")).toBeNull(); + }); + + it("should respect sampling rate", async () => { + const mockResponse = new Response("test response", { status: 200 }); + mockHandler.mockResolvedValue(mockResponse); + + // Mock Math.random to return 0.9 + const originalRandom = Math.random; + Math.random = vi.fn(() => 0.9); + + try { + const configWithSampling: ObservabilityConfig = { + ...config, + samplingRate: 0.5, // 50% sampling rate + }; + + const wrappedHandler = withObservability(mockHandler, configWithSampling); + const request = new Request("https://example.com/test", { method: "GET" }); + + const result = await wrappedHandler(request); + + // Should skip tracing due to sampling + expect(result.headers.get("x-trace-id")).toBeNull(); + } finally { + Math.random = originalRandom; + } + }); + + it("should handle errors properly", async () => { + const error = new Error("Test error"); + mockHandler.mockRejectedValue(error); + + const wrappedHandler = withObservability(mockHandler, config); + const request = new Request("https://example.com/test", { method: "POST" }); + + await expect(wrappedHandler(request)).rejects.toThrow("Test error"); + expect(mockHandler).toHaveBeenCalledWith(request); + }); + + it("should use custom trace headers when specified", async () => { + const mockResponse = new Response("test response", { status: 200 }); + mockHandler.mockResolvedValue(mockResponse); + + const customConfig: ObservabilityConfig = { + ...config, + traceIdHeader: "custom-trace-header", + spanIdHeader: "custom-span-header", + }; + + const wrappedHandler = withObservability(mockHandler, customConfig); + const request = new Request("https://example.com/test", { + method: "GET", + headers: { + "custom-trace-header": "custom-trace-123", + "custom-span-header": "custom-span-456", + }, + }); + + const result = await wrappedHandler(request); + + expect(request.traceId).toBe("custom-trace-123"); + expect(request.spanId).toBe("custom-span-456"); + expect(result.headers.get("custom-trace-header")).toBe("mock-trace-id-123"); + expect(result.headers.get("custom-span-header")).toBe("mock-span-id-456"); + }); + + it("should add custom attributes", async () => { + const mockResponse = new Response("test response", { status: 200 }); + mockHandler.mockResolvedValue(mockResponse); + + const customConfig: ObservabilityConfig = { + ...config, + customAttributes: { + "service.environment": "test", + "custom.number": 42, + "custom.boolean": true, + }, + }; + + const wrappedHandler = withObservability(mockHandler, customConfig); + const request = new Request("https://example.com/test", { method: "GET" }); + + await wrappedHandler(request); + + expect(mockHandler).toHaveBeenCalledWith(request); + }); +}); \ No newline at end of file