From 02c902f0c8cdd7ddd5df927f05ecea9bfff65101 Mon Sep 17 00:00:00 2001 From: taskbot Date: Wed, 16 Jul 2025 12:19:21 +0200 Subject: [PATCH 1/9] feat: add session management for proxy Adds a session interface to proxy, and initialize with the extracted session id from transport messages Only enable healthchecks when we detect that the server has been initialized Closes: #1078 --- .../proxy/transparent/transparent_proxy.go | 173 +++++++++++++++++- pkg/transport/session/session.go | 40 ++++ 2 files changed, 205 insertions(+), 8 deletions(-) create mode 100644 pkg/transport/session/session.go diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index e3988ce53..636801765 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -3,11 +3,18 @@ package transparent import ( + "bufio" + "bytes" "context" "fmt" + "io" + "mime" "net/http" "net/http/httputil" "net/url" + "regexp" + "strconv" + "strings" "sync" "time" @@ -15,6 +22,7 @@ import ( "github.com/stacklok/toolhive/pkg/healthcheck" "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -47,8 +55,17 @@ type TransparentProxy struct { // Optional Prometheus metrics handler prometheusHandler http.Handler + + // Sessions for managing client connections + sessions map[string]session.Session + + // mutex for protecting session access + sessionMutex sync.Mutex } +// TransparentProxySessionID is the session ID used for the transparent proxy. +const TransparentProxySessionID = "transparent-proxy-session" + // NewTransparentProxy creates a new transparent proxy with optional middlewares. func NewTransparentProxy( host string, @@ -66,15 +83,115 @@ func NewTransparentProxy( middlewares: middlewares, shutdownCh: make(chan struct{}), prometheusHandler: prometheusHandler, + sessions: make(map[string]session.Session), } // Create MCP pinger and health checker mcpPinger := NewMCPPinger(targetURI) proxy.healthChecker = healthcheck.NewHealthChecker("sse", mcpPinger) + _, err := proxy.CreateSession(TransparentProxySessionID) + if err != nil { + logger.Errorf("Failed to create session for TransparentProxy: %v", err) + return nil + } + return proxy } +func (p *TransparentProxy) handleModifyResponse(res *http.Response) error { + // Log headers + if sid := res.Header.Get("Mcp-Session-Id"); sid != "" { + logger.Infof("🆔 Streamable session ID from header: %s", sid) + } + + // Handle streaming (SSE) + if ct, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); ct == "text/event-stream" { + pr, pw := io.Pipe() + orig := res.Body + res.Body = pr + + go func() { + defer pw.Close() + scanner := bufio.NewScanner(orig) + re := regexp.MustCompile(`sessionId=([\w-]+)`) // Capture UUID-like IDs + + for scanner.Scan() { + line := scanner.Text() + + if matches := re.FindStringSubmatch(line); len(matches) == 2 { + sessionID := matches[1] + + // set session id for proxy + extractedSession, ok := p.GetSession(TransparentProxySessionID) + if !ok { + logger.Errorf("Failed to get session for TransparentProxy") + continue + } + extractedSession.SetIsInitialized(true) + extractedSession.SetMCPSessionID(sessionID) + } + _, err := pw.Write([]byte(line + "\n")) + if err != nil { + logger.Errorf("Failed to write to pipe: %v", err) + } + } + }() + return nil + } + + // Handle non-streaming (JSON) bodies + body, err := io.ReadAll(res.Body) + if err != nil { + return err + } + err = res.Body.Close() + if err != nil { + logger.Errorf("Failed to close response body: %v", err) + } + + text := string(body) + logger.Infof("HTTP response body: %s", text) + + // Parse sessionId if embedded in JSON + rejson := regexp.MustCompile(`"sessionId"\s*:\s*"([\w-]+)"`) + if matches := rejson.FindStringSubmatch(text); len(matches) == 2 { + sessionID := matches[1] + logger.Infof("🆔 Captured sessionId from JSON: %s", sessionID) + } + + res.Body = io.NopCloser(bytes.NewReader(body)) + res.ContentLength = int64(len(body)) + res.Header.Set("Content-Length", strconv.Itoa(len(body))) + return nil +} + +func (p *TransparentProxy) handleAndDetectInitialize(w http.ResponseWriter, r *http.Request, proxy *httputil.ReverseProxy) { + logger.Infof("Transparent proxy: %s %s -> %s", r.Method, r.URL.Path, p.targetURI) + + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/mcp") { + // Read the body for inspection without consuming it + body, err := io.ReadAll(r.Body) + if err != nil { + logger.Errorf("Error reading request body: %v", err) + } else { + if bytes.Contains(body, []byte(`"method":"initialize"`)) { + logger.Infof("🔧 Detected initialize request to %s", r.URL.Path) + extractedSession, ok := p.GetSession(TransparentProxySessionID) + if ok { + extractedSession.SetIsInitialized(true) + } else { + logger.Errorf("No session found to mark initialized") + } + } + r.Body = io.NopCloser(bytes.NewReader(body)) + r.ContentLength = int64(len(body)) + } + } + + proxy.ServeHTTP(w, r) +} + // Start starts the transparent proxy. func (p *TransparentProxy) Start(ctx context.Context) error { p.mutex.Lock() @@ -88,11 +205,11 @@ func (p *TransparentProxy) Start(ctx context.Context) error { // Create a reverse proxy proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.ModifyResponse = p.handleModifyResponse // Create a handler that logs requests handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger.Infof("Transparent proxy: %s %s -> %s", r.Method, r.URL.Path, targetURL) - proxy.ServeHTTP(w, r) + p.handleAndDetectInitialize(w, r, proxy) }) // Create a mux to handle both proxy and health endpoints @@ -160,14 +277,25 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) { logger.Infof("Shutdown initiated, stopping health monitor for %s", p.containerName) return case <-ticker.C: - alive := p.healthChecker.CheckHealth(parentCtx) - if alive.Status != healthcheck.StatusHealthy { - logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName) - if err := p.Stop(parentCtx); err != nil { - logger.Errorf("Failed to stop proxy for %s: %v", p.containerName, err) - } + // Perform health check only if mcp server has been initialized + extractedSession, isSession := p.GetSession(TransparentProxySessionID) + if !isSession { + logger.Errorf("Failed to get session for health check") return } + + if extractedSession.IsInitialized() { + alive := p.healthChecker.CheckHealth(parentCtx) + if alive.Status != healthcheck.StatusHealthy { + logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName) + if err := p.Stop(parentCtx); err != nil { + logger.Errorf("Failed to stop proxy for %s: %v", p.containerName, err) + } + return + } + } else { + logger.Infof("Session %s is not initialized, cannot start healthcheck", extractedSession.ID()) + } } } } @@ -219,3 +347,32 @@ func (*TransparentProxy) SendMessageToDestination(_ jsonrpc2.Message) error { func (*TransparentProxy) ForwardResponseToClients(_ context.Context, _ jsonrpc2.Message) error { return fmt.Errorf("ForwardResponseToClients not implemented for TransparentProxy") } + +// CreateSession creates a new session for the transparent proxy. +func (p *TransparentProxy) CreateSession(id string) (session.Session, error) { + logger.Infof("Creating session with ID: %s", id) + if id == "" { + return nil, fmt.Errorf("session ID cannot be empty") + } + p.sessionMutex.Lock() + defer p.sessionMutex.Unlock() + if _, found := p.sessions[id]; found { + return nil, fmt.Errorf("session %s exists", id) + } + s := &session.ProxySession{Id: id} + s.Init() + p.sessions[id] = s + return s, nil +} + +// GetSession retrieves a session by ID. +func (p *TransparentProxy) GetSession(id string) (session.Session, bool) { + logger.Infof("Retrieving session with ID: %s", id) + if id == "" { + return nil, false + } + p.sessionMutex.Lock() + defer p.sessionMutex.Unlock() + s, ok := p.sessions[id] + return s, ok +} diff --git a/pkg/transport/session/session.go b/pkg/transport/session/session.go new file mode 100644 index 000000000..8152250c3 --- /dev/null +++ b/pkg/transport/session/session.go @@ -0,0 +1,40 @@ +// Package session provides an interface and implementation for managing sessions in the transport layer. +package session + +// Session defines the interface for transport sessions. +type Session interface { + ID() string + Init() + IsInitialized() bool + SetIsInitialized(initialized bool) + MCPSessionID() string + SetMCPSessionID(mcpID string) +} + +// ProxySession implements the Session interface for transport sessions. +type ProxySession struct { + Id string + initialized bool + mcpSessionID string +} + +// ID returns the session ID. +func (s *ProxySession) ID() string { return s.Id } + +// SetIsInitialized sets the initialized state of the session. +func (s *ProxySession) SetIsInitialized(initialized bool) { s.initialized = initialized } + +// IsInitialized returns whether the session is initialized. +func (s *ProxySession) IsInitialized() bool { return s.initialized } + +// MCPSessionID returns the MCP session ID. +func (s *ProxySession) MCPSessionID() string { return s.mcpSessionID } + +// SetMCPSessionID sets the MCP session ID. +func (s *ProxySession) SetMCPSessionID(mcpID string) { s.mcpSessionID = mcpID } + +// Init initializes the session, setting it to uninitialized state. +func (s *ProxySession) Init() { + s.initialized = false + s.mcpSessionID = "" +} From 49efb3c38ce721af10a44738d3b6bdc17117563d Mon Sep 17 00:00:00 2001 From: taskbot Date: Thu, 17 Jul 2025 11:57:27 +0200 Subject: [PATCH 2/9] fixes from review --- .../proxy/transparent/transparent_proxy.go | 128 +++++------------- pkg/transport/session/manager.go | 106 +++++++++++++++ pkg/transport/session/proxy_session.go | 28 ++++ pkg/transport/session/session.go | 40 ------ 4 files changed, 167 insertions(+), 135 deletions(-) create mode 100644 pkg/transport/session/manager.go create mode 100644 pkg/transport/session/proxy_session.go delete mode 100644 pkg/transport/session/session.go diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index 636801765..5ceaef0bb 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -13,7 +13,6 @@ import ( "net/http/httputil" "net/url" "regexp" - "strconv" "strings" "sync" "time" @@ -56,16 +55,13 @@ type TransparentProxy struct { // Optional Prometheus metrics handler prometheusHandler http.Handler - // Sessions for managing client connections - sessions map[string]session.Session + // Sessions for tracking state + sessionManager *session.Manager - // mutex for protecting session access - sessionMutex sync.Mutex + // If mcp server has been initialized + IsServerInitialized bool } -// TransparentProxySessionID is the session ID used for the transparent proxy. -const TransparentProxySessionID = "transparent-proxy-session" - // NewTransparentProxy creates a new transparent proxy with optional middlewares. func NewTransparentProxy( host string, @@ -83,30 +79,36 @@ func NewTransparentProxy( middlewares: middlewares, shutdownCh: make(chan struct{}), prometheusHandler: prometheusHandler, - sessions: make(map[string]session.Session), + sessionManager: session.NewManager(30 * time.Minute), } // Create MCP pinger and health checker mcpPinger := NewMCPPinger(targetURI) proxy.healthChecker = healthcheck.NewHealthChecker("sse", mcpPinger) - _, err := proxy.CreateSession(TransparentProxySessionID) - if err != nil { - logger.Errorf("Failed to create session for TransparentProxy: %v", err) - return nil - } - return proxy } +var sessionIDRegex = regexp.MustCompile(`sessionId=([\w-]+)`) + func (p *TransparentProxy) handleModifyResponse(res *http.Response) error { - // Log headers if sid := res.Header.Get("Mcp-Session-Id"); sid != "" { - logger.Infof("🆔 Streamable session ID from header: %s", sid) + logger.Infof("Detected Mcp-Session-Id header: %s", sid) + if _, ok := p.sessionManager.Get(sid); !ok { + if _, err := p.sessionManager.AddWithID(sid); err != nil { + logger.Errorf("Failed to create session from header %s: %v", sid, err) + } + } + p.IsServerInitialized = true } // Handle streaming (SSE) - if ct, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); ct == "text/event-stream" { + ct, _, err := mime.ParseMediaType(res.Header.Get("Content-Type")) + if err != nil { + logger.Errorf("Failed to parse Content-Type: %v", err) + return err + } + if ct == "text/event-stream" { pr, pw := io.Pipe() orig := res.Body res.Body = pr @@ -114,22 +116,21 @@ func (p *TransparentProxy) handleModifyResponse(res *http.Response) error { go func() { defer pw.Close() scanner := bufio.NewScanner(orig) - re := regexp.MustCompile(`sessionId=([\w-]+)`) // Capture UUID-like IDs - for scanner.Scan() { line := scanner.Text() - if matches := re.FindStringSubmatch(line); len(matches) == 2 { + if matches := sessionIDRegex.FindStringSubmatch(line); len(matches) == 2 { sessionID := matches[1] - - // set session id for proxy - extractedSession, ok := p.GetSession(TransparentProxySessionID) + _, ok := p.sessionManager.Get(sessionID) if !ok { - logger.Errorf("Failed to get session for TransparentProxy") - continue + var err error + _, err = p.sessionManager.AddWithID(sessionID) + if err != nil { + logger.Errorf("Failed to create session %s: %v", sessionID, err) + continue + } } - extractedSession.SetIsInitialized(true) - extractedSession.SetMCPSessionID(sessionID) + p.IsServerInitialized = true } _, err := pw.Write([]byte(line + "\n")) if err != nil { @@ -140,29 +141,6 @@ func (p *TransparentProxy) handleModifyResponse(res *http.Response) error { return nil } - // Handle non-streaming (JSON) bodies - body, err := io.ReadAll(res.Body) - if err != nil { - return err - } - err = res.Body.Close() - if err != nil { - logger.Errorf("Failed to close response body: %v", err) - } - - text := string(body) - logger.Infof("HTTP response body: %s", text) - - // Parse sessionId if embedded in JSON - rejson := regexp.MustCompile(`"sessionId"\s*:\s*"([\w-]+)"`) - if matches := rejson.FindStringSubmatch(text); len(matches) == 2 { - sessionID := matches[1] - logger.Infof("🆔 Captured sessionId from JSON: %s", sessionID) - } - - res.Body = io.NopCloser(bytes.NewReader(body)) - res.ContentLength = int64(len(body)) - res.Header.Set("Content-Length", strconv.Itoa(len(body))) return nil } @@ -176,13 +154,8 @@ func (p *TransparentProxy) handleAndDetectInitialize(w http.ResponseWriter, r *h logger.Errorf("Error reading request body: %v", err) } else { if bytes.Contains(body, []byte(`"method":"initialize"`)) { - logger.Infof("🔧 Detected initialize request to %s", r.URL.Path) - extractedSession, ok := p.GetSession(TransparentProxySessionID) - if ok { - extractedSession.SetIsInitialized(true) - } else { - logger.Errorf("No session found to mark initialized") - } + logger.Infof("Detected initialize request to %s", r.URL.Path) + p.IsServerInitialized = true } r.Body = io.NopCloser(bytes.NewReader(body)) r.ContentLength = int64(len(body)) @@ -278,13 +251,7 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) { return case <-ticker.C: // Perform health check only if mcp server has been initialized - extractedSession, isSession := p.GetSession(TransparentProxySessionID) - if !isSession { - logger.Errorf("Failed to get session for health check") - return - } - - if extractedSession.IsInitialized() { + if p.IsServerInitialized { alive := p.healthChecker.CheckHealth(parentCtx) if alive.Status != healthcheck.StatusHealthy { logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName) @@ -294,7 +261,7 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) { return } } else { - logger.Infof("Session %s is not initialized, cannot start healthcheck", extractedSession.ID()) + logger.Infof("MCP server not initialized yet, skipping health check for %s", p.containerName) } } } @@ -347,32 +314,3 @@ func (*TransparentProxy) SendMessageToDestination(_ jsonrpc2.Message) error { func (*TransparentProxy) ForwardResponseToClients(_ context.Context, _ jsonrpc2.Message) error { return fmt.Errorf("ForwardResponseToClients not implemented for TransparentProxy") } - -// CreateSession creates a new session for the transparent proxy. -func (p *TransparentProxy) CreateSession(id string) (session.Session, error) { - logger.Infof("Creating session with ID: %s", id) - if id == "" { - return nil, fmt.Errorf("session ID cannot be empty") - } - p.sessionMutex.Lock() - defer p.sessionMutex.Unlock() - if _, found := p.sessions[id]; found { - return nil, fmt.Errorf("session %s exists", id) - } - s := &session.ProxySession{Id: id} - s.Init() - p.sessions[id] = s - return s, nil -} - -// GetSession retrieves a session by ID. -func (p *TransparentProxy) GetSession(id string) (session.Session, bool) { - logger.Infof("Retrieving session with ID: %s", id) - if id == "" { - return nil, false - } - p.sessionMutex.Lock() - defer p.sessionMutex.Unlock() - s, ok := p.sessions[id] - return s, ok -} diff --git a/pkg/transport/session/manager.go b/pkg/transport/session/manager.go new file mode 100644 index 000000000..bb5cc7c81 --- /dev/null +++ b/pkg/transport/session/manager.go @@ -0,0 +1,106 @@ +// Package session provides a session manager with TTL cleanup. +package session + +import ( + "fmt" + "sync" + "time" +) + +// Session interface +type Session interface { + ID() string + CreatedAt() time.Time + UpdatedAt() time.Time + Touch() +} + +// Manager holds sessions with TTL cleanup. +type Manager struct { + sessions map[string]Session + mu sync.RWMutex + ttl time.Duration + stopCh chan struct{} +} + +// NewManager creates a session manager with TTL and starts cleanup worker. +func NewManager(ttl time.Duration) *Manager { + m := &Manager{ + sessions: make(map[string]Session), + ttl: ttl, + stopCh: make(chan struct{}), + } + go m.cleanupRoutine() + return m +} + +func (m *Manager) cleanupRoutine() { + ticker := time.NewTicker(m.ttl / 2) + defer ticker.Stop() + for { + select { + case <-ticker.C: + m.CleanupExpired() + case <-m.stopCh: + return + } + } +} + +// AddWithID creates (and adds) a new session with the provided ID. +// Returns error if ID is empty or already exists. +func (m *Manager) AddWithID(id string) (Session, error) { + if id == "" { + return nil, fmt.Errorf("session ID cannot be empty") + } + + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.sessions[id]; exists { + return nil, fmt.Errorf("session ID %q already exists", id) + } + + s := NewProxySession(id) + m.sessions[id] = s + return s, nil +} + +// Get retrieves a session by ID. Returns (session, true) if found, +// and also updates its UpdatedAt timestamp. +func (m *Manager) Get(id string) (Session, bool) { + m.mu.RLock() + s, ok := m.sessions[id] + m.mu.RUnlock() + + if !ok { + return nil, false + } + + s.Touch() + return s, true +} + +// Delete removes a session by ID. +func (m *Manager) Delete(id string) { + m.mu.Lock() + delete(m.sessions, id) + m.mu.Unlock() +} + +// CleanupExpired removes sessions that have not been updated within the TTL. +func (m *Manager) CleanupExpired() { + cutoff := time.Now().Add(-m.ttl) + m.mu.Lock() + defer m.mu.Unlock() + for id, s := range m.sessions { + if s.UpdatedAt().Before(cutoff) { + delete(m.sessions, id) + } + } +} + +// Stop stops the cleanup worker. +func (m *Manager) Stop() { + close(m.stopCh) +} diff --git a/pkg/transport/session/proxy_session.go b/pkg/transport/session/proxy_session.go new file mode 100644 index 000000000..9f767b5d1 --- /dev/null +++ b/pkg/transport/session/proxy_session.go @@ -0,0 +1,28 @@ +package session + +import "time" + +// ProxySession implements the Session interface for proxy sessions. +type ProxySession struct { + id string + created time.Time + updated time.Time +} + +// NewProxySession creates a new ProxySession with the given ID. +func NewProxySession(id string) *ProxySession { + now := time.Now() + return &ProxySession{id: id, created: now, updated: now} +} + +// ID returns the session ID. +func (s *ProxySession) ID() string { return s.id } + +// CreatedAt returns the creation time of the session. +func (s *ProxySession) CreatedAt() time.Time { return s.created } + +// UpdatedAt returns the last updated time of the session. +func (s *ProxySession) UpdatedAt() time.Time { return s.updated } + +// Touch updates the session's last updated time to the current time. +func (s *ProxySession) Touch() { s.updated = time.Now() } diff --git a/pkg/transport/session/session.go b/pkg/transport/session/session.go deleted file mode 100644 index 8152250c3..000000000 --- a/pkg/transport/session/session.go +++ /dev/null @@ -1,40 +0,0 @@ -// Package session provides an interface and implementation for managing sessions in the transport layer. -package session - -// Session defines the interface for transport sessions. -type Session interface { - ID() string - Init() - IsInitialized() bool - SetIsInitialized(initialized bool) - MCPSessionID() string - SetMCPSessionID(mcpID string) -} - -// ProxySession implements the Session interface for transport sessions. -type ProxySession struct { - Id string - initialized bool - mcpSessionID string -} - -// ID returns the session ID. -func (s *ProxySession) ID() string { return s.Id } - -// SetIsInitialized sets the initialized state of the session. -func (s *ProxySession) SetIsInitialized(initialized bool) { s.initialized = initialized } - -// IsInitialized returns whether the session is initialized. -func (s *ProxySession) IsInitialized() bool { return s.initialized } - -// MCPSessionID returns the MCP session ID. -func (s *ProxySession) MCPSessionID() string { return s.mcpSessionID } - -// SetMCPSessionID sets the MCP session ID. -func (s *ProxySession) SetMCPSessionID(mcpID string) { s.mcpSessionID = mcpID } - -// Init initializes the session, setting it to uninitialized state. -func (s *ProxySession) Init() { - s.initialized = false - s.mcpSessionID = "" -} From 99483790b9bbb6445feb58f36bd5d9ec2c9aa5ab Mon Sep 17 00:00:00 2001 From: taskbot Date: Thu, 17 Jul 2025 13:07:57 +0200 Subject: [PATCH 3/9] fix tests --- pkg/transport/proxy/transparent/transparent_proxy.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index 5ceaef0bb..4bc0a71db 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -100,13 +100,14 @@ func (p *TransparentProxy) handleModifyResponse(res *http.Response) error { } } p.IsServerInitialized = true + return nil } // Handle streaming (SSE) ct, _, err := mime.ParseMediaType(res.Header.Get("Content-Type")) if err != nil { - logger.Errorf("Failed to parse Content-Type: %v", err) - return err + logger.Warnf("Invalid Content-Type header, defaulting behavior: %v", err) + ct = "" // or choose a fallback } if ct == "text/event-stream" { pr, pw := io.Pipe() From 1f6d97a3dc3087fa330ad88a1e5ea07b2ec0ab8a Mon Sep 17 00:00:00 2001 From: taskbot Date: Fri, 18 Jul 2025 11:25:55 +0200 Subject: [PATCH 4/9] refactor session management --- .../proxy/transparent/transparent_proxy.go | 176 ++++++++++++------ pkg/transport/session/manager.go | 4 +- 2 files changed, 119 insertions(+), 61 deletions(-) diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index 4bc0a71db..ee1bf316a 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -6,6 +6,7 @@ import ( "bufio" "bytes" "context" + "encoding/json" "fmt" "io" "mime" @@ -89,81 +90,137 @@ func NewTransparentProxy( return proxy } -var sessionIDRegex = regexp.MustCompile(`sessionId=([\w-]+)`) +type tracingTransport struct { + base http.RoundTripper + p *TransparentProxy +} + +func (t *tracingTransport) setServerInitialized() { + if !t.p.IsServerInitialized { + t.p.mutex.Lock() + t.p.IsServerInitialized = true + t.p.mutex.Unlock() + logger.Infof("Server was initialized successfully for %s", t.p.containerName) + } +} + +func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) { + tr := t.base + if tr == nil { + tr = http.DefaultTransport + } + return tr.RoundTrip(req) +} + +func (t *tracingTransport) watchEventStream(r io.Reader, w *io.PipeWriter) { + defer w.Close() + + scanner := bufio.NewScanner(r) + sessionRe := regexp.MustCompile(`sessionId=([0-9a-fA-F-]+)|\"sessionId\"\s*:\s*\"([^\"]+)\"`) -func (p *TransparentProxy) handleModifyResponse(res *http.Response) error { - if sid := res.Header.Get("Mcp-Session-Id"); sid != "" { - logger.Infof("Detected Mcp-Session-Id header: %s", sid) - if _, ok := p.sessionManager.Get(sid); !ok { - if _, err := p.sessionManager.AddWithID(sid); err != nil { - logger.Errorf("Failed to create session from header %s: %v", sid, err) + for scanner.Scan() { + line := scanner.Text() + + if m := sessionRe.FindStringSubmatch(line); m != nil { + sid := m[1] + if sid == "" { + sid = m[2] + } + + if _, ok := t.p.sessionManager.Get(sid); !ok { + _, err := t.p.sessionManager.AddWithID(sid) + if err != nil { + logger.Errorf("Failed to create session from event stream: %v", err) + } } + t.setServerInitialized() } - p.IsServerInitialized = true - return nil } - // Handle streaming (SSE) - ct, _, err := mime.ParseMediaType(res.Header.Get("Content-Type")) + _, err := io.Copy(io.Discard, r) if err != nil { - logger.Warnf("Invalid Content-Type header, defaulting behavior: %v", err) - ct = "" // or choose a fallback + logger.Errorf("Failed to copy event stream: %v", err) } - if ct == "text/event-stream" { - pr, pw := io.Pipe() - orig := res.Body - res.Body = pr - - go func() { - defer pw.Close() - scanner := bufio.NewScanner(orig) - for scanner.Scan() { - line := scanner.Text() - - if matches := sessionIDRegex.FindStringSubmatch(line); len(matches) == 2 { - sessionID := matches[1] - _, ok := p.sessionManager.Get(sessionID) - if !ok { - var err error - _, err = p.sessionManager.AddWithID(sessionID) - if err != nil { - logger.Errorf("Failed to create session %s: %v", sessionID, err) - continue - } - } - p.IsServerInitialized = true - } - _, err := pw.Write([]byte(line + "\n")) - if err != nil { - logger.Errorf("Failed to write to pipe: %v", err) +} + +func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + reqBody := readRequestBody(req) + + path := req.URL.Path + isMCP := strings.HasPrefix(path, "/mcp") + isJSON := strings.Contains(req.Header.Get("Content-Type"), "application/json") + sawInitialize := false + + if isMCP && isJSON && len(reqBody) > 0 { + sawInitialize = t.detectInitialize(reqBody) + } + + resp, err := t.forward(req) + if err != nil { + logger.Errorf("Failed to forward request: %v", err) + return nil, err + } + + if resp.StatusCode == http.StatusOK { + // check if we saw a valid mcp header + ct := resp.Header.Get("Mcp-Session-Id") + if ct != "" { + logger.Infof("Detected Mcp-Session-Id header: %s", ct) + if _, ok := t.p.sessionManager.Get(ct); !ok { + if _, err := t.p.sessionManager.AddWithID(ct); err != nil { + logger.Errorf("Failed to create session from header %s: %v", ct, err) } } - }() - return nil + t.setServerInitialized() + return resp, nil + } + // status was ok and we saw an initialize call + if sawInitialize && !t.p.IsServerInitialized { + t.setServerInitialized() + return resp, nil + } + ct = resp.Header.Get("Content-Type") + mediaType, _, _ := mime.ParseMediaType(ct) + if mediaType == "text/event-stream" { + originalBody := resp.Body + pr, pw := io.Pipe() + tee := io.TeeReader(originalBody, pw) + resp.Body = pr + + go t.watchEventStream(tee, pw) + } } - return nil + return resp, nil } -func (p *TransparentProxy) handleAndDetectInitialize(w http.ResponseWriter, r *http.Request, proxy *httputil.ReverseProxy) { - logger.Infof("Transparent proxy: %s %s -> %s", r.Method, r.URL.Path, p.targetURI) - - if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/mcp") { - // Read the body for inspection without consuming it - body, err := io.ReadAll(r.Body) +func readRequestBody(req *http.Request) []byte { + reqBody := []byte{} + if req.Body != nil { + buf, err := io.ReadAll(req.Body) if err != nil { - logger.Errorf("Error reading request body: %v", err) + logger.Errorf("Failed to read request body: %v", err) } else { - if bytes.Contains(body, []byte(`"method":"initialize"`)) { - logger.Infof("Detected initialize request to %s", r.URL.Path) - p.IsServerInitialized = true - } - r.Body = io.NopCloser(bytes.NewReader(body)) - r.ContentLength = int64(len(body)) + reqBody = buf } + req.Body = io.NopCloser(bytes.NewReader(reqBody)) } + return reqBody +} - proxy.ServeHTTP(w, r) +func (t *tracingTransport) detectInitialize(body []byte) bool { + var rpc struct { + Method string `json:"method"` + } + if err := json.Unmarshal(body, &rpc); err != nil { + logger.Errorf("Failed to parse JSON-RPC body: %v", err) + return false + } + if rpc.Method == "initialize" { + logger.Infof("Detected initialize method call for %s", t.p.containerName) + return true + } + return false } // Start starts the transparent proxy. @@ -179,11 +236,12 @@ func (p *TransparentProxy) Start(ctx context.Context) error { // Create a reverse proxy proxy := httputil.NewSingleHostReverseProxy(targetURL) - proxy.ModifyResponse = p.handleModifyResponse + proxy.Transport = &tracingTransport{base: http.DefaultTransport, p: p} // Create a handler that logs requests handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - p.handleAndDetectInitialize(w, r, proxy) + logger.Infof("Transparent proxy: %s %s -> %s", r.Method, r.URL.Path, targetURL) + proxy.ServeHTTP(w, r) }) // Create a mux to handle both proxy and health endpoints diff --git a/pkg/transport/session/manager.go b/pkg/transport/session/manager.go index bb5cc7c81..f4892e554 100644 --- a/pkg/transport/session/manager.go +++ b/pkg/transport/session/manager.go @@ -71,13 +71,13 @@ func (m *Manager) AddWithID(id string) (Session, error) { func (m *Manager) Get(id string) (Session, bool) { m.mu.RLock() s, ok := m.sessions[id] - m.mu.RUnlock() - if !ok { return nil, false } s.Touch() + m.mu.RUnlock() + return s, true } From 591f815563bcade8dee27878a1a1f37f8d200059 Mon Sep 17 00:00:00 2001 From: taskbot Date: Fri, 18 Jul 2025 12:56:36 +0200 Subject: [PATCH 5/9] fixes from review --- pkg/transport/proxy/transparent/transparent_proxy.go | 7 +++---- pkg/transport/session/manager.go | 8 ++++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index ee1bf316a..7a8be733a 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -126,14 +126,13 @@ func (t *tracingTransport) watchEventStream(r io.Reader, w *io.PipeWriter) { if sid == "" { sid = m[2] } - + t.setServerInitialized() if _, ok := t.p.sessionManager.Get(sid); !ok { - _, err := t.p.sessionManager.AddWithID(sid) + err := t.p.sessionManager.AddWithID(sid) if err != nil { logger.Errorf("Failed to create session from event stream: %v", err) } } - t.setServerInitialized() } } @@ -167,7 +166,7 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) if ct != "" { logger.Infof("Detected Mcp-Session-Id header: %s", ct) if _, ok := t.p.sessionManager.Get(ct); !ok { - if _, err := t.p.sessionManager.AddWithID(ct); err != nil { + if err := t.p.sessionManager.AddWithID(ct); err != nil { logger.Errorf("Failed to create session from header %s: %v", ct, err) } } diff --git a/pkg/transport/session/manager.go b/pkg/transport/session/manager.go index f4892e554..42f90ee33 100644 --- a/pkg/transport/session/manager.go +++ b/pkg/transport/session/manager.go @@ -49,21 +49,21 @@ func (m *Manager) cleanupRoutine() { // AddWithID creates (and adds) a new session with the provided ID. // Returns error if ID is empty or already exists. -func (m *Manager) AddWithID(id string) (Session, error) { +func (m *Manager) AddWithID(id string) error { if id == "" { - return nil, fmt.Errorf("session ID cannot be empty") + return fmt.Errorf("session ID cannot be empty") } m.mu.Lock() defer m.mu.Unlock() if _, exists := m.sessions[id]; exists { - return nil, fmt.Errorf("session ID %q already exists", id) + return fmt.Errorf("session ID %q already exists", id) } s := NewProxySession(id) m.sessions[id] = s - return s, nil + return nil } // Get retrieves a session by ID. Returns (session, true) if found, From eb0c1a6f9e807a178e905204b7b32e2a7ecc942a Mon Sep 17 00:00:00 2001 From: taskbot Date: Mon, 21 Jul 2025 12:30:38 +0200 Subject: [PATCH 6/9] add tests --- .../proxy/transparent/transparent_proxy.go | 95 ++++++++------ .../proxy/transparent/transparent_test.go | 120 ++++++++++++++++++ pkg/transport/session/manager.go | 55 +++----- pkg/transport/session/manager_test.go | 111 ++++++++++++++++ pkg/transport/session/proxy_session.go | 2 +- 5 files changed, 306 insertions(+), 77 deletions(-) create mode 100644 pkg/transport/proxy/transparent/transparent_test.go create mode 100644 pkg/transport/session/manager_test.go diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index 7a8be733a..936852a00 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -112,36 +112,6 @@ func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) { return tr.RoundTrip(req) } -func (t *tracingTransport) watchEventStream(r io.Reader, w *io.PipeWriter) { - defer w.Close() - - scanner := bufio.NewScanner(r) - sessionRe := regexp.MustCompile(`sessionId=([0-9a-fA-F-]+)|\"sessionId\"\s*:\s*\"([^\"]+)\"`) - - for scanner.Scan() { - line := scanner.Text() - - if m := sessionRe.FindStringSubmatch(line); m != nil { - sid := m[1] - if sid == "" { - sid = m[2] - } - t.setServerInitialized() - if _, ok := t.p.sessionManager.Get(sid); !ok { - err := t.p.sessionManager.AddWithID(sid) - if err != nil { - logger.Errorf("Failed to create session from event stream: %v", err) - } - } - } - } - - _, err := io.Copy(io.Discard, r) - if err != nil { - logger.Errorf("Failed to copy event stream: %v", err) - } -} - func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) { reqBody := readRequestBody(req) @@ -166,28 +136,23 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) if ct != "" { logger.Infof("Detected Mcp-Session-Id header: %s", ct) if _, ok := t.p.sessionManager.Get(ct); !ok { + fmt.Println("i get session id") if err := t.p.sessionManager.AddWithID(ct); err != nil { + fmt.Println("i add session") logger.Errorf("Failed to create session from header %s: %v", ct, err) } + fmt.Println("i set server initialized") } + fmt.Println("i set server initialized") t.setServerInitialized() return resp, nil } // status was ok and we saw an initialize call if sawInitialize && !t.p.IsServerInitialized { + fmt.Println("here") t.setServerInitialized() return resp, nil } - ct = resp.Header.Get("Content-Type") - mediaType, _, _ := mime.ParseMediaType(ct) - if mediaType == "text/event-stream" { - originalBody := resp.Body - pr, pw := io.Pipe() - tee := io.TeeReader(originalBody, pw) - resp.Body = pr - - go t.watchEventStream(tee, pw) - } } return resp, nil @@ -222,6 +187,52 @@ func (t *tracingTransport) detectInitialize(body []byte) bool { return false } +var sessionRe = regexp.MustCompile(`sessionId=([0-9A-Fa-f-]+)|"sessionId"\s*:\s*"([^"]+)"`) + +func (p *TransparentProxy) modifyForSessionID(resp *http.Response) error { + mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if mediaType != "text/event-stream" { + return nil + } + + pr, pw := io.Pipe() + originalBody := resp.Body + resp.Body = pr + + go func() { + defer pw.Close() + scanner := bufio.NewScanner(originalBody) + found := false + + for scanner.Scan() { + line := scanner.Bytes() + if !found { + if m := sessionRe.FindSubmatch(line); m != nil { + sid := string(m[1]) + if sid == "" { + sid = string(m[2]) + } + p.IsServerInitialized = true + err := p.sessionManager.AddWithID(sid) + if err != nil { + logger.Errorf("Failed to create session from SSE line: %v", err) + } + found = true + } + } + if _, err := pw.Write(append(line, '\n')); err != nil { + return + } + } + _, err := io.Copy(pw, originalBody) + if err != nil && err != io.EOF { + logger.Errorf("Failed to copy response body: %v", err) + } + }() + + return nil +} + // Start starts the transparent proxy. func (p *TransparentProxy) Start(ctx context.Context) error { p.mutex.Lock() @@ -235,7 +246,11 @@ func (p *TransparentProxy) Start(ctx context.Context) error { // Create a reverse proxy proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.FlushInterval = -1 proxy.Transport = &tracingTransport{base: http.DefaultTransport, p: p} + proxy.ModifyResponse = func(resp *http.Response) error { + return p.modifyForSessionID(resp) + } // Create a handler that logs requests handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/transport/proxy/transparent/transparent_test.go b/pkg/transport/proxy/transparent/transparent_test.go new file mode 100644 index 000000000..cad26a8d9 --- /dev/null +++ b/pkg/transport/proxy/transparent/transparent_test.go @@ -0,0 +1,120 @@ +package transparent + +import ( + "bufio" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "testing" + "time" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stretchr/testify/assert" +) + +func init() { + logger.Initialize() // ensure logging doesn't panic +} + +func TestStreamingSessionIDDetection(t *testing.T) { + proxy := NewTransparentProxy("127.0.0.1", 0, "test", "http://example.com", nil) + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + w.WriteHeader(200) + + // Simulate SSE lines + w.Write([]byte("data: hello\n")) + w.Write([]byte("data: sessionId=ABC123\n")) + w.(http.Flusher).Flush() + + time.Sleep(10 * time.Millisecond) + w.Write([]byte("data: more\n")) + })) + defer target.Close() + + // set up reverse proxy using ModifyResponse + parsedURL, _ := http.NewRequest("GET", target.URL, nil) + proxyURL := httputil.NewSingleHostReverseProxy(parsedURL.URL) + proxyURL.FlushInterval = -1 + proxyURL.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy} + proxyURL.ModifyResponse = proxy.modifyForSessionID + + // hit the proxy + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", target.URL, nil) + proxyURL.ServeHTTP(rec, req) + + // read all SSE lines + sc := bufio.NewScanner(rec.Body) + var bodyLines []string + for sc.Scan() { + bodyLines = append(bodyLines, sc.Text()) + } + assert.Contains(t, bodyLines, "data: sessionId=ABC123") + + // side-effect: proxy should have seen session + assert.True(t, proxy.IsServerInitialized, "server should have been initialized") + _, ok := proxy.sessionManager.Get("ABC123") + assert.True(t, ok, "sessionManager should have stored ABC123") +} + +func createBasicProxy(p *TransparentProxy, targetURL *url.URL) *httputil.ReverseProxy { + proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.Director = func(r *http.Request) { + r.URL.Scheme = targetURL.Scheme + r.URL.Host = targetURL.Host + r.Host = targetURL.Host + } + proxy.FlushInterval = -1 + proxy.Transport = &tracingTransport{base: http.DefaultTransport, p: p} + proxy.ModifyResponse = p.modifyForSessionID + return proxy +} + +func TestNoSessionIDInNonSSE(t *testing.T) { + p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil) + + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set both content-type and also optionally MCP header to test behavior + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"hello": "world"}`)) + })) + defer target.Close() + + targetURL, _ := url.Parse(target.URL) + proxy := createBasicProxy(p, targetURL) + + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", target.URL, nil) + proxy.ServeHTTP(rec, req) + + assert.False(t, p.IsServerInitialized, "server should not be initialized for application/json") + _, ok := p.sessionManager.Get("XYZ789") + assert.False(t, ok, "no session should be added") +} + +func TestHeaderBasedSessionInitialization(t *testing.T) { + p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil) + + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set both content-type and also optionally MCP header to test behavior + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", "XYZ789") + w.WriteHeader(200) + w.Write([]byte(`{"hello": "world"}`)) + })) + defer target.Close() + + targetURL, _ := url.Parse(target.URL) + proxy := createBasicProxy(p, targetURL) + + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", target.URL, nil) + proxy.ServeHTTP(rec, req) + + assert.True(t, p.IsServerInitialized, "server should not be initialized for application/json") + _, ok := p.sessionManager.Get("XYZ789") + assert.True(t, ok, "no session should be added") +} diff --git a/pkg/transport/session/manager.go b/pkg/transport/session/manager.go index 42f90ee33..6384c2be3 100644 --- a/pkg/transport/session/manager.go +++ b/pkg/transport/session/manager.go @@ -17,8 +17,7 @@ type Session interface { // Manager holds sessions with TTL cleanup. type Manager struct { - sessions map[string]Session - mu sync.RWMutex + sessions sync.Map ttl time.Duration stopCh chan struct{} } @@ -26,9 +25,8 @@ type Manager struct { // NewManager creates a session manager with TTL and starts cleanup worker. func NewManager(ttl time.Duration) *Manager { m := &Manager{ - sessions: make(map[string]Session), - ttl: ttl, - stopCh: make(chan struct{}), + ttl: ttl, + stopCh: make(chan struct{}), } go m.cleanupRoutine() return m @@ -40,7 +38,14 @@ func (m *Manager) cleanupRoutine() { for { select { case <-ticker.C: - m.CleanupExpired() + cutoff := time.Now().Add(-m.ttl) + m.sessions.Range(func(key, val any) bool { + sess := val.(Session) + if sess.UpdatedAt().Before(cutoff) { + m.sessions.Delete(key) + } + return true + }) case <-m.stopCh: return } @@ -53,51 +58,29 @@ func (m *Manager) AddWithID(id string) error { if id == "" { return fmt.Errorf("session ID cannot be empty") } - - m.mu.Lock() - defer m.mu.Unlock() - - if _, exists := m.sessions[id]; exists { + // Use LoadOrStore: returns existing if already present + _, loaded := m.sessions.LoadOrStore(id, NewProxySession(id)) + if loaded { return fmt.Errorf("session ID %q already exists", id) } - - s := NewProxySession(id) - m.sessions[id] = s return nil } // Get retrieves a session by ID. Returns (session, true) if found, // and also updates its UpdatedAt timestamp. func (m *Manager) Get(id string) (Session, bool) { - m.mu.RLock() - s, ok := m.sessions[id] + v, ok := m.sessions.Load(id) if !ok { return nil, false } - - s.Touch() - m.mu.RUnlock() - - return s, true + sess := v.(Session) + sess.Touch() + return sess, true } // Delete removes a session by ID. func (m *Manager) Delete(id string) { - m.mu.Lock() - delete(m.sessions, id) - m.mu.Unlock() -} - -// CleanupExpired removes sessions that have not been updated within the TTL. -func (m *Manager) CleanupExpired() { - cutoff := time.Now().Add(-m.ttl) - m.mu.Lock() - defer m.mu.Unlock() - for id, s := range m.sessions { - if s.UpdatedAt().Before(cutoff) { - delete(m.sessions, id) - } - } + m.sessions.Delete(id) } // Stop stops the cleanup worker. diff --git a/pkg/transport/session/manager_test.go b/pkg/transport/session/manager_test.go new file mode 100644 index 000000000..7b9d6b6a8 --- /dev/null +++ b/pkg/transport/session/manager_test.go @@ -0,0 +1,111 @@ +package session + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddAndGetWithStubSession(t *testing.T) { + orig := NewProxySession + NewProxySession = func(id string) *ProxySession { + ts := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + return &ProxySession{id: id, created: ts, updated: ts} + } + defer func() { NewProxySession = orig }() + + m := NewManager(1 * time.Hour) + defer m.Stop() + + require.NoError(t, m.AddWithID("foo")) + + sess, ok := m.Get("foo") + require.True(t, ok) + assert.Equal(t, "foo", sess.ID()) +} + +func TestAddDuplicate(t *testing.T) { + m := NewManager(time.Hour) + defer m.Stop() + + err := m.AddWithID("dup") + assert.NoError(t, err) + + err2 := m.AddWithID("dup") + assert.Error(t, err2) + assert.Contains(t, err2.Error(), "already exists") +} + +func TestDeleteSession(t *testing.T) { + m := NewManager(time.Hour) + defer m.Stop() + + require.NoError(t, m.AddWithID("del")) + m.Delete("del") + + _, ok := m.Get("del") + assert.False(t, ok) +} + +func TestGetUpdatesTimestamp(t *testing.T) { + orig := NewProxySession + NewProxySession = func(id string) *ProxySession { + ts := time.Now().Add(-1 * time.Minute) + return &ProxySession{id: id, created: ts, updated: ts} + } + defer func() { NewProxySession = orig }() + + m := NewManager(1 * time.Hour) + defer m.Stop() + + require.NoError(t, m.AddWithID("touchme")) + s1, _ := m.Get("touchme") + t0 := s1.UpdatedAt() + + time.Sleep(5 * time.Millisecond) + s2, _ := m.Get("touchme") + t1 := s2.UpdatedAt() + + assert.True(t, t1.After(t0), "UpdatedAt should update on Get()") +} + +func TestCleanupExpired(t *testing.T) { + ttl := 50 * time.Millisecond + orig := NewProxySession + NewProxySession = func(id string) *ProxySession { + return &ProxySession{ + id: id, + created: time.Now(), + updated: time.Now(), + } + } + defer func() { NewProxySession = orig }() + + m := NewManager(ttl) + defer m.Stop() + + require.NoError(t, m.AddWithID("old")) + time.Sleep(ttl * 2) // allow old to expire + + require.NoError(t, m.AddWithID("new")) + time.Sleep(ttl) // let cleanup execute + + _, okOld := m.Get("old") + _, okNew := m.Get("new") + assert.False(t, okOld, "expired session should be cleaned") + assert.True(t, okNew, "recent session should remain") +} + +func TestStopDisablesCleanup(t *testing.T) { + ttl := 50 * time.Millisecond + m := NewManager(ttl) + m.Stop() // stop cleanup upfront + + require.NoError(t, m.AddWithID("stay")) + time.Sleep(ttl * 2) + + _, ok := m.Get("stay") + assert.True(t, ok, "session should persist after Stop()") +} diff --git a/pkg/transport/session/proxy_session.go b/pkg/transport/session/proxy_session.go index 9f767b5d1..e0ad4c15d 100644 --- a/pkg/transport/session/proxy_session.go +++ b/pkg/transport/session/proxy_session.go @@ -10,7 +10,7 @@ type ProxySession struct { } // NewProxySession creates a new ProxySession with the given ID. -func NewProxySession(id string) *ProxySession { +var NewProxySession = func(id string) *ProxySession { now := time.Now() return &ProxySession{id: id, created: now, updated: now} } From b4846a1d0ff3b0b69ef75174659e7a0087e4c0bd Mon Sep 17 00:00:00 2001 From: taskbot Date: Mon, 21 Jul 2025 12:42:14 +0200 Subject: [PATCH 7/9] fix lint --- .../proxy/transparent/transparent_test.go | 14 ++++++++++---- pkg/transport/session/manager_test.go | 12 ++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/pkg/transport/proxy/transparent/transparent_test.go b/pkg/transport/proxy/transparent/transparent_test.go index cad26a8d9..2c63019ce 100644 --- a/pkg/transport/proxy/transparent/transparent_test.go +++ b/pkg/transport/proxy/transparent/transparent_test.go @@ -9,8 +9,9 @@ import ( "testing" "time" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stretchr/testify/assert" + + "github.com/stacklok/toolhive/pkg/logger" ) func init() { @@ -18,8 +19,9 @@ func init() { } func TestStreamingSessionIDDetection(t *testing.T) { + t.Parallel() proxy := NewTransparentProxy("127.0.0.1", 0, "test", "http://example.com", nil) - target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") w.WriteHeader(200) @@ -73,9 +75,11 @@ func createBasicProxy(p *TransparentProxy, targetURL *url.URL) *httputil.Reverse } func TestNoSessionIDInNonSSE(t *testing.T) { + t.Parallel() + p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil) - target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // Set both content-type and also optionally MCP header to test behavior w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -96,9 +100,11 @@ func TestNoSessionIDInNonSSE(t *testing.T) { } func TestHeaderBasedSessionInitialization(t *testing.T) { + t.Parallel() + p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil) - target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // Set both content-type and also optionally MCP header to test behavior w.Header().Set("Content-Type", "application/json") w.Header().Set("Mcp-Session-Id", "XYZ789") diff --git a/pkg/transport/session/manager_test.go b/pkg/transport/session/manager_test.go index 7b9d6b6a8..203d13075 100644 --- a/pkg/transport/session/manager_test.go +++ b/pkg/transport/session/manager_test.go @@ -9,6 +9,8 @@ import ( ) func TestAddAndGetWithStubSession(t *testing.T) { + t.Parallel() + orig := NewProxySession NewProxySession = func(id string) *ProxySession { ts := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) @@ -27,6 +29,8 @@ func TestAddAndGetWithStubSession(t *testing.T) { } func TestAddDuplicate(t *testing.T) { + t.Parallel() + m := NewManager(time.Hour) defer m.Stop() @@ -39,6 +43,8 @@ func TestAddDuplicate(t *testing.T) { } func TestDeleteSession(t *testing.T) { + t.Parallel() + m := NewManager(time.Hour) defer m.Stop() @@ -50,6 +56,8 @@ func TestDeleteSession(t *testing.T) { } func TestGetUpdatesTimestamp(t *testing.T) { + t.Parallel() + orig := NewProxySession NewProxySession = func(id string) *ProxySession { ts := time.Now().Add(-1 * time.Minute) @@ -72,6 +80,8 @@ func TestGetUpdatesTimestamp(t *testing.T) { } func TestCleanupExpired(t *testing.T) { + t.Parallel() + ttl := 50 * time.Millisecond orig := NewProxySession NewProxySession = func(id string) *ProxySession { @@ -99,6 +109,8 @@ func TestCleanupExpired(t *testing.T) { } func TestStopDisablesCleanup(t *testing.T) { + t.Parallel() + ttl := 50 * time.Millisecond m := NewManager(ttl) m.Stop() // stop cleanup upfront From 9f62a12e741302e3c564f5f30372adc5f5d49bc1 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Mota Date: Mon, 21 Jul 2025 12:54:03 +0200 Subject: [PATCH 8/9] Update pkg/transport/session/proxy_session.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pkg/transport/session/proxy_session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/transport/session/proxy_session.go b/pkg/transport/session/proxy_session.go index e0ad4c15d..9f767b5d1 100644 --- a/pkg/transport/session/proxy_session.go +++ b/pkg/transport/session/proxy_session.go @@ -10,7 +10,7 @@ type ProxySession struct { } // NewProxySession creates a new ProxySession with the given ID. -var NewProxySession = func(id string) *ProxySession { +func NewProxySession(id string) *ProxySession { now := time.Now() return &ProxySession{id: id, created: now, updated: now} } From ed1e12f6e6a61a0cf58c22b7865f4cf4765f4cf9 Mon Sep 17 00:00:00 2001 From: taskbot Date: Mon, 21 Jul 2025 13:10:05 +0200 Subject: [PATCH 9/9] fixes from copilot --- .../proxy/transparent/transparent_proxy.go | 25 ++-- pkg/transport/session/manager.go | 38 +++++- pkg/transport/session/manager_test.go | 117 ++++++++++-------- 3 files changed, 109 insertions(+), 71 deletions(-) diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index 936852a00..c17c34d0b 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -80,7 +80,7 @@ func NewTransparentProxy( middlewares: middlewares, shutdownCh: make(chan struct{}), prometheusHandler: prometheusHandler, - sessionManager: session.NewManager(30 * time.Minute), + sessionManager: session.NewManager(30*time.Minute, session.NewProxySession), } // Create MCP pinger and health checker @@ -95,12 +95,12 @@ type tracingTransport struct { p *TransparentProxy } -func (t *tracingTransport) setServerInitialized() { - if !t.p.IsServerInitialized { - t.p.mutex.Lock() - t.p.IsServerInitialized = true - t.p.mutex.Unlock() - logger.Infof("Server was initialized successfully for %s", t.p.containerName) +func (p *TransparentProxy) setServerInitialized() { + if !p.IsServerInitialized { + p.mutex.Lock() + p.IsServerInitialized = true + p.mutex.Unlock() + logger.Infof("Server was initialized successfully for %s", p.containerName) } } @@ -136,21 +136,16 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) if ct != "" { logger.Infof("Detected Mcp-Session-Id header: %s", ct) if _, ok := t.p.sessionManager.Get(ct); !ok { - fmt.Println("i get session id") if err := t.p.sessionManager.AddWithID(ct); err != nil { - fmt.Println("i add session") logger.Errorf("Failed to create session from header %s: %v", ct, err) } - fmt.Println("i set server initialized") } - fmt.Println("i set server initialized") - t.setServerInitialized() + t.p.setServerInitialized() return resp, nil } // status was ok and we saw an initialize call if sawInitialize && !t.p.IsServerInitialized { - fmt.Println("here") - t.setServerInitialized() + t.p.setServerInitialized() return resp, nil } } @@ -212,7 +207,7 @@ func (p *TransparentProxy) modifyForSessionID(resp *http.Response) error { if sid == "" { sid = string(m[2]) } - p.IsServerInitialized = true + p.setServerInitialized() err := p.sessionManager.AddWithID(sid) if err != nil { logger.Errorf("Failed to create session from SSE line: %v", err) diff --git a/pkg/transport/session/manager.go b/pkg/transport/session/manager.go index 6384c2be3..de5c512c5 100644 --- a/pkg/transport/session/manager.go +++ b/pkg/transport/session/manager.go @@ -20,13 +20,19 @@ type Manager struct { sessions sync.Map ttl time.Duration stopCh chan struct{} + factory Factory } +// Factory defines a function type for creating new sessions. +type Factory func(id string) *ProxySession + // NewManager creates a session manager with TTL and starts cleanup worker. -func NewManager(ttl time.Duration) *Manager { +func NewManager(ttl time.Duration, factory Factory) *Manager { m := &Manager{ - ttl: ttl, - stopCh: make(chan struct{}), + sessions: sync.Map{}, + ttl: ttl, + stopCh: make(chan struct{}), + factory: factory, } go m.cleanupRoutine() return m @@ -40,7 +46,11 @@ func (m *Manager) cleanupRoutine() { case <-ticker.C: cutoff := time.Now().Add(-m.ttl) m.sessions.Range(func(key, val any) bool { - sess := val.(Session) + sess, ok := val.(Session) + if !ok { + // Skip invalid value + return true + } if sess.UpdatedAt().Before(cutoff) { m.sessions.Delete(key) } @@ -59,7 +69,8 @@ func (m *Manager) AddWithID(id string) error { return fmt.Errorf("session ID cannot be empty") } // Use LoadOrStore: returns existing if already present - _, loaded := m.sessions.LoadOrStore(id, NewProxySession(id)) + session := m.factory(id) + _, loaded := m.sessions.LoadOrStore(id, session) if loaded { return fmt.Errorf("session ID %q already exists", id) } @@ -73,7 +84,11 @@ func (m *Manager) Get(id string) (Session, bool) { if !ok { return nil, false } - sess := v.(Session) + sess, ok := v.(Session) + if !ok { + return nil, false // Invalid session type + } + sess.Touch() return sess, true } @@ -87,3 +102,14 @@ func (m *Manager) Delete(id string) { func (m *Manager) Stop() { close(m.stopCh) } + +func (m *Manager) cleanupExpiredOnce() { + cutoff := time.Now().Add(-m.ttl) + m.sessions.Range(func(key, val any) bool { + sess := val.(Session) + if sess.UpdatedAt().Before(cutoff) { + m.sessions.Delete(key) + } + return true + }) +} diff --git a/pkg/transport/session/manager_test.go b/pkg/transport/session/manager_test.go index 203d13075..5223c03c0 100644 --- a/pkg/transport/session/manager_test.go +++ b/pkg/transport/session/manager_test.go @@ -1,6 +1,7 @@ package session import ( + "sync" "testing" "time" @@ -8,116 +9,132 @@ import ( "github.com/stretchr/testify/require" ) -func TestAddAndGetWithStubSession(t *testing.T) { - t.Parallel() +// stubFactory returns ProxySessions with fixed timestamps and records IDs. +type stubFactory struct { + mu sync.Mutex + createdIDs []string + fixedTime time.Time +} - orig := NewProxySession - NewProxySession = func(id string) *ProxySession { - ts := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) - return &ProxySession{id: id, created: ts, updated: ts} +func (f *stubFactory) New(id string) *ProxySession { + f.mu.Lock() + defer f.mu.Unlock() + f.createdIDs = append(f.createdIDs, id) + return &ProxySession{ + id: id, + created: f.fixedTime, + updated: f.fixedTime, } - defer func() { NewProxySession = orig }() +} - m := NewManager(1 * time.Hour) +func TestAddAndGetWithStubSession(t *testing.T) { + t.Parallel() + now := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + factory := &stubFactory{fixedTime: now} + + m := NewManager(time.Hour, factory.New) defer m.Stop() require.NoError(t, m.AddWithID("foo")) sess, ok := m.Get("foo") - require.True(t, ok) + require.True(t, ok, "session foo should exist") assert.Equal(t, "foo", sess.ID()) + assert.Contains(t, factory.createdIDs, "foo") } func TestAddDuplicate(t *testing.T) { t.Parallel() + factory := &stubFactory{fixedTime: time.Now()} - m := NewManager(time.Hour) + m := NewManager(time.Hour, factory.New) defer m.Stop() - err := m.AddWithID("dup") - assert.NoError(t, err) + require.NoError(t, m.AddWithID("dup")) - err2 := m.AddWithID("dup") - assert.Error(t, err2) - assert.Contains(t, err2.Error(), "already exists") + err := m.AddWithID("dup") + assert.Error(t, err) + assert.Contains(t, err.Error(), "already exists") } func TestDeleteSession(t *testing.T) { t.Parallel() + factory := &stubFactory{fixedTime: time.Now()} - m := NewManager(time.Hour) + m := NewManager(time.Hour, factory.New) defer m.Stop() require.NoError(t, m.AddWithID("del")) m.Delete("del") _, ok := m.Get("del") - assert.False(t, ok) + assert.False(t, ok, "deleted session should not be found") } func TestGetUpdatesTimestamp(t *testing.T) { t.Parallel() + oldTime := time.Now().Add(-1 * time.Minute) + factory := &stubFactory{fixedTime: oldTime} - orig := NewProxySession - NewProxySession = func(id string) *ProxySession { - ts := time.Now().Add(-1 * time.Minute) - return &ProxySession{id: id, created: ts, updated: ts} - } - defer func() { NewProxySession = orig }() - - m := NewManager(1 * time.Hour) + m := NewManager(time.Hour, factory.New) defer m.Stop() require.NoError(t, m.AddWithID("touchme")) - s1, _ := m.Get("touchme") + s1, ok := m.Get("touchme") + require.True(t, ok) t0 := s1.UpdatedAt() - time.Sleep(5 * time.Millisecond) - s2, _ := m.Get("touchme") + time.Sleep(10 * time.Millisecond) + s2, ok2 := m.Get("touchme") + require.True(t, ok2) t1 := s2.UpdatedAt() - assert.True(t, t1.After(t0), "UpdatedAt should update on Get()") + assert.True(t, t1.After(t0), "UpdatedAt should update on repeated Get()") } - -func TestCleanupExpired(t *testing.T) { +func TestCleanupExpired_ManualTrigger(t *testing.T) { t.Parallel() + // Stub factory: all sessions start with UpdatedAt = `now` + now := time.Now() + factory := &stubFactory{fixedTime: now} ttl := 50 * time.Millisecond - orig := NewProxySession - NewProxySession = func(id string) *ProxySession { - return &ProxySession{ - id: id, - created: time.Now(), - updated: time.Now(), - } - } - defer func() { NewProxySession = orig }() - m := NewManager(ttl) + m := NewManager(ttl, factory.New) defer m.Stop() require.NoError(t, m.AddWithID("old")) - time.Sleep(ttl * 2) // allow old to expire - require.NoError(t, m.AddWithID("new")) - time.Sleep(ttl) // let cleanup execute + // Retrieve and expire session manually + sess, ok := m.Get("old") + require.True(t, ok) + ps := sess.(*ProxySession) + ps.updated = now.Add(-ttl * 2) + // Run cleanup manually + m.cleanupExpiredOnce() + + // Now it should be gone _, okOld := m.Get("old") + assert.False(t, okOld, "expired session should have been cleaned") + + // Add fresh session and assert it remains after cleanup + require.NoError(t, m.AddWithID("new")) + m.cleanupExpiredOnce() _, okNew := m.Get("new") - assert.False(t, okOld, "expired session should be cleaned") - assert.True(t, okNew, "recent session should remain") + assert.True(t, okNew, "new session should still exist after cleanup") } func TestStopDisablesCleanup(t *testing.T) { t.Parallel() - ttl := 50 * time.Millisecond - m := NewManager(ttl) - m.Stop() // stop cleanup upfront + factory := &stubFactory{fixedTime: time.Now()} + + m := NewManager(ttl, factory.New) + m.Stop() // disable cleanup before any session expires require.NoError(t, m.AddWithID("stay")) time.Sleep(ttl * 2) _, ok := m.Get("stay") - assert.True(t, ok, "session should persist after Stop()") + assert.True(t, ok, "session should still be present even after Stop() and TTL elapsed") }