@@ -17,6 +17,7 @@ import (
17
17
18
18
"github.com/stacklok/toolhive/pkg/healthcheck"
19
19
"github.com/stacklok/toolhive/pkg/logger"
20
+ "github.com/stacklok/toolhive/pkg/transport/session"
20
21
"github.com/stacklok/toolhive/pkg/transport/ssecommon"
21
22
"github.com/stacklok/toolhive/pkg/transport/types"
22
23
)
@@ -63,9 +64,8 @@ type HTTPSSEProxy struct {
63
64
// Optional Prometheus metrics handler
64
65
prometheusHandler http.Handler
65
66
66
- // SSE clients
67
- sseClients map [string ]* ssecommon.SSEClient
68
- sseClientsMutex sync.RWMutex
67
+ // Session manager for SSE clients
68
+ sessionManager * session.Manager
69
69
70
70
// Pending messages for SSE clients
71
71
pendingMessages []* ssecommon.PendingSSEMessage
@@ -86,14 +86,19 @@ type HTTPSSEProxy struct {
86
86
func NewHTTPSSEProxy (
87
87
host string , port int , containerName string , prometheusHandler http.Handler , middlewares ... types.MiddlewareFunction ,
88
88
) * HTTPSSEProxy {
89
+ // Create a factory for SSE sessions
90
+ sseFactory := func (id string ) session.Session {
91
+ return session .NewSSESession (id )
92
+ }
93
+
89
94
proxy := & HTTPSSEProxy {
90
95
middlewares : middlewares ,
91
96
host : host ,
92
97
port : port ,
93
98
containerName : containerName ,
94
99
shutdownCh : make (chan struct {}),
95
100
messageCh : make (chan jsonrpc2.Message , 100 ),
96
- sseClients : make ( map [ string ] * ssecommon. SSEClient ),
101
+ sessionManager : session . NewManager ( 30 * time . Minute , sseFactory ),
97
102
pendingMessages : []* ssecommon.PendingSSEMessage {},
98
103
prometheusHandler : prometheusHandler ,
99
104
closedClients : make (map [string ]bool ),
@@ -190,6 +195,19 @@ func (p *HTTPSSEProxy) Stop(ctx context.Context) error {
190
195
// Signal shutdown
191
196
close (p .shutdownCh )
192
197
198
+ // Stop the session manager cleanup routine
199
+ if p .sessionManager != nil {
200
+ p .sessionManager .Stop ()
201
+ }
202
+
203
+ // Disconnect all active sessions
204
+ p .sessionManager .Range (func (_ , value interface {}) bool {
205
+ if sess , ok := value .(* session.SSESession ); ok {
206
+ sess .Disconnect ()
207
+ }
208
+ return true
209
+ })
210
+
193
211
// Stop the HTTP server
194
212
if p .server != nil {
195
213
return p .server .Shutdown (ctx )
@@ -226,10 +244,12 @@ func (p *HTTPSSEProxy) ForwardResponseToClients(_ context.Context, msg jsonrpc2.
226
244
// Create an SSE message
227
245
sseMsg := ssecommon .NewSSEMessage ("message" , string (data ))
228
246
229
- // Check if there are any connected clients
230
- p .sseClientsMutex .RLock ()
231
- hasClients := len (p .sseClients ) > 0
232
- p .sseClientsMutex .RUnlock ()
247
+ // Check if there are any connected clients by checking session count
248
+ hasClients := false
249
+ p .sessionManager .Range (func (_ , _ interface {}) bool {
250
+ hasClients = true
251
+ return false // Stop iteration after finding first session
252
+ })
233
253
234
254
if hasClients {
235
255
// Send the message to all connected clients
@@ -258,13 +278,19 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques
258
278
// Create a channel for sending messages to this client
259
279
messageCh := make (chan string , 100 )
260
280
261
- // Register the client
262
- p .sseClientsMutex .Lock ()
263
- p .sseClients [clientID ] = & ssecommon.SSEClient {
281
+ // Create SSE client info
282
+ clientInfo := & ssecommon.SSEClient {
264
283
MessageCh : messageCh ,
265
284
CreatedAt : time .Now (),
266
285
}
267
- p .sseClientsMutex .Unlock ()
286
+
287
+ // Create and register the SSE session
288
+ sseSession := session .NewSSESessionWithClient (clientID , clientInfo )
289
+ if err := p .sessionManager .AddSession (sseSession ); err != nil {
290
+ logger .Errorf ("Failed to add SSE session: %v" , err )
291
+ http .Error (w , "Failed to create session" , http .StatusInternalServerError )
292
+ return
293
+ }
268
294
269
295
// Process any pending messages for this client
270
296
p .processPendingMessages (clientID , messageCh )
@@ -347,10 +373,7 @@ func (p *HTTPSSEProxy) handlePostRequest(w http.ResponseWriter, r *http.Request)
347
373
}
348
374
349
375
// Check if the session exists
350
- p .sseClientsMutex .RLock ()
351
- _ , exists := p .sseClients [sessionID ]
352
- p .sseClientsMutex .RUnlock ()
353
-
376
+ _ , exists := p .sessionManager .Get (sessionID )
354
377
if ! exists {
355
378
http .Error (w , "Could not find session" , http .StatusNotFound )
356
379
return
@@ -391,21 +414,34 @@ func (p *HTTPSSEProxy) sendSSEEvent(msg *ssecommon.SSEMessage) error {
391
414
// Convert the message to an SSE-formatted string
392
415
sseString := msg .ToSSEString ()
393
416
394
- // Hold the lock while sending to ensure channels aren't closed during send
395
- // This is a read lock, so multiple sends can happen concurrently
396
- p .sseClientsMutex .RLock ()
397
- defer p .sseClientsMutex .RUnlock ()
417
+ // Iterate through all sessions and send to SSE sessions
418
+ p .sessionManager .Range (func (key , value interface {}) bool {
419
+ clientID , ok := key .(string )
420
+ if ! ok {
421
+ return true // Continue iteration
422
+ }
423
+
424
+ sess , ok := value .(session.Session )
425
+ if ! ok {
426
+ return true // Continue iteration
427
+ }
398
428
399
- for clientID , client := range p .sseClients {
400
- select {
401
- case client .MessageCh <- sseString :
402
- // Message sent successfully
403
- default :
404
- // Channel is full, skip this client
405
- // Don't remove the client here - let the disconnect monitor handle it
406
- logger .Debugf ("Client %s channel full, skipping message" , clientID )
429
+ // Check if this is an SSE session
430
+ if sseSession , ok := sess .(* session.SSESession ); ok {
431
+ // Try to send the message
432
+ if err := sseSession .SendMessage (sseString ); err != nil {
433
+ // Log the error but continue sending to other clients
434
+ switch err {
435
+ case session .ErrSessionDisconnected :
436
+ logger .Debugf ("Client %s is disconnected, skipping message" , clientID )
437
+ case session .ErrMessageChannelFull :
438
+ logger .Debugf ("Client %s channel full, skipping message" , clientID )
439
+ }
440
+ }
407
441
}
408
- }
442
+
443
+ return true // Continue iteration
444
+ })
409
445
410
446
return nil
411
447
}
@@ -421,21 +457,20 @@ func (p *HTTPSSEProxy) removeClient(clientID string) {
421
457
p .closedClients [clientID ] = true
422
458
p .closedClientsMutex .Unlock ()
423
459
424
- // Remove from clients map and get the client
425
- // Use write lock to ensure no sends happen during removal
426
- p .sseClientsMutex .Lock ()
427
- client , exists := p .sseClients [clientID ]
428
- if exists {
429
- delete (p .sseClients , clientID )
460
+ // Get the session from the manager
461
+ sess , exists := p .sessionManager .Get (clientID )
462
+ if ! exists {
463
+ return
430
464
}
431
- p .sseClientsMutex .Unlock ()
432
465
433
- // Close the channel after removing from map
434
- // This ensures no goroutine will try to send to it
435
- if exists && client != nil {
436
- close (client .MessageCh )
466
+ // If it's an SSE session, disconnect it
467
+ if sseSession , ok := sess .(* session.SSESession ); ok {
468
+ sseSession .Disconnect ()
437
469
}
438
470
471
+ // Remove the session from the manager
472
+ p .sessionManager .Delete (clientID )
473
+
439
474
// Clean up closed clients map periodically (prevent memory leak)
440
475
p .closedClientsMutex .Lock ()
441
476
if len (p .closedClients ) > 1000 {
0 commit comments