From 92c5c77a7d1217205572d4e31bcca30e2d2b6d9b Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Mon, 1 Sep 2025 10:36:18 +0300 Subject: [PATCH] Refactor session management to support multiple storage backends Unified session management across all transport types by migrating HTTPSSEProxy from direct map storage to use the centralized session manager. Extended the session interface to support type differentiation and metadata storage, enabling future support for distributed session storage backends like Redis/Valkey. Key changes: - Added session types (MCP, SSE, Streamable) for better session handling - Created SSESession type with SSE-specific functionality - Migrated HTTPSSEProxy to use session manager with proper TTL - Updated factory pattern to support different session types - Fixed all tests to work with the new session management - Maintained backward compatibility with existing code Signed-off-by: Juan Antonio Osorio --- pkg/transport/proxy/httpsse/http_proxy.go | 115 ++++++++++------ .../proxy/httpsse/http_proxy_test.go | 129 +++++++++--------- pkg/transport/session/errors.go | 21 +++ pkg/transport/session/manager.go | 84 +++++++++++- pkg/transport/session/proxy_session.go | 117 +++++++++++++++- pkg/transport/session/sse_session.go | 100 ++++++++++++++ 6 files changed, 451 insertions(+), 115 deletions(-) create mode 100644 pkg/transport/session/errors.go create mode 100644 pkg/transport/session/sse_session.go diff --git a/pkg/transport/proxy/httpsse/http_proxy.go b/pkg/transport/proxy/httpsse/http_proxy.go index 9e2d4e46a..2021cf8c5 100644 --- a/pkg/transport/proxy/httpsse/http_proxy.go +++ b/pkg/transport/proxy/httpsse/http_proxy.go @@ -17,6 +17,7 @@ import ( "github.com/stacklok/toolhive/pkg/healthcheck" "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/transport/ssecommon" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -63,9 +64,8 @@ type HTTPSSEProxy struct { // Optional Prometheus metrics handler prometheusHandler http.Handler - // SSE clients - sseClients map[string]*ssecommon.SSEClient - sseClientsMutex sync.RWMutex + // Session manager for SSE clients + sessionManager *session.Manager // Pending messages for SSE clients pendingMessages []*ssecommon.PendingSSEMessage @@ -86,6 +86,11 @@ type HTTPSSEProxy struct { func NewHTTPSSEProxy( host string, port int, containerName string, prometheusHandler http.Handler, middlewares ...types.MiddlewareFunction, ) *HTTPSSEProxy { + // Create a factory for SSE sessions + sseFactory := func(id string) session.Session { + return session.NewSSESession(id) + } + proxy := &HTTPSSEProxy{ middlewares: middlewares, host: host, @@ -93,7 +98,7 @@ func NewHTTPSSEProxy( containerName: containerName, shutdownCh: make(chan struct{}), messageCh: make(chan jsonrpc2.Message, 100), - sseClients: make(map[string]*ssecommon.SSEClient), + sessionManager: session.NewManager(30*time.Minute, sseFactory), pendingMessages: []*ssecommon.PendingSSEMessage{}, prometheusHandler: prometheusHandler, closedClients: make(map[string]bool), @@ -190,6 +195,19 @@ func (p *HTTPSSEProxy) Stop(ctx context.Context) error { // Signal shutdown close(p.shutdownCh) + // Stop the session manager cleanup routine + if p.sessionManager != nil { + p.sessionManager.Stop() + } + + // Disconnect all active sessions + p.sessionManager.Range(func(_, value interface{}) bool { + if sess, ok := value.(*session.SSESession); ok { + sess.Disconnect() + } + return true + }) + // Stop the HTTP server if p.server != nil { return p.server.Shutdown(ctx) @@ -226,10 +244,12 @@ func (p *HTTPSSEProxy) ForwardResponseToClients(_ context.Context, msg jsonrpc2. // Create an SSE message sseMsg := ssecommon.NewSSEMessage("message", string(data)) - // Check if there are any connected clients - p.sseClientsMutex.RLock() - hasClients := len(p.sseClients) > 0 - p.sseClientsMutex.RUnlock() + // Check if there are any connected clients by checking session count + hasClients := false + p.sessionManager.Range(func(_, _ interface{}) bool { + hasClients = true + return false // Stop iteration after finding first session + }) if hasClients { // Send the message to all connected clients @@ -258,13 +278,19 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques // Create a channel for sending messages to this client messageCh := make(chan string, 100) - // Register the client - p.sseClientsMutex.Lock() - p.sseClients[clientID] = &ssecommon.SSEClient{ + // Create SSE client info + clientInfo := &ssecommon.SSEClient{ MessageCh: messageCh, CreatedAt: time.Now(), } - p.sseClientsMutex.Unlock() + + // Create and register the SSE session + sseSession := session.NewSSESessionWithClient(clientID, clientInfo) + if err := p.sessionManager.AddSession(sseSession); err != nil { + logger.Errorf("Failed to add SSE session: %v", err) + http.Error(w, "Failed to create session", http.StatusInternalServerError) + return + } // Process any pending messages for this client p.processPendingMessages(clientID, messageCh) @@ -347,10 +373,7 @@ func (p *HTTPSSEProxy) handlePostRequest(w http.ResponseWriter, r *http.Request) } // Check if the session exists - p.sseClientsMutex.RLock() - _, exists := p.sseClients[sessionID] - p.sseClientsMutex.RUnlock() - + _, exists := p.sessionManager.Get(sessionID) if !exists { http.Error(w, "Could not find session", http.StatusNotFound) return @@ -391,21 +414,34 @@ func (p *HTTPSSEProxy) sendSSEEvent(msg *ssecommon.SSEMessage) error { // Convert the message to an SSE-formatted string sseString := msg.ToSSEString() - // Hold the lock while sending to ensure channels aren't closed during send - // This is a read lock, so multiple sends can happen concurrently - p.sseClientsMutex.RLock() - defer p.sseClientsMutex.RUnlock() + // Iterate through all sessions and send to SSE sessions + p.sessionManager.Range(func(key, value interface{}) bool { + clientID, ok := key.(string) + if !ok { + return true // Continue iteration + } + + sess, ok := value.(session.Session) + if !ok { + return true // Continue iteration + } - for clientID, client := range p.sseClients { - select { - case client.MessageCh <- sseString: - // Message sent successfully - default: - // Channel is full, skip this client - // Don't remove the client here - let the disconnect monitor handle it - logger.Debugf("Client %s channel full, skipping message", clientID) + // Check if this is an SSE session + if sseSession, ok := sess.(*session.SSESession); ok { + // Try to send the message + if err := sseSession.SendMessage(sseString); err != nil { + // Log the error but continue sending to other clients + switch err { + case session.ErrSessionDisconnected: + logger.Debugf("Client %s is disconnected, skipping message", clientID) + case session.ErrMessageChannelFull: + logger.Debugf("Client %s channel full, skipping message", clientID) + } + } } - } + + return true // Continue iteration + }) return nil } @@ -421,21 +457,20 @@ func (p *HTTPSSEProxy) removeClient(clientID string) { p.closedClients[clientID] = true p.closedClientsMutex.Unlock() - // Remove from clients map and get the client - // Use write lock to ensure no sends happen during removal - p.sseClientsMutex.Lock() - client, exists := p.sseClients[clientID] - if exists { - delete(p.sseClients, clientID) + // Get the session from the manager + sess, exists := p.sessionManager.Get(clientID) + if !exists { + return } - p.sseClientsMutex.Unlock() - // Close the channel after removing from map - // This ensures no goroutine will try to send to it - if exists && client != nil { - close(client.MessageCh) + // If it's an SSE session, disconnect it + if sseSession, ok := sess.(*session.SSESession); ok { + sseSession.Disconnect() } + // Remove the session from the manager + p.sessionManager.Delete(clientID) + // Clean up closed clients map periodically (prevent memory leak) p.closedClientsMutex.Lock() if len(p.closedClients) > 1000 { diff --git a/pkg/transport/proxy/httpsse/http_proxy_test.go b/pkg/transport/proxy/httpsse/http_proxy_test.go index 1f5bf927a..8766f69e2 100644 --- a/pkg/transport/proxy/httpsse/http_proxy_test.go +++ b/pkg/transport/proxy/httpsse/http_proxy_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/jsonrpc2" + "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/transport/ssecommon" ) @@ -30,7 +31,7 @@ func TestNewHTTPSSEProxy(t *testing.T) { assert.Equal(t, 8080, proxy.port) assert.Equal(t, "test-container", proxy.containerName) assert.NotNil(t, proxy.messageCh) - assert.NotNil(t, proxy.sseClients) + assert.NotNil(t, proxy.sessionManager) assert.NotNil(t, proxy.closedClients) assert.NotNil(t, proxy.healthChecker) } @@ -94,24 +95,23 @@ func TestSendMessageToDestination_ChannelFull(t *testing.T) { func TestRemoveClient(t *testing.T) { proxy := NewHTTPSSEProxy("localhost", 8080, "test-container", nil) - // Create a client + // Create a client session clientID := "test-client-1" - messageCh := make(chan string, 10) - - proxy.sseClientsMutex.Lock() - proxy.sseClients[clientID] = &ssecommon.SSEClient{ - MessageCh: messageCh, + clientInfo := &ssecommon.SSEClient{ + MessageCh: make(chan string, 10), CreatedAt: time.Now(), } - proxy.sseClientsMutex.Unlock() + + // Add session to manager + sseSession := session.NewSSESessionWithClient(clientID, clientInfo) + err := proxy.sessionManager.AddSession(sseSession) + require.NoError(t, err) // Remove the client once proxy.removeClient(clientID) - // Verify client was removed - proxy.sseClientsMutex.RLock() - _, exists := proxy.sseClients[clientID] - proxy.sseClientsMutex.RUnlock() + // Verify client was removed from session manager + _, exists := proxy.sessionManager.Get(clientID) assert.False(t, exists) // Verify client is marked as closed @@ -132,18 +132,19 @@ func TestRemoveClient(t *testing.T) { func TestConcurrentClientRemoval(t *testing.T) { proxy := NewHTTPSSEProxy("localhost", 8080, "test-container", nil) - // Create multiple clients + // Create multiple client sessions numClients := 100 for i := 0; i < numClients; i++ { clientID := fmt.Sprintf("client-%d", i) - messageCh := make(chan string, 10) - - proxy.sseClientsMutex.Lock() - proxy.sseClients[clientID] = &ssecommon.SSEClient{ - MessageCh: messageCh, + clientInfo := &ssecommon.SSEClient{ + MessageCh: make(chan string, 10), CreatedAt: time.Now(), } - proxy.sseClientsMutex.Unlock() + + // Add session to manager + sseSession := session.NewSSESessionWithClient(clientID, clientInfo) + err := proxy.sessionManager.AddSession(sseSession) + require.NoError(t, err) } // Concurrently remove all clients from multiple goroutines @@ -170,9 +171,7 @@ func TestConcurrentClientRemoval(t *testing.T) { }) // Verify all clients are removed - proxy.sseClientsMutex.RLock() - assert.Empty(t, proxy.sseClients) - proxy.sseClientsMutex.RUnlock() + assert.Equal(t, 0, proxy.sessionManager.Count()) } // TestForwardResponseToClients tests forwarding responses to connected clients @@ -182,16 +181,18 @@ func TestForwardResponseToClients(t *testing.T) { proxy := NewHTTPSSEProxy("localhost", 8080, "test-container", nil) ctx := context.Background() - // Create a client + // Create a client session clientID := testClientID messageCh := make(chan string, 10) - - proxy.sseClientsMutex.Lock() - proxy.sseClients[clientID] = &ssecommon.SSEClient{ + clientInfo := &ssecommon.SSEClient{ MessageCh: messageCh, CreatedAt: time.Now(), } - proxy.sseClientsMutex.Unlock() + + // Add session to manager + sseSession := session.NewSSESessionWithClient(clientID, clientInfo) + err := proxy.sessionManager.AddSession(sseSession) + require.NoError(t, err) // Create a test response response, err := jsonrpc2.NewResponse(jsonrpc2.StringID("test"), "test result", nil) @@ -238,30 +239,30 @@ func TestForwardResponseToClients_NoClients(t *testing.T) { func TestSendSSEEvent_ChannelFull(t *testing.T) { proxy := NewHTTPSSEProxy("localhost", 8080, "test-container", nil) - // Create a client with a small buffer + // Create a client session with a small buffer clientID := testClientID messageCh := make(chan string, 1) - - proxy.sseClientsMutex.Lock() - proxy.sseClients[clientID] = &ssecommon.SSEClient{ + clientInfo := &ssecommon.SSEClient{ MessageCh: messageCh, CreatedAt: time.Now(), } - proxy.sseClientsMutex.Unlock() + + // Add session to manager + sseSession := session.NewSSESessionWithClient(clientID, clientInfo) + err := proxy.sessionManager.AddSession(sseSession) + require.NoError(t, err) // Fill the channel messageCh <- "blocking message" // Try to send another message msg := ssecommon.NewSSEMessage("test", "test data") - err := proxy.sendSSEEvent(msg) - assert.NoError(t, err) + err2 := proxy.sendSSEEvent(msg) + assert.NoError(t, err2) // In the improved implementation, we don't remove clients with full channels // We just skip sending to them and let the disconnect monitor handle cleanup - proxy.sseClientsMutex.RLock() - _, exists := proxy.sseClients[clientID] - proxy.sseClientsMutex.RUnlock() + _, exists := proxy.sessionManager.Get(clientID) assert.True(t, exists, "Client should still exist even with full channel") // Clean up @@ -322,10 +323,7 @@ func TestHandleSSEConnection(t *testing.T) { // Verify a client was registered time.Sleep(100 * time.Millisecond) // Give time for registration - proxy.sseClientsMutex.RLock() - numClients := len(proxy.sseClients) - proxy.sseClientsMutex.RUnlock() - assert.Equal(t, 1, numClients) + assert.Equal(t, 1, proxy.sessionManager.Count()) } // TestHandlePostRequest tests handling of POST requests @@ -334,16 +332,17 @@ func TestHandleSSEConnection(t *testing.T) { func TestHandlePostRequest(t *testing.T) { proxy := NewHTTPSSEProxy("localhost", 8080, "test-container", nil) - // Create a client + // Create a client session sessionID := "test-session" - messageCh := make(chan string, 10) - - proxy.sseClientsMutex.Lock() - proxy.sseClients[sessionID] = &ssecommon.SSEClient{ - MessageCh: messageCh, + clientInfo := &ssecommon.SSEClient{ + MessageCh: make(chan string, 10), CreatedAt: time.Now(), } - proxy.sseClientsMutex.Unlock() + + // Add session to manager + sseSession := session.NewSSESessionWithClient(sessionID, clientInfo) + err := proxy.sessionManager.AddSession(sseSession) + require.NoError(t, err) // Create a valid JSON-RPC message msg, err := jsonrpc2.NewCall(jsonrpc2.StringID("test"), "test.method", nil) @@ -413,17 +412,18 @@ func TestHandlePostRequest_InvalidSession(t *testing.T) { func TestRWMutexUsage(t *testing.T) { proxy := NewHTTPSSEProxy("localhost", 8080, "test-container", nil) - // Add multiple clients + // Add multiple client sessions for i := 0; i < 10; i++ { clientID := fmt.Sprintf("client-%d", i) - messageCh := make(chan string, 10) - - proxy.sseClientsMutex.Lock() - proxy.sseClients[clientID] = &ssecommon.SSEClient{ - MessageCh: messageCh, + clientInfo := &ssecommon.SSEClient{ + MessageCh: make(chan string, 10), CreatedAt: time.Now(), } - proxy.sseClientsMutex.Unlock() + + // Add session to manager + sseSession := session.NewSSESessionWithClient(clientID, clientInfo) + err := proxy.sessionManager.AddSession(sseSession) + require.NoError(t, err) } // Test concurrent reads (should not block each other) @@ -434,10 +434,8 @@ func TestRWMutexUsage(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - proxy.sseClientsMutex.RLock() - _ = len(proxy.sseClients) + _ = proxy.sessionManager.Count() time.Sleep(10 * time.Millisecond) // Simulate some work - proxy.sseClientsMutex.RUnlock() }() } @@ -455,17 +453,18 @@ func TestRWMutexUsage(t *testing.T) { func TestClosedClientsCleanup(t *testing.T) { proxy := NewHTTPSSEProxy("localhost", 8080, "test-container", nil) - // Add many closed clients to trigger cleanup + // Add many closed client sessions to trigger cleanup for i := 0; i < 1100; i++ { clientID := fmt.Sprintf("client-%d", i) - messageCh := make(chan string, 1) - - proxy.sseClientsMutex.Lock() - proxy.sseClients[clientID] = &ssecommon.SSEClient{ - MessageCh: messageCh, + clientInfo := &ssecommon.SSEClient{ + MessageCh: make(chan string, 1), CreatedAt: time.Now(), } - proxy.sseClientsMutex.Unlock() + + // Add session to manager + sseSession := session.NewSSESessionWithClient(clientID, clientInfo) + err := proxy.sessionManager.AddSession(sseSession) + require.NoError(t, err) // Remove the client proxy.removeClient(clientID) diff --git a/pkg/transport/session/errors.go b/pkg/transport/session/errors.go new file mode 100644 index 000000000..7a8aab6bb --- /dev/null +++ b/pkg/transport/session/errors.go @@ -0,0 +1,21 @@ +package session + +import "errors" + +// Common session errors +var ( + // ErrSessionDisconnected is returned when trying to send to a disconnected session + ErrSessionDisconnected = errors.New("session is disconnected") + + // ErrMessageChannelFull is returned when the message channel is full + ErrMessageChannelFull = errors.New("message channel is full") + + // ErrSessionNotFound is returned when a session cannot be found + ErrSessionNotFound = errors.New("session not found") + + // ErrSessionAlreadyExists is returned when trying to create a session with an existing ID + ErrSessionAlreadyExists = errors.New("session already exists") + + // ErrInvalidSessionType is returned when an invalid session type is provided + ErrInvalidSessionType = errors.New("invalid session type") +) diff --git a/pkg/transport/session/manager.go b/pkg/transport/session/manager.go index de5c512c5..66ecc8971 100644 --- a/pkg/transport/session/manager.go +++ b/pkg/transport/session/manager.go @@ -24,20 +24,65 @@ type Manager struct { } // Factory defines a function type for creating new sessions. -type Factory func(id string) *ProxySession +// It now returns the Session interface to support different session types. +type Factory func(id string) Session + +// LegacyFactory is the old factory type for backward compatibility +type LegacyFactory func(id string) *ProxySession // NewManager creates a session manager with TTL and starts cleanup worker. -func NewManager(ttl time.Duration, factory Factory) *Manager { +// It accepts either the new Factory or the legacy factory for backward compatibility. +func NewManager(ttl time.Duration, factory interface{}) *Manager { + var f Factory + + switch factoryFunc := factory.(type) { + case Factory: + f = factoryFunc + case LegacyFactory: + // Wrap legacy factory to return Session interface + f = func(id string) Session { + return factoryFunc(id) + } + case func(id string) *ProxySession: + // Also support direct function for backward compatibility + f = func(id string) Session { + return factoryFunc(id) + } + default: + // Default to creating basic ProxySession + f = func(id string) Session { + return NewProxySession(id) + } + } + m := &Manager{ sessions: sync.Map{}, ttl: ttl, stopCh: make(chan struct{}), - factory: factory, + factory: f, } go m.cleanupRoutine() return m } +// NewTypedManager creates a session manager for a specific session type. +func NewTypedManager(ttl time.Duration, sessionType SessionType) *Manager { + factory := func(id string) Session { + switch sessionType { + case SessionTypeSSE: + return NewSSESession(id) + case SessionTypeMCP: + return NewProxySession(id) + case SessionTypeStreamable: + return NewTypedProxySession(id, sessionType) + default: + return NewTypedProxySession(id, sessionType) + } + } + + return NewManager(ttl, factory) +} + func (m *Manager) cleanupRoutine() { ticker := time.NewTicker(m.ttl / 2) defer ticker.Stop() @@ -77,6 +122,23 @@ func (m *Manager) AddWithID(id string) error { return nil } +// AddSession adds an existing session to the manager. +// This is useful when you need to create a session with specific properties. +func (m *Manager) AddSession(session Session) error { + if session == nil { + return fmt.Errorf("session cannot be nil") + } + if session.ID() == "" { + return fmt.Errorf("session ID cannot be empty") + } + + _, loaded := m.sessions.LoadOrStore(session.ID(), session) + if loaded { + return fmt.Errorf("session ID %q already exists", session.ID()) + } + return nil +} + // Get retrieves a session by ID. Returns (session, true) if found, // and also updates its UpdatedAt timestamp. func (m *Manager) Get(id string) (Session, bool) { @@ -103,6 +165,22 @@ func (m *Manager) Stop() { close(m.stopCh) } +// Range calls f sequentially for each key and value present in the map. +// If f returns false, range stops the iteration. +func (m *Manager) Range(f func(key, value interface{}) bool) { + m.sessions.Range(f) +} + +// Count returns the number of active sessions. +func (m *Manager) Count() int { + count := 0 + m.sessions.Range(func(_, _ interface{}) bool { + count++ + return true + }) + return count +} + func (m *Manager) cleanupExpiredOnce() { cutoff := time.Now().Add(-m.ttl) m.sessions.Range(func(key, val any) bool { diff --git a/pkg/transport/session/proxy_session.go b/pkg/transport/session/proxy_session.go index 9f767b5d1..da86bf8f6 100644 --- a/pkg/transport/session/proxy_session.go +++ b/pkg/transport/session/proxy_session.go @@ -1,18 +1,59 @@ package session -import "time" +import ( + "sync" + "time" +) + +// SessionType represents the type of session +// +//revive:disable-next-line:exported +type SessionType string + +const ( + // SessionTypeMCP represents a standard MCP session + SessionTypeMCP SessionType = "mcp" + // SessionTypeSSE represents an SSE (Server-Sent Events) session + SessionTypeSSE SessionType = "sse" + // SessionTypeStreamable represents a streamable HTTP session + SessionTypeStreamable SessionType = "streamable" +) // ProxySession implements the Session interface for proxy sessions. +// It now includes support for session types, metadata, and custom data. type ProxySession struct { - id string - created time.Time - updated time.Time + id string + created time.Time + updated time.Time + sessType SessionType + data interface{} + metadata map[string]string + mu sync.RWMutex // Protect concurrent access to metadata and data } // NewProxySession creates a new ProxySession with the given ID. +// It defaults to SessionTypeMCP for backward compatibility. func NewProxySession(id string) *ProxySession { now := time.Now() - return &ProxySession{id: id, created: now, updated: now} + return &ProxySession{ + id: id, + created: now, + updated: now, + sessType: SessionTypeMCP, + metadata: make(map[string]string), + } +} + +// NewTypedProxySession creates a new ProxySession with the given ID and type. +func NewTypedProxySession(id string, sessType SessionType) *ProxySession { + now := time.Now() + return &ProxySession{ + id: id, + created: now, + updated: now, + sessType: sessType, + metadata: make(map[string]string), + } } // ID returns the session ID. @@ -22,7 +63,69 @@ func (s *ProxySession) ID() string { return s.id } func (s *ProxySession) CreatedAt() time.Time { return s.created } // UpdatedAt returns the last updated time of the session. -func (s *ProxySession) UpdatedAt() time.Time { return s.updated } +func (s *ProxySession) UpdatedAt() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.updated +} // Touch updates the session's last updated time to the current time. -func (s *ProxySession) Touch() { s.updated = time.Now() } +func (s *ProxySession) Touch() { + s.mu.Lock() + defer s.mu.Unlock() + s.updated = time.Now() +} + +// Type returns the session type. +func (s *ProxySession) Type() SessionType { return s.sessType } + +// GetData returns the session-specific data. +func (s *ProxySession) GetData() interface{} { + s.mu.RLock() + defer s.mu.RUnlock() + return s.data +} + +// SetData sets the session-specific data. +func (s *ProxySession) SetData(data interface{}) { + s.mu.Lock() + defer s.mu.Unlock() + s.data = data +} + +// GetMetadata returns all metadata as a map. +func (s *ProxySession) GetMetadata() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + // Return a copy to prevent external modification + result := make(map[string]string, len(s.metadata)) + for k, v := range s.metadata { + result[k] = v + } + return result +} + +// SetMetadata sets a metadata key-value pair. +func (s *ProxySession) SetMetadata(key, value string) { + s.mu.Lock() + defer s.mu.Unlock() + if s.metadata == nil { + s.metadata = make(map[string]string) + } + s.metadata[key] = value +} + +// GetMetadataValue gets a specific metadata value. +func (s *ProxySession) GetMetadataValue(key string) (string, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + value, ok := s.metadata[key] + return value, ok +} + +// DeleteMetadata removes a metadata key. +func (s *ProxySession) DeleteMetadata(key string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.metadata, key) +} diff --git a/pkg/transport/session/sse_session.go b/pkg/transport/session/sse_session.go new file mode 100644 index 000000000..e45e16aa9 --- /dev/null +++ b/pkg/transport/session/sse_session.go @@ -0,0 +1,100 @@ +package session + +import ( + "time" + + "github.com/stacklok/toolhive/pkg/transport/ssecommon" +) + +// SSESession represents an SSE (Server-Sent Events) session. +// It embeds ProxySession and adds SSE-specific functionality. +type SSESession struct { + *ProxySession + + // SSE-specific fields + MessageCh chan string + ClientInfo *ssecommon.SSEClient + IsConnected bool +} + +// NewSSESession creates a new SSE session with the given ID. +func NewSSESession(id string) *SSESession { + return &SSESession{ + ProxySession: NewTypedProxySession(id, SessionTypeSSE), + MessageCh: make(chan string, 100), + IsConnected: true, + } +} + +// NewSSESessionWithClient creates a new SSE session with the given ID and client info. +func NewSSESessionWithClient(id string, client *ssecommon.SSEClient) *SSESession { + sess := NewSSESession(id) + sess.ClientInfo = client + if client != nil { + sess.MessageCh = client.MessageCh + } + return sess +} + +// Disconnect marks the session as disconnected and closes the message channel. +func (s *SSESession) Disconnect() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.IsConnected { + s.IsConnected = false + if s.MessageCh != nil { + close(s.MessageCh) + } + } +} + +// SendMessage sends a message to the SSE client if connected. +func (s *SSESession) SendMessage(msg string) error { + s.mu.RLock() + defer s.mu.RUnlock() + + if !s.IsConnected { + return ErrSessionDisconnected + } + + select { + case s.MessageCh <- msg: + return nil + default: + return ErrMessageChannelFull + } +} + +// GetClientInfo returns the SSE client information. +func (s *SSESession) GetClientInfo() *ssecommon.SSEClient { + s.mu.RLock() + defer s.mu.RUnlock() + return s.ClientInfo +} + +// SetClientInfo sets the SSE client information. +func (s *SSESession) SetClientInfo(client *ssecommon.SSEClient) { + s.mu.Lock() + defer s.mu.Unlock() + s.ClientInfo = client + if client != nil && client.MessageCh != nil { + s.MessageCh = client.MessageCh + } +} + +// GetConnectionStatus returns whether the SSE session is connected. +func (s *SSESession) GetConnectionStatus() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.IsConnected +} + +// GetCreatedAt returns when the SSE session was created. +// This is useful for tracking connection duration. +func (s *SSESession) GetCreatedAt() time.Time { + if s.ClientInfo != nil { + return s.ClientInfo.CreatedAt + } + return s.CreatedAt() +}