Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 75 additions & 40 deletions pkg/transport/proxy/httpsse/http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/stacklok/toolhive/pkg/healthcheck"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/transport/session"
"github.com/stacklok/toolhive/pkg/transport/ssecommon"
"github.com/stacklok/toolhive/pkg/transport/types"
)
Expand Down Expand Up @@ -63,9 +64,8 @@ type HTTPSSEProxy struct {
// Optional Prometheus metrics handler
prometheusHandler http.Handler

// SSE clients
sseClients map[string]*ssecommon.SSEClient
sseClientsMutex sync.RWMutex
// Session manager for SSE clients
sessionManager *session.Manager

// Pending messages for SSE clients
pendingMessages []*ssecommon.PendingSSEMessage
Expand All @@ -86,14 +86,19 @@ type HTTPSSEProxy struct {
func NewHTTPSSEProxy(
host string, port int, containerName string, prometheusHandler http.Handler, middlewares ...types.MiddlewareFunction,
) *HTTPSSEProxy {
// Create a factory for SSE sessions
sseFactory := func(id string) session.Session {
return session.NewSSESession(id)
}

proxy := &HTTPSSEProxy{
middlewares: middlewares,
host: host,
port: port,
containerName: containerName,
shutdownCh: make(chan struct{}),
messageCh: make(chan jsonrpc2.Message, 100),
sseClients: make(map[string]*ssecommon.SSEClient),
sessionManager: session.NewManager(30*time.Minute, sseFactory),
pendingMessages: []*ssecommon.PendingSSEMessage{},
prometheusHandler: prometheusHandler,
closedClients: make(map[string]bool),
Expand Down Expand Up @@ -190,6 +195,19 @@ func (p *HTTPSSEProxy) Stop(ctx context.Context) error {
// Signal shutdown
close(p.shutdownCh)

// Stop the session manager cleanup routine
if p.sessionManager != nil {
p.sessionManager.Stop()
}

// Disconnect all active sessions
p.sessionManager.Range(func(_, value interface{}) bool {
if sess, ok := value.(*session.SSESession); ok {
sess.Disconnect()
}
return true
})

// Stop the HTTP server
if p.server != nil {
return p.server.Shutdown(ctx)
Expand Down Expand Up @@ -226,10 +244,12 @@ func (p *HTTPSSEProxy) ForwardResponseToClients(_ context.Context, msg jsonrpc2.
// Create an SSE message
sseMsg := ssecommon.NewSSEMessage("message", string(data))

// Check if there are any connected clients
p.sseClientsMutex.RLock()
hasClients := len(p.sseClients) > 0
p.sseClientsMutex.RUnlock()
// Check if there are any connected clients by checking session count
hasClients := false
p.sessionManager.Range(func(_, _ interface{}) bool {
hasClients = true
return false // Stop iteration after finding first session
})

if hasClients {
// Send the message to all connected clients
Expand Down Expand Up @@ -258,13 +278,19 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques
// Create a channel for sending messages to this client
messageCh := make(chan string, 100)

// Register the client
p.sseClientsMutex.Lock()
p.sseClients[clientID] = &ssecommon.SSEClient{
// Create SSE client info
clientInfo := &ssecommon.SSEClient{
MessageCh: messageCh,
CreatedAt: time.Now(),
}
p.sseClientsMutex.Unlock()

// Create and register the SSE session
sseSession := session.NewSSESessionWithClient(clientID, clientInfo)
if err := p.sessionManager.AddSession(sseSession); err != nil {
logger.Errorf("Failed to add SSE session: %v", err)
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}

// Process any pending messages for this client
p.processPendingMessages(clientID, messageCh)
Expand Down Expand Up @@ -347,10 +373,7 @@ func (p *HTTPSSEProxy) handlePostRequest(w http.ResponseWriter, r *http.Request)
}

// Check if the session exists
p.sseClientsMutex.RLock()
_, exists := p.sseClients[sessionID]
p.sseClientsMutex.RUnlock()

_, exists := p.sessionManager.Get(sessionID)
if !exists {
http.Error(w, "Could not find session", http.StatusNotFound)
return
Expand Down Expand Up @@ -391,21 +414,34 @@ func (p *HTTPSSEProxy) sendSSEEvent(msg *ssecommon.SSEMessage) error {
// Convert the message to an SSE-formatted string
sseString := msg.ToSSEString()

// Hold the lock while sending to ensure channels aren't closed during send
// This is a read lock, so multiple sends can happen concurrently
p.sseClientsMutex.RLock()
defer p.sseClientsMutex.RUnlock()
// Iterate through all sessions and send to SSE sessions
p.sessionManager.Range(func(key, value interface{}) bool {
clientID, ok := key.(string)
if !ok {
return true // Continue iteration
}

sess, ok := value.(session.Session)
if !ok {
return true // Continue iteration
}

for clientID, client := range p.sseClients {
select {
case client.MessageCh <- sseString:
// Message sent successfully
default:
// Channel is full, skip this client
// Don't remove the client here - let the disconnect monitor handle it
logger.Debugf("Client %s channel full, skipping message", clientID)
// Check if this is an SSE session
if sseSession, ok := sess.(*session.SSESession); ok {
// Try to send the message
if err := sseSession.SendMessage(sseString); err != nil {
// Log the error but continue sending to other clients
switch err {
case session.ErrSessionDisconnected:
logger.Debugf("Client %s is disconnected, skipping message", clientID)
case session.ErrMessageChannelFull:
logger.Debugf("Client %s channel full, skipping message", clientID)
}
}
}
}

return true // Continue iteration
})

return nil
}
Expand All @@ -421,21 +457,20 @@ func (p *HTTPSSEProxy) removeClient(clientID string) {
p.closedClients[clientID] = true
p.closedClientsMutex.Unlock()

// Remove from clients map and get the client
// Use write lock to ensure no sends happen during removal
p.sseClientsMutex.Lock()
client, exists := p.sseClients[clientID]
if exists {
delete(p.sseClients, clientID)
// Get the session from the manager
sess, exists := p.sessionManager.Get(clientID)
if !exists {
return
}
p.sseClientsMutex.Unlock()

// Close the channel after removing from map
// This ensures no goroutine will try to send to it
if exists && client != nil {
close(client.MessageCh)
// If it's an SSE session, disconnect it
if sseSession, ok := sess.(*session.SSESession); ok {
sseSession.Disconnect()
}

// Remove the session from the manager
p.sessionManager.Delete(clientID)

// Clean up closed clients map periodically (prevent memory leak)
p.closedClientsMutex.Lock()
if len(p.closedClients) > 1000 {
Expand Down
Loading
Loading