Skip to content

Commit efacaef

Browse files
authored
Refactor session management to support multiple storage backends (#1677)
Unified session management across all transport types by migrating HTTPSSEProxy from direct map storage to use the centralized session manager. Extended the session interface to support type differentiation and metadata storage, enabling future support for distributed session storage backends like Redis/Valkey. Key changes: - Added session types (MCP, SSE, Streamable) for better session handling - Created SSESession type with SSE-specific functionality - Migrated HTTPSSEProxy to use session manager with proper TTL - Updated factory pattern to support different session types - Fixed all tests to work with the new session management - Maintained backward compatibility with existing code Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 580f417 commit efacaef

File tree

6 files changed

+451
-115
lines changed

6 files changed

+451
-115
lines changed

pkg/transport/proxy/httpsse/http_proxy.go

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
"github.com/stacklok/toolhive/pkg/healthcheck"
1919
"github.com/stacklok/toolhive/pkg/logger"
20+
"github.com/stacklok/toolhive/pkg/transport/session"
2021
"github.com/stacklok/toolhive/pkg/transport/ssecommon"
2122
"github.com/stacklok/toolhive/pkg/transport/types"
2223
)
@@ -63,9 +64,8 @@ type HTTPSSEProxy struct {
6364
// Optional Prometheus metrics handler
6465
prometheusHandler http.Handler
6566

66-
// SSE clients
67-
sseClients map[string]*ssecommon.SSEClient
68-
sseClientsMutex sync.RWMutex
67+
// Session manager for SSE clients
68+
sessionManager *session.Manager
6969

7070
// Pending messages for SSE clients
7171
pendingMessages []*ssecommon.PendingSSEMessage
@@ -86,14 +86,19 @@ type HTTPSSEProxy struct {
8686
func NewHTTPSSEProxy(
8787
host string, port int, containerName string, prometheusHandler http.Handler, middlewares ...types.MiddlewareFunction,
8888
) *HTTPSSEProxy {
89+
// Create a factory for SSE sessions
90+
sseFactory := func(id string) session.Session {
91+
return session.NewSSESession(id)
92+
}
93+
8994
proxy := &HTTPSSEProxy{
9095
middlewares: middlewares,
9196
host: host,
9297
port: port,
9398
containerName: containerName,
9499
shutdownCh: make(chan struct{}),
95100
messageCh: make(chan jsonrpc2.Message, 100),
96-
sseClients: make(map[string]*ssecommon.SSEClient),
101+
sessionManager: session.NewManager(30*time.Minute, sseFactory),
97102
pendingMessages: []*ssecommon.PendingSSEMessage{},
98103
prometheusHandler: prometheusHandler,
99104
closedClients: make(map[string]bool),
@@ -190,6 +195,19 @@ func (p *HTTPSSEProxy) Stop(ctx context.Context) error {
190195
// Signal shutdown
191196
close(p.shutdownCh)
192197

198+
// Stop the session manager cleanup routine
199+
if p.sessionManager != nil {
200+
p.sessionManager.Stop()
201+
}
202+
203+
// Disconnect all active sessions
204+
p.sessionManager.Range(func(_, value interface{}) bool {
205+
if sess, ok := value.(*session.SSESession); ok {
206+
sess.Disconnect()
207+
}
208+
return true
209+
})
210+
193211
// Stop the HTTP server
194212
if p.server != nil {
195213
return p.server.Shutdown(ctx)
@@ -226,10 +244,12 @@ func (p *HTTPSSEProxy) ForwardResponseToClients(_ context.Context, msg jsonrpc2.
226244
// Create an SSE message
227245
sseMsg := ssecommon.NewSSEMessage("message", string(data))
228246

229-
// Check if there are any connected clients
230-
p.sseClientsMutex.RLock()
231-
hasClients := len(p.sseClients) > 0
232-
p.sseClientsMutex.RUnlock()
247+
// Check if there are any connected clients by checking session count
248+
hasClients := false
249+
p.sessionManager.Range(func(_, _ interface{}) bool {
250+
hasClients = true
251+
return false // Stop iteration after finding first session
252+
})
233253

234254
if hasClients {
235255
// Send the message to all connected clients
@@ -258,13 +278,19 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques
258278
// Create a channel for sending messages to this client
259279
messageCh := make(chan string, 100)
260280

261-
// Register the client
262-
p.sseClientsMutex.Lock()
263-
p.sseClients[clientID] = &ssecommon.SSEClient{
281+
// Create SSE client info
282+
clientInfo := &ssecommon.SSEClient{
264283
MessageCh: messageCh,
265284
CreatedAt: time.Now(),
266285
}
267-
p.sseClientsMutex.Unlock()
286+
287+
// Create and register the SSE session
288+
sseSession := session.NewSSESessionWithClient(clientID, clientInfo)
289+
if err := p.sessionManager.AddSession(sseSession); err != nil {
290+
logger.Errorf("Failed to add SSE session: %v", err)
291+
http.Error(w, "Failed to create session", http.StatusInternalServerError)
292+
return
293+
}
268294

269295
// Process any pending messages for this client
270296
p.processPendingMessages(clientID, messageCh)
@@ -347,10 +373,7 @@ func (p *HTTPSSEProxy) handlePostRequest(w http.ResponseWriter, r *http.Request)
347373
}
348374

349375
// Check if the session exists
350-
p.sseClientsMutex.RLock()
351-
_, exists := p.sseClients[sessionID]
352-
p.sseClientsMutex.RUnlock()
353-
376+
_, exists := p.sessionManager.Get(sessionID)
354377
if !exists {
355378
http.Error(w, "Could not find session", http.StatusNotFound)
356379
return
@@ -391,21 +414,34 @@ func (p *HTTPSSEProxy) sendSSEEvent(msg *ssecommon.SSEMessage) error {
391414
// Convert the message to an SSE-formatted string
392415
sseString := msg.ToSSEString()
393416

394-
// Hold the lock while sending to ensure channels aren't closed during send
395-
// This is a read lock, so multiple sends can happen concurrently
396-
p.sseClientsMutex.RLock()
397-
defer p.sseClientsMutex.RUnlock()
417+
// Iterate through all sessions and send to SSE sessions
418+
p.sessionManager.Range(func(key, value interface{}) bool {
419+
clientID, ok := key.(string)
420+
if !ok {
421+
return true // Continue iteration
422+
}
423+
424+
sess, ok := value.(session.Session)
425+
if !ok {
426+
return true // Continue iteration
427+
}
398428

399-
for clientID, client := range p.sseClients {
400-
select {
401-
case client.MessageCh <- sseString:
402-
// Message sent successfully
403-
default:
404-
// Channel is full, skip this client
405-
// Don't remove the client here - let the disconnect monitor handle it
406-
logger.Debugf("Client %s channel full, skipping message", clientID)
429+
// Check if this is an SSE session
430+
if sseSession, ok := sess.(*session.SSESession); ok {
431+
// Try to send the message
432+
if err := sseSession.SendMessage(sseString); err != nil {
433+
// Log the error but continue sending to other clients
434+
switch err {
435+
case session.ErrSessionDisconnected:
436+
logger.Debugf("Client %s is disconnected, skipping message", clientID)
437+
case session.ErrMessageChannelFull:
438+
logger.Debugf("Client %s channel full, skipping message", clientID)
439+
}
440+
}
407441
}
408-
}
442+
443+
return true // Continue iteration
444+
})
409445

410446
return nil
411447
}
@@ -421,21 +457,20 @@ func (p *HTTPSSEProxy) removeClient(clientID string) {
421457
p.closedClients[clientID] = true
422458
p.closedClientsMutex.Unlock()
423459

424-
// Remove from clients map and get the client
425-
// Use write lock to ensure no sends happen during removal
426-
p.sseClientsMutex.Lock()
427-
client, exists := p.sseClients[clientID]
428-
if exists {
429-
delete(p.sseClients, clientID)
460+
// Get the session from the manager
461+
sess, exists := p.sessionManager.Get(clientID)
462+
if !exists {
463+
return
430464
}
431-
p.sseClientsMutex.Unlock()
432465

433-
// Close the channel after removing from map
434-
// This ensures no goroutine will try to send to it
435-
if exists && client != nil {
436-
close(client.MessageCh)
466+
// If it's an SSE session, disconnect it
467+
if sseSession, ok := sess.(*session.SSESession); ok {
468+
sseSession.Disconnect()
437469
}
438470

471+
// Remove the session from the manager
472+
p.sessionManager.Delete(clientID)
473+
439474
// Clean up closed clients map periodically (prevent memory leak)
440475
p.closedClientsMutex.Lock()
441476
if len(p.closedClients) > 1000 {

0 commit comments

Comments
 (0)