diff --git a/cmd/thv/app/proxy.go b/cmd/thv/app/proxy.go index 90da87e18..633c64b42 100644 --- a/cmd/thv/app/proxy.go +++ b/cmd/thv/app/proxy.go @@ -171,9 +171,10 @@ func init() { if err := proxyCmd.MarkFlagRequired("target-uri"); err != nil { logger.Warnf("Warning: Failed to mark flag as required: %v", err) } - - // Attach the subcommand to the main proxy command + // Attach the subcommands to the main proxy command proxyCmd.AddCommand(proxyTunnelCmd) + proxyCmd.AddCommand(proxyStdioCmd) + } func proxyCmdFunc(cmd *cobra.Command, args []string) error { diff --git a/cmd/thv/app/proxy_stdio.go b/cmd/thv/app/proxy_stdio.go new file mode 100644 index 000000000..d17a430c9 --- /dev/null +++ b/cmd/thv/app/proxy_stdio.go @@ -0,0 +1,53 @@ +package app + +import ( + "fmt" + "os/signal" + "syscall" + + "github.com/spf13/cobra" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/transport" + "github.com/stacklok/toolhive/pkg/workloads" +) + +var proxyStdioCmd = &cobra.Command{ + Use: "stdio WORKLOAD-NAME", + Short: "Create a stdio-based proxy for an MCP server", + Long: `Create a stdio-based proxy that connects stdin/stdout to a target MCP server. + +Example: + thv proxy stdio my-workload +`, + Args: cobra.ExactArgs(1), + RunE: proxyStdioCmdFunc, +} + +func proxyStdioCmdFunc(cmd *cobra.Command, args []string) error { + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + workloadName := args[0] + workloadManager, err := workloads.NewManager(ctx) + if err != nil { + return fmt.Errorf("failed to create workload manager: %w", err) + } + stdioWorkload, err := workloadManager.GetWorkload(ctx, workloadName) + if err != nil { + return fmt.Errorf("failed to get workload %q: %w", workloadName, err) + } + logger.Infof("Starting stdio proxy for workload=%q", workloadName) + + bridge, err := transport.NewStdioBridge(stdioWorkload.URL, stdioWorkload.TransportType) + if err != nil { + return fmt.Errorf("failed to create stdio bridge: %w", err) + } + bridge.Start(ctx) + + // Consume until interrupt + <-ctx.Done() + logger.Info("Shutting down bridge") + bridge.Shutdown() + return nil +} diff --git a/docs/cli/thv_proxy.md b/docs/cli/thv_proxy.md index a4e0a9513..cd38f232f 100644 --- a/docs/cli/thv_proxy.md +++ b/docs/cli/thv_proxy.md @@ -115,5 +115,6 @@ thv proxy [flags] SERVER_NAME ### SEE ALSO * [thv](thv.md) - ToolHive (thv) is a lightweight, secure, and fast manager for MCP servers +* [thv proxy stdio](thv_proxy_stdio.md) - Create a stdio-based proxy for an MCP server * [thv proxy tunnel](thv_proxy_tunnel.md) - Create a tunnel proxy for exposing internal endpoints diff --git a/docs/cli/thv_proxy_stdio.md b/docs/cli/thv_proxy_stdio.md new file mode 100644 index 000000000..5006ade6a --- /dev/null +++ b/docs/cli/thv_proxy_stdio.md @@ -0,0 +1,43 @@ +--- +title: thv proxy stdio +hide_title: true +description: Reference for ToolHive CLI command `thv proxy stdio` +last_update: + author: autogenerated +slug: thv_proxy_stdio +mdx: + format: md +--- + +## thv proxy stdio + +Create a stdio-based proxy for an MCP server + +### Synopsis + +Create a stdio-based proxy that connects stdin/stdout to a target MCP server. + +Example: + thv proxy stdio my-workload + + +``` +thv proxy stdio WORKLOAD-NAME [flags] +``` + +### Options + +``` + -h, --help help for stdio +``` + +### Options inherited from parent commands + +``` + --debug Enable debug mode +``` + +### SEE ALSO + +* [thv proxy](thv_proxy.md) - Create a transparent proxy for an MCP server with authentication support + diff --git a/pkg/container/images/registry.go b/pkg/container/images/registry.go index 55f9f8908..acc588f3f 100644 --- a/pkg/container/images/registry.go +++ b/pkg/container/images/registry.go @@ -9,7 +9,7 @@ import ( "github.com/docker/docker/client" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/name" - "github.com/google/go-containerregistry/pkg/v1" + v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/daemon" "github.com/google/go-containerregistry/pkg/v1/remote" diff --git a/pkg/transport/bridge.go b/pkg/transport/bridge.go new file mode 100644 index 000000000..7dfc69813 --- /dev/null +++ b/pkg/transport/bridge.go @@ -0,0 +1,236 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/transport/types" +) + +// StdioBridge connects stdin/stdout to a target MCP server using the specified transport type. +type StdioBridge struct { + mode types.TransportType + rawTarget string // upstream base URL + + up *client.Client + srv *server.MCPServer + + wg sync.WaitGroup + cancel context.CancelFunc +} + +// NewStdioBridge creates a new StdioBridge instance for the given target URL and transport type. +func NewStdioBridge(rawURL string, mode types.TransportType) (*StdioBridge, error) { + return &StdioBridge{mode: mode, rawTarget: rawURL}, nil +} + +// Start initializes the bridge and connects to the upstream MCP server. +func (b *StdioBridge) Start(ctx context.Context) { + ctx, b.cancel = context.WithCancel(ctx) + b.wg.Add(1) + go b.run(ctx) +} + +// Shutdown gracefully stops the bridge, closing connections and waiting for cleanup. +func (b *StdioBridge) Shutdown() { + if b.cancel != nil { + b.cancel() + } + if b.up != nil { + _ = b.up.Close() + } + b.wg.Wait() +} + +func (b *StdioBridge) run(ctx context.Context) { + logger.Infof("Starting StdioBridge for %s in mode %s", b.rawTarget, b.mode) + defer b.wg.Done() + + up, err := b.connectUpstream(ctx) + if err != nil { + logger.Errorf("upstream connect failed: %v", err) + return + } + b.up = up + logger.Infof("Connected to upstream %s", b.rawTarget) + + if err := b.initializeUpstream(ctx); err != nil { + logger.Errorf("upstream initialize failed: %v", err) + return + } + logger.Infof("Upstream initialized successfully") + + // Tiny local stdio server + b.srv = server.NewMCPServer( + "toolhive-stdio-bridge", + "0.1.0", + server.WithToolCapabilities(true), + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + ) + logger.Infof("Starting local stdio server") + + b.up.OnConnectionLost(func(err error) { logger.Warnf("upstream lost: %v", err) }) + + // Handle upstream notifications + b.up.OnNotification(func(n mcp.JSONRPCNotification) { + logger.Infof("upstream → downstream notify: %s %v", n.Method, n.Params) + // Convert the Params struct to JSON and back to a generic map + var params map[string]any + if buf, err := json.Marshal(n.Params); err != nil { + logger.Warnf("Failed to marshal params: %v", err) + params = map[string]any{} + } else if err := json.Unmarshal(buf, ¶ms); err != nil { + logger.Warnf("Failed to unmarshal to map: %v", err) + params = map[string]any{} + } + + b.srv.SendNotificationToAllClients(n.Method, params) + }) + + // Forwarders (register once; no pagination/refresh to keep it simple) + b.forwardAll(ctx) + + // Serve stdio (blocks) + if err := server.ServeStdio(b.srv); err != nil { + logger.Errorf("stdio server error: %v", err) + } +} + +func (b *StdioBridge) connectUpstream(_ context.Context) (*client.Client, error) { + logger.Infof("Connecting to upstream %s using mode %s", b.rawTarget, b.mode) + + switch b.mode { + case types.TransportTypeStreamableHTTP: + c, err := client.NewStreamableHttpClient( + b.rawTarget, + transport.WithHTTPTimeout(0), + transport.WithContinuousListening(), + ) + if err != nil { + return nil, err + } + // use separate, never-ending context for the client + if err := c.Start(context.Background()); err != nil { + return nil, err + } + return c, nil + case types.TransportTypeSSE: + c, err := client.NewSSEMCPClient( + b.rawTarget, + ) + if err != nil { + return nil, err + } + if err := c.Start(context.Background()); err != nil { + return nil, err + } + return c, nil + case types.TransportTypeStdio: + // if url contains sse it's sse else streamable-http + var c *client.Client + var err error + if strings.Contains(b.rawTarget, "sse") { + c, err = client.NewSSEMCPClient( + b.rawTarget, + ) + if err != nil { + return nil, err + } + } else { + c, err = client.NewStreamableHttpClient( + b.rawTarget, + ) + if err != nil { + return nil, err + } + } + if err := c.Start(context.Background()); err != nil { + return nil, err + } + return c, nil + case types.TransportTypeInspector: + fallthrough + default: + return nil, fmt.Errorf("unsupported mode %q", b.mode) + } +} + +func (b *StdioBridge) initializeUpstream(ctx context.Context) error { + logger.Infof("Initializing upstream %s", b.rawTarget) + _, err := b.up.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{Name: "toolhive-bridge", Version: "0.1.0"}, + Capabilities: mcp.ClientCapabilities{}, + }, + }) + if err != nil { + return err + } + return nil +} + +func (b *StdioBridge) forwardAll(ctx context.Context) { + logger.Infof("Forwarding all upstream data to local stdio server") + // Tools -> straight passthrough + logger.Infof("Forwarding tools from upstream to local stdio server") + if lt, err := b.up.ListTools(ctx, mcp.ListToolsRequest{}); err == nil { + for _, tool := range lt.Tools { + toolCopy := tool + b.srv.AddTool(toolCopy, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return b.up.CallTool(ctx, req) + }) + } + } + + // Resources -> return []mcp.ResourceContents + logger.Infof("Forwarding resources from upstream to local stdio server") + if lr, err := b.up.ListResources(ctx, mcp.ListResourcesRequest{}); err == nil { + for _, res := range lr.Resources { + resCopy := res + b.srv.AddResource(resCopy, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + out, err := b.up.ReadResource(ctx, req) + if err != nil { + return nil, err + } + return out.Contents, nil + }) + } + } + + // Resource templates -> same return type as resources + logger.Infof("Forwarding resource templates from upstream to local stdio server") + if lt, err := b.up.ListResourceTemplates(ctx, mcp.ListResourceTemplatesRequest{}); err == nil { + for _, tpl := range lt.ResourceTemplates { + tplCopy := tpl + b.srv.AddResourceTemplate(tplCopy, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + out, err := b.up.ReadResource(ctx, req) + if err != nil { + return nil, err + } + return out.Contents, nil + }) + } + } + + // Prompts -> straight passthrough + logger.Infof("Forwarding prompts from upstream to local stdio server") + if lp, err := b.up.ListPrompts(ctx, mcp.ListPromptsRequest{}); err == nil { + for _, p := range lp.Prompts { + pCopy := p + b.srv.AddPrompt(pCopy, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return b.up.GetPrompt(ctx, req) + }) + } + } +} diff --git a/test/e2e/proxy_stdio_test.go b/test/e2e/proxy_stdio_test.go new file mode 100644 index 000000000..3d1f36486 --- /dev/null +++ b/test/e2e/proxy_stdio_test.go @@ -0,0 +1,311 @@ +package e2e_test + +import ( + "bytes" + "fmt" + "io" + "os" + "os/exec" + "strings" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/test/e2e" +) + +func generateUniqueProxyStdioServerName(prefix string) string { + return fmt.Sprintf("%s-%d-%d-%d", prefix, os.Getpid(), time.Now().UnixNano(), GinkgoRandomSeed()) +} + +var _ = Describe("Proxy Stdio E2E", Serial, func() { + var ( + config *e2e.TestConfig + proxyCmd *exec.Cmd + mcpServerName string + workloadName string + transportType types.TransportType + proxyMode string // e.g. "sse" or "streamable-http" + ) + + BeforeEach(func() { + config = e2e.NewTestConfig() + err := e2e.CheckTHVBinaryAvailable(config) + Expect(err).ToNot(HaveOccurred()) + workloadName = generateUniqueProxyStdioServerName("mcpserver-proxy-stdio-target") + }) + + JustBeforeEach(func() { + // Build args after mcpServerName is set + args := []string{"run", "--name", workloadName, "--transport", transportType.String()} + + if transportType == types.TransportTypeStdio { + Expect(proxyMode).ToNot(BeEmpty()) + args = append(args, "--proxy-mode", proxyMode) + } + + args = append(args, mcpServerName) + + By("Starting MCP server as target") + e2e.NewTHVCommand(config, args...).ExpectSuccess() + + err := e2e.WaitForMCPServer(config, workloadName, 60*time.Second) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + By("Cleaning up test resources") + + // Stop proxy if running + if proxyCmd != nil && proxyCmd.Process != nil { + proxyCmd.Process.Kill() + proxyCmd.Wait() + } + + // Stop and remove server + if config.CleanupAfter { + err := e2e.StopAndRemoveMCPServer(config, workloadName) + Expect(err).ToNot(HaveOccurred(), "Should be able to stop and remove server") + } + }) + + Context("testing proxy stdio with sse protocol", func() { + BeforeEach(func() { + transportType = types.TransportTypeSSE + mcpServerName = "osv" + }) + It("should proxy MCP requests successfully", func() { + By("Getting OSV server URL") + osvServerURL, err := e2e.GetMCPServerURL(config, workloadName) + Expect(err).ToNot(HaveOccurred()) + + By("Extracting base URL for transparent proxy") + // The URL from thv list is like: http://127.0.0.1:21929/sse#container-name + // But the transparent proxy needs the base URL: http://127.0.0.1:21929 + baseURL := strings.TrimSuffix(strings.Split(osvServerURL, "#")[0], "/sse") + GinkgoWriter.Printf("Original server URL: %s\n", osvServerURL) + GinkgoWriter.Printf("Base URL for proxy: %s\n", baseURL) + + By("Starting the stdio proxy") + proxyCmd, stdin, outputBuffer := startProxyStdioForMCP( + config, + workloadName, + ) + + // Ensure the proxy started + Eventually(func() string { + return outputBuffer.String() + }, 10*time.Second, 1*time.Second).Should(ContainSubstring("Starting stdio proxy")) + + // Basic JSON-RPC message to initialize session + message := `{"jsonrpc":"2.0","id":-1,"method":"initialize","params":{}}` + "\n" + _, err = stdin.Write([]byte(message)) + Expect(err).ToNot(HaveOccurred()) + + By("Validating response is received through stdout (proxied)") + Eventually(func() string { + return outputBuffer.String() + }, 15*time.Second, 1*time.Second).Should(ContainSubstring(`"id":-1`)) + Eventually(func() string { + return outputBuffer.String() + }, 15*time.Second, 1*time.Second).Should(ContainSubstring(`"jsonrpc":"2.0"`)) + + By("Validating that response came from the SSE server via proxy") + Expect(outputBuffer.String()).To(ContainSubstring("result")) // Or other expected field in the response + + By("Shutting down proxy") + proxyCmd.Process.Kill() + proxyCmd.Wait() + }) + }) + + Context("testing proxy stdio with streamable-http protocol", func() { + BeforeEach(func() { + transportType = types.TransportTypeStreamableHTTP + mcpServerName = "osv" + }) + + It("should proxy MCP requests successfully", func() { + By("Getting OSV server URL") + osvServerURL, err := e2e.GetMCPServerURL(config, workloadName) + Expect(err).ToNot(HaveOccurred()) + + By("Extracting base URL for transparent proxy") + // URL will be like: http://127.0.0.1:21929/mcp#container-name + baseURL := strings.Split(osvServerURL, "#")[0] + baseURL = strings.TrimSuffix(baseURL, "/mcp") + GinkgoWriter.Printf("Original server URL: %s\n", osvServerURL) + GinkgoWriter.Printf("Base URL for proxy: %s\n", baseURL) + + By("Starting the stdio proxy") + proxyCmd, stdin, outputBuffer := startProxyStdioForMCP( + config, + workloadName, + ) + + // Ensure the proxy started + Eventually(func() string { + return outputBuffer.String() + }, 10*time.Second, 1*time.Second).Should(ContainSubstring("Starting stdio proxy")) + + By("Sending JSON-RPC initialize message through the proxy stdin") + message := `{"jsonrpc":"2.0","id":-1,"method":"initialize","params":{}}` + "\n" + _, err = stdin.Write([]byte(message)) + Expect(err).ToNot(HaveOccurred()) + + By("Validating response is received through stdout (proxied)") + Eventually(func() string { + return outputBuffer.String() + }, 15*time.Second, 1*time.Second).Should(ContainSubstring(`"id":-1`)) + Eventually(func() string { + return outputBuffer.String() + }, 15*time.Second, 1*time.Second).Should(ContainSubstring(`"jsonrpc":"2.0"`)) + + By("Validating that response came from the streamable-http server via proxy") + Expect(outputBuffer.String()).To(ContainSubstring("result")) + + By("Shutting down proxy") + proxyCmd.Process.Kill() + proxyCmd.Wait() + }) + }) + + Context("testing proxy stdio with stdio protocol+sse proxy mode", func() { + BeforeEach(func() { + transportType = types.TransportTypeStdio + proxyMode = "sse" + mcpServerName = "time" + }) + It("should proxy MCP requests successfully", func() { + By("Getting time server URL") + timeServerURL, err := e2e.GetMCPServerURL(config, workloadName) + Expect(err).ToNot(HaveOccurred()) + + By("Extracting base URL for transparent proxy") + // The URL from thv list is like: http://127.0.0.1:21929/sse#container-name + // But the transparent proxy needs the base URL: http://127.0.0.1:21929 + baseURL := strings.TrimSuffix(strings.Split(timeServerURL, "#")[0], "/sse") + GinkgoWriter.Printf("Original server URL: %s\n", timeServerURL) + GinkgoWriter.Printf("Base URL for proxy: %s\n", baseURL) + + By("Starting the stdio proxy") + proxyCmd, stdin, outputBuffer := startProxyStdioForMCP( + config, + workloadName, + ) + + // Ensure the proxy started + Eventually(func() string { + return outputBuffer.String() + }, 10*time.Second, 1*time.Second).Should(ContainSubstring("Starting stdio proxy")) + + // Basic JSON-RPC message to initialize session + message := `{"jsonrpc":"2.0","id":-1,"method":"initialize","params":{}}` + "\n" + _, err = stdin.Write([]byte(message)) + Expect(err).ToNot(HaveOccurred()) + + By("Validating response is received through stdout (proxied)") + Eventually(func() string { + return outputBuffer.String() + }, 15*time.Second, 1*time.Second).Should(ContainSubstring(`"id":-1`)) + Eventually(func() string { + return outputBuffer.String() + }, 15*time.Second, 1*time.Second).Should(ContainSubstring(`"jsonrpc":"2.0"`)) + + By("Validating that response came from the SSE server via proxy") + Expect(outputBuffer.String()).To(ContainSubstring("result")) // Or other expected field in the response + + By("Shutting down proxy") + proxyCmd.Process.Kill() + proxyCmd.Wait() + }) + }) + + Context("testing proxy stdio with stdio protocol+streamable-http proxy mode", func() { + BeforeEach(func() { + transportType = types.TransportTypeStdio + proxyMode = "streamable-http" + mcpServerName = "time" + }) + It("should proxy MCP requests successfully", func() { + By("Getting time server URL") + timeServerURL, err := e2e.GetMCPServerURL(config, workloadName) + Expect(err).ToNot(HaveOccurred()) + + By("Extracting base URL for transparent proxy") + // URL will be like: http://127.0.0.1:21929/mcp#container-name + baseURL := strings.Split(timeServerURL, "#")[0] + baseURL = strings.TrimSuffix(baseURL, "/mcp") + GinkgoWriter.Printf("Original server URL: %s\n", timeServerURL) + GinkgoWriter.Printf("Base URL for proxy: %s\n", baseURL) + + By("Starting the stdio proxy") + proxyCmd, stdin, outputBuffer := startProxyStdioForMCP( + config, + workloadName, + ) + + // Ensure the proxy started + Eventually(func() string { + return outputBuffer.String() + }, 10*time.Second, 1*time.Second).Should(ContainSubstring("Starting stdio proxy")) + + By("Sending JSON-RPC initialize message through the proxy stdin") + message := `{"jsonrpc":"2.0","id":-1,"method":"initialize","params":{}}` + "\n" + _, err = stdin.Write([]byte(message)) + Expect(err).ToNot(HaveOccurred()) + + By("Validating response is received through stdout (proxied)") + Eventually(func() string { + return outputBuffer.String() + }, 15*time.Second, 1*time.Second).Should(ContainSubstring(`"id":-1`)) + Eventually(func() string { + return outputBuffer.String() + }, 15*time.Second, 1*time.Second).Should(ContainSubstring(`"jsonrpc":"2.0"`)) + + By("Validating that response came from the streamable-http server via proxy") + Expect(outputBuffer.String()).To(ContainSubstring("result")) + + By("Shutting down proxy") + proxyCmd.Process.Kill() + proxyCmd.Wait() + }) + }) + +}) + +// Helper functions +func startProxyStdioForMCP(config *e2e.TestConfig, workloadName string) (*exec.Cmd, io.WriteCloser, *bytes.Buffer) { + args := []string{ + "proxy", + "stdio", + workloadName, + } + + // Log the command for debugging + GinkgoWriter.Printf("Starting proxy stdio for MCP with args: %v\n", args) + + // Create command + cmd := exec.Command(config.THVBinary, args...) + cmd.Env = os.Environ() + + // Create buffer to capture output (capture both stdout and stderr) + var outputBuffer bytes.Buffer + + // Use MultiWriter to write to both buffer and GinkgoWriter + multiWriter := io.MultiWriter(&outputBuffer, GinkgoWriter) + cmd.Stdout = multiWriter + cmd.Stderr = multiWriter // Capture stderr too since logger might write there + + // Get stdin pipe BEFORE starting + stdin, err := cmd.StdinPipe() + Expect(err).ToNot(HaveOccurred()) + + // Start the command + err = cmd.Start() + Expect(err).ToNot(HaveOccurred()) + + return cmd, stdin, &outputBuffer +}