Skip to content

Commit f7d9a48

Browse files
committed
fix for stdio proxy
1 parent f70fb07 commit f7d9a48

File tree

4 files changed

+61
-39
lines changed

4 files changed

+61
-39
lines changed

cmd/thv/app/proxy.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,6 @@ func init() {
171171
if err := proxyCmd.MarkFlagRequired("target-uri"); err != nil {
172172
logger.Warnf("Warning: Failed to mark flag as required: %v", err)
173173
}
174-
175-
proxyStdioCmd.Flags().StringVar(&stdioWorkloadName, "workload-name", "", "Workload name for the proxy (required)")
176-
_ = proxyStdioCmd.MarkFlagRequired("workload-name")
177-
178174
// Attach the subcommands to the main proxy command
179175
proxyCmd.AddCommand(proxyTunnelCmd)
180176
proxyCmd.AddCommand(proxyStdioCmd)

cmd/thv/app/proxy_stdio.go

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,46 +12,34 @@ import (
1212
"github.com/stacklok/toolhive/pkg/workloads"
1313
)
1414

15-
var (
16-
stdioWorkloadName string
17-
)
18-
1915
var proxyStdioCmd = &cobra.Command{
20-
Use: "stdio [flags] SERVER_NAME",
16+
Use: "stdio WORKLOAD-NAME SERVER_NAME",
2117
Short: "Create a stdio-based proxy for an MCP server",
2218
Long: `Create a stdio-based proxy that connects stdin/stdout to a target MCP server.
2319
2420
Example:
25-
thv proxy stdio --workload-name my-server my-server-proxy
26-
27-
Flags:
28-
--workload-name Workload name for the proxy (required)
21+
thv proxy stdio my-workload my-server-proxy
2922
`,
30-
Args: cobra.ExactArgs(1),
23+
Args: cobra.ExactArgs(2),
3124
RunE: proxyStdioCmdFunc,
3225
}
3326

3427
func proxyStdioCmdFunc(cmd *cobra.Command, args []string) error {
3528
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
3629
defer cancel()
3730

38-
serverName := args[0]
39-
40-
// validate that workload name exists
41-
if stdioWorkloadName == "" {
42-
return fmt.Errorf("workload name must be specified with --workload-name")
43-
}
31+
workloadName := args[0]
32+
serverName := args[1]
4433

4534
workloadManager, err := workloads.NewManager(ctx)
4635
if err != nil {
4736
return fmt.Errorf("failed to create workload manager: %w", err)
4837
}
49-
stdioWorkload, err := workloadManager.GetWorkload(ctx, stdioWorkloadName)
38+
stdioWorkload, err := workloadManager.GetWorkload(ctx, workloadName)
5039
if err != nil {
51-
return fmt.Errorf("failed to get workload %q: %w", stdioWorkloadName, err)
40+
return fmt.Errorf("failed to get workload %q: %w", workloadName, err)
5241
}
53-
54-
logger.Infof("Starting stdio proxy for server=%q -> %s", serverName, stdioWorkloadName)
42+
logger.Infof("Starting stdio proxy for server=%q -> %s", serverName, workloadName)
5543

5644
bridge, err := transport.NewStdioBridge(stdioWorkload.URL)
5745
if err != nil {

pkg/container/images/registry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (r *RegistryImageManager) ImageExists(_ context.Context, imageName string)
6262

6363
// PullImage pulls an image from a registry and saves it to the local daemon
6464
func (r *RegistryImageManager) PullImage(ctx context.Context, imageName string) error {
65-
logger.Infof("Pulling1 image: %s", imageName)
65+
logger.Infof("Pulling image: %s", imageName)
6666

6767
// Parse the image reference
6868
ref, err := name.ParseReference(imageName)

pkg/transport/bridge.go

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,28 +73,54 @@ func (b *StdioBridge) loop(ctx context.Context) {
7373

7474
// Header include Mcp-Session-Id if needed
7575
func (b *StdioBridge) detectTransport(ctx context.Context) (string, error) {
76-
body, err := buildJSONRPCInitializeRequest()
76+
// 1) Probe for legacy SSE without sending any JSON-RPC
77+
getReq, err := http.NewRequestWithContext(ctx, "GET", b.baseURL.String(), nil)
7778
if err != nil {
78-
logger.Errorf("failed to marshal initialization request: %v", err)
7979
return "", err
8080
}
81-
req, err := http.NewRequestWithContext(ctx, "POST", b.baseURL.String(), bytes.NewReader(body))
81+
getReq.Header.Set("Accept", "text/event-stream")
82+
copyHeaders(getReq.Header, b.headers)
83+
84+
getResp, err := http.DefaultClient.Do(getReq)
85+
if err == nil {
86+
ct := getResp.Header.Get("Content-Type")
87+
err = getResp.Body.Close() // we’re only peeking at headers here
88+
if err != nil {
89+
return "", fmt.Errorf("failed to close GET response body: %w", err)
90+
}
91+
if strings.HasPrefix(ct, "text/event-stream") {
92+
// Legacy SSE: runLegacyReader will open the real stream.
93+
return "legacy-sse", nil
94+
}
95+
}
96+
97+
// 2) Not SSE -> treat as streamable HTTP and do a proper initialize
98+
body, err := buildJSONRPCInitializeRequest() // keep if your server expects initialize for streamable
8299
if err != nil {
83-
logger.Errorf("failed to create HTTP request: %v", err)
84100
return "", err
85101
}
86-
req.Header.Set("Content-Type", "application/json")
87-
req.Header.Set("Accept", "application/json, text/event-stream")
88-
copyHeaders(req.Header, b.headers)
89-
resp, err := http.DefaultClient.Do(req)
102+
postReq, err := http.NewRequestWithContext(ctx, "POST", b.baseURL.String(), bytes.NewReader(body))
90103
if err != nil {
91104
return "", err
92105
}
93-
defer resp.Body.Close()
94-
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
106+
postReq.Header.Set("Content-Type", "application/json")
107+
postReq.Header.Set("Accept", "application/json, text/event-stream")
108+
copyHeaders(postReq.Header, b.headers)
109+
110+
postResp, err := http.DefaultClient.Do(postReq)
111+
if err != nil {
112+
return "", err
113+
}
114+
defer postResp.Body.Close()
115+
116+
// If a streamable server returns info + Mcp-Session-Id, great.
117+
// Some legacy gateways might reply 4xx to POST at base URL.
118+
if postResp.StatusCode >= 400 && postResp.StatusCode < 500 {
95119
return "legacy-sse", nil
96120
}
97-
b.handleInitializeResponse(resp)
121+
122+
// Streamable: capture session headers and emit the response payload.
123+
b.handleInitializeResponse(postResp)
98124
return "streamable-http", nil
99125
}
100126

@@ -104,7 +130,6 @@ func (b *StdioBridge) handleInitializeResponse(resp *http.Response) {
104130
b.headers = make(http.Header)
105131
}
106132
b.headers.Set("Mcp-Session-Id", sid)
107-
logger.Infof("Streamable HTTP session ID: %s", sid)
108133
}
109134
data, err := io.ReadAll(resp.Body)
110135
if err != nil {
@@ -307,15 +332,28 @@ func (b *StdioBridge) updatePostURL(path string) {
307332
return
308333
}
309334
b.postURL = u
310-
logger.Infof("POST URL updated to %s", b.postURL)
311335
}
312336

313337
func buildJSONRPCInitializeRequest() ([]byte, error) {
314338
req := map[string]interface{}{
315339
"jsonrpc": "2.0",
316340
"id": 1,
317341
"method": "initialize",
318-
"params": map[string]interface{}{},
342+
"params": map[string]interface{}{
343+
"protocolVersion": "2024-02-01",
344+
"capabilities": map[string]interface{}{
345+
"prompts": true,
346+
"tools": true,
347+
"resources": map[string]interface{}{
348+
"subscribe": true,
349+
"unsubscribe": true,
350+
},
351+
},
352+
"clientInfo": map[string]interface{}{
353+
"name": "toolhive-stdio-bridge",
354+
"version": "0.1.0",
355+
},
356+
},
319357
}
320358
return json.Marshal(req)
321359
}

0 commit comments

Comments
 (0)