Skip to content

feat: add session management for proxy #1081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 169 additions & 6 deletions pkg/transport/proxy/transparent/transparent_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,26 @@
package transparent

import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strings"
"sync"
"time"

"golang.org/x/exp/jsonrpc2"

"github.com/stacklok/toolhive/pkg/healthcheck"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/transport/session"
"github.com/stacklok/toolhive/pkg/transport/types"
)

Expand Down Expand Up @@ -47,6 +55,12 @@ type TransparentProxy struct {

// Optional Prometheus metrics handler
prometheusHandler http.Handler

// Sessions for tracking state
sessionManager *session.Manager

// If mcp server has been initialized
IsServerInitialized bool
}

// NewTransparentProxy creates a new transparent proxy with optional middlewares.
Expand All @@ -66,6 +80,7 @@ func NewTransparentProxy(
middlewares: middlewares,
shutdownCh: make(chan struct{}),
prometheusHandler: prometheusHandler,
sessionManager: session.NewManager(30*time.Minute, session.NewProxySession),
}

// Create MCP pinger and health checker
Expand All @@ -75,6 +90,144 @@ func NewTransparentProxy(
return proxy
}

type tracingTransport struct {
base http.RoundTripper
p *TransparentProxy
}

func (p *TransparentProxy) setServerInitialized() {
if !p.IsServerInitialized {
p.mutex.Lock()
p.IsServerInitialized = true
p.mutex.Unlock()
logger.Infof("Server was initialized successfully for %s", p.containerName)
}
}

func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) {
tr := t.base
if tr == nil {
tr = http.DefaultTransport
}
return tr.RoundTrip(req)
}

func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
reqBody := readRequestBody(req)

path := req.URL.Path
isMCP := strings.HasPrefix(path, "/mcp")
isJSON := strings.Contains(req.Header.Get("Content-Type"), "application/json")
sawInitialize := false

if isMCP && isJSON && len(reqBody) > 0 {
sawInitialize = t.detectInitialize(reqBody)
}

resp, err := t.forward(req)
if err != nil {
logger.Errorf("Failed to forward request: %v", err)
return nil, err
}

if resp.StatusCode == http.StatusOK {
// check if we saw a valid mcp header
ct := resp.Header.Get("Mcp-Session-Id")
if ct != "" {
logger.Infof("Detected Mcp-Session-Id header: %s", ct)
if _, ok := t.p.sessionManager.Get(ct); !ok {
if err := t.p.sessionManager.AddWithID(ct); err != nil {
logger.Errorf("Failed to create session from header %s: %v", ct, err)
}
}
t.p.setServerInitialized()
return resp, nil
}
// status was ok and we saw an initialize call
if sawInitialize && !t.p.IsServerInitialized {
t.p.setServerInitialized()
return resp, nil
}
}

return resp, nil
}

func readRequestBody(req *http.Request) []byte {
reqBody := []byte{}
if req.Body != nil {
buf, err := io.ReadAll(req.Body)
if err != nil {
logger.Errorf("Failed to read request body: %v", err)
} else {
reqBody = buf
}
req.Body = io.NopCloser(bytes.NewReader(reqBody))
}
return reqBody
}

func (t *tracingTransport) detectInitialize(body []byte) bool {
var rpc struct {
Method string `json:"method"`
}
if err := json.Unmarshal(body, &rpc); err != nil {
logger.Errorf("Failed to parse JSON-RPC body: %v", err)
return false
}
if rpc.Method == "initialize" {
logger.Infof("Detected initialize method call for %s", t.p.containerName)
return true
}
return false
}

var sessionRe = regexp.MustCompile(`sessionId=([0-9A-Fa-f-]+)|"sessionId"\s*:\s*"([^"]+)"`)

func (p *TransparentProxy) modifyForSessionID(resp *http.Response) error {
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if mediaType != "text/event-stream" {
return nil
}

pr, pw := io.Pipe()
originalBody := resp.Body
resp.Body = pr

go func() {
defer pw.Close()
scanner := bufio.NewScanner(originalBody)
found := false

for scanner.Scan() {
line := scanner.Bytes()
if !found {
if m := sessionRe.FindSubmatch(line); m != nil {
sid := string(m[1])
if sid == "" {
sid = string(m[2])
}
p.setServerInitialized()
err := p.sessionManager.AddWithID(sid)
if err != nil {
logger.Errorf("Failed to create session from SSE line: %v", err)
}
found = true
}
}
if _, err := pw.Write(append(line, '\n')); err != nil {
return
}
}
_, err := io.Copy(pw, originalBody)
if err != nil && err != io.EOF {
logger.Errorf("Failed to copy response body: %v", err)
}
}()

return nil
}

