Skip to content

Commit 440522c

Browse files
committed
add tests
1 parent 2fc7c6b commit 440522c

File tree

5 files changed

+362
-19
lines changed

5 files changed

+362
-19
lines changed

cmd/thv/app/proxy.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ func init() {
175175
proxyStdioCmd.Flags().StringVar(&stdioWorkloadName, "workload-name", "", "Workload name for the proxy (required)")
176176
_ = proxyStdioCmd.MarkFlagRequired("workload-name")
177177

178-
// Attach the subcommand to the main proxy command
178+
// Attach the subcommands to the main proxy command
179+
proxyCmd.AddCommand(proxyTunnelCmd)
179180
proxyCmd.AddCommand(proxyStdioCmd)
180181

181182
}

cmd/thv/app/proxy_stdio.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99

1010
"github.com/stacklok/toolhive/pkg/logger"
1111
"github.com/stacklok/toolhive/pkg/transport"
12-
"github.com/stacklok/toolhive/pkg/transport/types"
1312
"github.com/stacklok/toolhive/pkg/workloads"
1413
)
1514

@@ -52,12 +51,6 @@ func proxyStdioCmdFunc(cmd *cobra.Command, args []string) error {
5251
return fmt.Errorf("failed to get workload %q: %w", stdioWorkloadName, err)
5352
}
5453

55-
// check if workload has http/sse transport
56-
if stdioWorkload.TransportType != types.TransportTypeSSE &&
57-
stdioWorkload.TransportType != types.TransportTypeStreamableHTTP {
58-
return fmt.Errorf("only HTTP/SSE workloads are supported for this proxy")
59-
}
60-
6154
logger.Infof("Starting stdio proxy for server=%q -> %s", serverName, stdioWorkloadName)
6255

6356
bridge, err := transport.NewStdioBridge(stdioWorkload.URL)

pkg/container/images/registry.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/docker/docker/client"
1010
"github.com/google/go-containerregistry/pkg/authn"
1111
"github.com/google/go-containerregistry/pkg/name"
12-
"github.com/google/go-containerregistry/pkg/v1"
12+
v1 "github.com/google/go-containerregistry/pkg/v1"
1313
"github.com/google/go-containerregistry/pkg/v1/daemon"
1414
"github.com/google/go-containerregistry/pkg/v1/remote"
1515

@@ -62,7 +62,7 @@ func (r *RegistryImageManager) ImageExists(_ context.Context, imageName string)
6262

6363
// PullImage pulls an image from a registry and saves it to the local daemon
6464
func (r *RegistryImageManager) PullImage(ctx context.Context, imageName string) error {
65-
logger.Infof("Pulling image: %s", imageName)
65+
logger.Infof("Pulling1 image: %s", imageName)
6666

6767
// Parse the image reference
6868
ref, err := name.ParseReference(imageName)

pkg/transport/bridge.go

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,37 +195,67 @@ func (b *StdioBridge) runStreamableWriter(ctx context.Context) {
195195

196196
func (b *StdioBridge) runLegacyReader(ctx context.Context) {
197197
defer b.wg.Done()
198+
198199
req, err := http.NewRequestWithContext(ctx, "GET", b.baseURL.String(), nil)
199200
if err != nil {
200201
logger.Errorf("Failed to create GET request: %v", err)
201202
return
202203
}
203204
req.Header.Set("Accept", "text/event-stream")
204205
copyHeaders(req.Header, b.headers)
206+
205207
resp, err := http.DefaultClient.Do(req)
206208
if err != nil {
207209
logger.Errorf("SSE connect error: %v", err)
208210
return
209211
}
210212
defer resp.Body.Close()
213+
211214
scanner := bufio.NewScanner(resp.Body)
212-
var sb strings.Builder
215+
var (
216+
eventName string
217+
dataLines []string
218+
)
213219
for scanner.Scan() {
214220
line := scanner.Text()
221+
222+
if strings.HasPrefix(line, "event:") {
223+
eventName = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
224+
continue
225+
}
215226
if strings.HasPrefix(line, "data:") {
216-
sb.WriteString(strings.TrimPrefix(line, "data:"))
227+
dataLines = append(dataLines, strings.TrimPrefix(strings.TrimSpace(line), "data:"))
228+
continue
217229
}
218230
if line == "" {
219-
raw := strings.TrimSpace(sb.String())
220-
sb.Reset()
221-
if strings.HasPrefix(raw, "/") && strings.Contains(raw, "sessionId") {
222-
b.updatePostURL(raw)
223-
b.sendInitialize(ctx)
224-
} else {
225-
emitJSON(raw)
231+
payload := strings.Join(dataLines, "\n")
232+
dataLines = dataLines[:0]
233+
234+
switch eventName {
235+
case "endpoint":
236+
if payload != "" {
237+
b.updatePostURL(payload)
238+
b.sendInitialize(ctx)
239+
}
240+
default:
241+
if isJSON(payload) {
242+
emitJSON(payload)
243+
} else {
244+
logger.Debugf("Skipping non-JSON SSE event (%q): %q", eventName, payload)
245+
}
226246
}
247+
eventName = ""
227248
}
228249
}
250+
251+
if err := scanner.Err(); err != nil {
252+
logger.Errorf("SSE read error: %v", err)
253+
}
254+
}
255+
256+
func isJSON(s string) bool {
257+
s = strings.TrimLeft(s, " \r\n\t")
258+
return len(s) > 0 && (s[0] == '{' || s[0] == '[')
229259
}
230260

231261
func (b *StdioBridge) runLegacyWriter(ctx context.Context) {

0 commit comments

Comments
 (0)