// Start starts the transparent proxy.
func (p *TransparentProxy) Start(ctx context.Context) error {
p.mutex.Lock()
Expand All @@ -88,6 +241,11 @@ func (p *TransparentProxy) Start(ctx context.Context) error {

// Create a reverse proxy
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.FlushInterval = -1
proxy.Transport = &tracingTransport{base: http.DefaultTransport, p: p}
proxy.ModifyResponse = func(resp *http.Response) error {
return p.modifyForSessionID(resp)
}

// Create a handler that logs requests
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -160,13 +318,18 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
logger.Infof("Shutdown initiated, stopping health monitor for %s", p.containerName)
return
case <-ticker.C:
alive := p.healthChecker.CheckHealth(parentCtx)
if alive.Status != healthcheck.StatusHealthy {
logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName)
if err := p.Stop(parentCtx); err != nil {
logger.Errorf("Failed to stop proxy for %s: %v", p.containerName, err)
// Perform health check only if mcp server has been initialized
if p.IsServerInitialized {
alive := p.healthChecker.CheckHealth(parentCtx)
if alive.Status != healthcheck.StatusHealthy {
logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName)
if err := p.Stop(parentCtx); err != nil {
logger.Errorf("Failed to stop proxy for %s: %v", p.containerName, err)
}
return
}
return
} else {
logger.Infof("MCP server not initialized yet, skipping health check for %s", p.containerName)
}
}
}
Expand Down
126 changes: 126 additions & 0 deletions pkg/transport/proxy/transparent/transparent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package transparent

import (
"bufio"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/stacklok/toolhive/pkg/logger"
)

func init() {
logger.Initialize() // ensure logging doesn't panic
}

func TestStreamingSessionIDDetection(t *testing.T) {
t.Parallel()
proxy := NewTransparentProxy("127.0.0.1", 0, "test", "http://example.com", nil)
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.WriteHeader(200)

// Simulate SSE lines
w.Write([]byte("data: hello\n"))
w.Write([]byte("data: sessionId=ABC123\n"))
w.(http.Flusher).Flush()

time.Sleep(10 * time.Millisecond)
w.Write([]byte("data: more\n"))
}))
defer target.Close()

// set up reverse proxy using ModifyResponse
parsedURL, _ := http.NewRequest("GET", target.URL, nil)
proxyURL := httputil.NewSingleHostReverseProxy(parsedURL.URL)
proxyURL.FlushInterval = -1
proxyURL.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy}
proxyURL.ModifyResponse = proxy.modifyForSessionID

// hit the proxy
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", target.URL, nil)
proxyURL.ServeHTTP(rec, req)

// read all SSE lines
sc := bufio.NewScanner(rec.Body)
var bodyLines []string
for sc.Scan() {
bodyLines = append(bodyLines, sc.Text())
}
assert.Contains(t, bodyLines, "data: sessionId=ABC123")

// side-effect: proxy should have seen session
assert.True(t, proxy.IsServerInitialized, "server should have been initialized")
_, ok := proxy.sessionManager.Get("ABC123")
assert.True(t, ok, "sessionManager should have stored ABC123")
}

func createBasicProxy(p *TransparentProxy, targetURL *url.URL) *httputil.ReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.Director = func(r *http.Request) {
r.URL.Scheme = targetURL.Scheme
r.URL.Host = targetURL.Host
r.Host = targetURL.Host
}
proxy.FlushInterval = -1
proxy.Transport = &tracingTransport{base: http.DefaultTransport, p: p}
proxy.ModifyResponse = p.modifyForSessionID
return proxy
}

func TestNoSessionIDInNonSSE(t *testing.T) {
t.Parallel()

p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil)

target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
// Set both content-type and also optionally MCP header to test behavior
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Write([]byte(`{"hello": "world"}`))
}))
defer target.Close()

targetURL, _ := url.Parse(target.URL)
proxy := createBasicProxy(p, targetURL)

rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", target.URL, nil)
proxy.ServeHTTP(rec, req)

assert.False(t, p.IsServerInitialized, "server should not be initialized for application/json")
_, ok := p.sessionManager.Get("XYZ789")
assert.False(t, ok, "no session should be added")
}

func TestHeaderBasedSessionInitialization(t *testing.T) {
t.Parallel()

p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil)

target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
// Set both content-type and also optionally MCP header to test behavior
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Mcp-Session-Id", "XYZ789")
w.WriteHeader(200)
w.Write([]byte(`{"hello": "world"}`))
}))
defer target.Close()

targetURL, _ := url.Parse(target.URL)
proxy := createBasicProxy(p, targetURL)

rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", target.URL, nil)
proxy.ServeHTTP(rec, req)

assert.True(t, p.IsServerInitialized, "server should not be initialized for application/json")
_, ok := p.sessionManager.Get("XYZ789")
assert.True(t, ok, "no session should be added")
}
Loading
Loading