Skip to content

Commit a7f4b00

Browse files
committed
fixes from review
1 parent 0e2745f commit a7f4b00

File tree

4 files changed

+161
-134
lines changed

4 files changed

+161
-134
lines changed

pkg/transport/proxy/transparent/transparent_proxy.go

Lines changed: 27 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"net/http/httputil"
1414
"net/url"
1515
"regexp"
16-
"strconv"
1716
"strings"
1817
"sync"
1918
"time"
@@ -56,16 +55,13 @@ type TransparentProxy struct {
5655
// Optional Prometheus metrics handler
5756
prometheusHandler http.Handler
5857

59-
// Sessions for managing client connections
60-
sessions map[string]session.Session
58+
// Sessions for tracking state
59+
sessionManager *session.Manager
6160

62-
// mutex for protecting session access
63-
sessionMutex sync.Mutex
61+
// If mcp server has been initialized
62+
IsServerInitialized bool
6463
}
6564

66-
// TransparentProxySessionID is the session ID used for the transparent proxy.
67-
const TransparentProxySessionID = "transparent-proxy-session"
68-
6965
// NewTransparentProxy creates a new transparent proxy with optional middlewares.
7066
func NewTransparentProxy(
7167
host string,
@@ -83,26 +79,27 @@ func NewTransparentProxy(
8379
middlewares: middlewares,
8480
shutdownCh: make(chan struct{}),
8581
prometheusHandler: prometheusHandler,
86-
sessions: make(map[string]session.Session),
82+
sessionManager: session.NewManager(30 * time.Minute),
8783
}
8884

8985
// Create MCP pinger and health checker
9086
mcpPinger := NewMCPPinger(targetURI)
9187
proxy.healthChecker = healthcheck.NewHealthChecker("sse", mcpPinger)
9288

93-
_, err := proxy.CreateSession(TransparentProxySessionID)
94-
if err != nil {
95-
logger.Errorf("Failed to create session for TransparentProxy: %v", err)
96-
return nil
97-
}
98-
9989
return proxy
10090
}
10191

92+
var sessionIDRegex = regexp.MustCompile(`sessionId=([\w-]+)`)
93+
10294
func (p *TransparentProxy) handleModifyResponse(res *http.Response) error {
103-
// Log headers
10495
if sid := res.Header.Get("Mcp-Session-Id"); sid != "" {
105-
logger.Infof("🆔 Streamable session ID from header: %s", sid)
96+
logger.Infof("Detected Mcp-Session-Id header: %s", sid)
97+
if _, ok := p.sessionManager.Get(sid); !ok {
98+
if _, err := p.sessionManager.AddWithID(sid); err != nil {
99+
logger.Errorf("Failed to create session from header %s: %v", sid, err)
100+
}
101+
}
102+
p.IsServerInitialized = true
106103
}
107104

108105
// Handle streaming (SSE)
@@ -114,22 +111,21 @@ func (p *TransparentProxy) handleModifyResponse(res *http.Response) error {
114111
go func() {
115112
defer pw.Close()
116113
scanner := bufio.NewScanner(orig)
117-
re := regexp.MustCompile(`sessionId=([\w-]+)`) // Capture UUID-like IDs
118-
119114
for scanner.Scan() {
120115
line := scanner.Text()
121116

122-
if matches := re.FindStringSubmatch(line); len(matches) == 2 {
117+
if matches := sessionIDRegex.FindStringSubmatch(line); len(matches) == 2 {
123118
sessionID := matches[1]
124-
125-
// set session id for proxy
126-
extractedSession, ok := p.GetSession(TransparentProxySessionID)
119+
_, ok := p.sessionManager.Get(sessionID)
127120
if !ok {
128-
logger.Errorf("Failed to get session for TransparentProxy")
129-
continue
121+
var err error
122+
_, err = p.sessionManager.AddWithID(sessionID)
123+
if err != nil {
124+
logger.Errorf("Failed to create session %s: %v", sessionID, err)
125+
continue
126+
}
130127
}
131-
extractedSession.SetIsInitialized(true)
132-
extractedSession.SetMCPSessionID(sessionID)
128+
p.IsServerInitialized = true
133129
}
134130
_, err := pw.Write([]byte(line + "\n"))
135131
if err != nil {
@@ -140,29 +136,6 @@ func (p *TransparentProxy) handleModifyResponse(res *http.Response) error {
140136
return nil
141137
}
142138

143-
// Handle non-streaming (JSON) bodies
144-
body, err := io.ReadAll(res.Body)
145-
if err != nil {
146-
return err
147-
}
148-
err = res.Body.Close()
149-
if err != nil {
150-
logger.Errorf("Failed to close response body: %v", err)
151-
}
152-
153-
text := string(body)
154-
logger.Infof("HTTP response body: %s", text)
155-
156-
// Parse sessionId if embedded in JSON
157-
rejson := regexp.MustCompile(`"sessionId"\s*:\s*"([\w-]+)"`)
158-
if matches := rejson.FindStringSubmatch(text); len(matches) == 2 {
159-
sessionID := matches[1]
160-
logger.Infof("🆔 Captured sessionId from JSON: %s", sessionID)
161-
}
162-
163-
res.Body = io.NopCloser(bytes.NewReader(body))
164-
res.ContentLength = int64(len(body))
165-
res.Header.Set("Content-Length", strconv.Itoa(len(body)))
166139
return nil
167140
}
168141

@@ -176,13 +149,8 @@ func (p *TransparentProxy) handleAndDetectInitialize(w http.ResponseWriter, r *h
176149
logger.Errorf("Error reading request body: %v", err)
177150
} else {
178151
if bytes.Contains(body, []byte(`"method":"initialize"`)) {
179-
logger.Infof("🔧 Detected initialize request to %s", r.URL.Path)
180-
extractedSession, ok := p.GetSession(TransparentProxySessionID)
181-
if ok {
182-
extractedSession.SetIsInitialized(true)
183-
} else {
184-
logger.Errorf("No session found to mark initialized")
185-
}
152+
logger.Infof("Detected initialize request to %s", r.URL.Path)
153+
p.IsServerInitialized = true
186154
}
187155
r.Body = io.NopCloser(bytes.NewReader(body))
188156
r.ContentLength = int64(len(body))
@@ -278,13 +246,7 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
278246
return
279247
case <-ticker.C:
280248
// Perform health check only if mcp server has been initialized
281-
extractedSession, isSession := p.GetSession(TransparentProxySessionID)
282-
if !isSession {
283-
logger.Errorf("Failed to get session for health check")
284-
return
285-
}
286-
287-
if extractedSession.IsInitialized() {
249+
if p.IsServerInitialized {
288250
alive := p.healthChecker.CheckHealth(parentCtx)
289251
if alive.Status != healthcheck.StatusHealthy {
290252
logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName)
@@ -294,7 +256,7 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
294256
return
295257
}
296258
} else {
297-
logger.Infof("Session %s is not initialized, cannot start healthcheck", extractedSession.ID())
259+
logger.Infof("MCP server not initialized yet, skipping health check for %s", p.containerName)
298260
}
299261
}
300262
}
@@ -347,32 +309,3 @@ func (*TransparentProxy) SendMessageToDestination(_ jsonrpc2.Message) error {
347309
func (*TransparentProxy) ForwardResponseToClients(_ context.Context, _ jsonrpc2.Message) error {
348310
return fmt.Errorf("ForwardResponseToClients not implemented for TransparentProxy")
349311
}
350-
351-
// CreateSession creates a new session for the transparent proxy.
352-
func (p *TransparentProxy) CreateSession(id string) (session.Session, error) {
353-
logger.Infof("Creating session with ID: %s", id)
354-
if id == "" {
355-
return nil, fmt.Errorf("session ID cannot be empty")
356-
}
357-
p.sessionMutex.Lock()
358-
defer p.sessionMutex.Unlock()
359-
if _, found := p.sessions[id]; found {
360-
return nil, fmt.Errorf("session %s exists", id)
361-
}
362-
s := &session.ProxySession{Id: id}
363-
s.Init()
364-
p.sessions[id] = s
365-
return s, nil
366-
}
367-
368-
// GetSession retrieves a session by ID.
369-
func (p *TransparentProxy) GetSession(id string) (session.Session, bool) {
370-
logger.Infof("Retrieving session with ID: %s", id)
371-
if id == "" {
372-
return nil, false
373-
}
374-
p.sessionMutex.Lock()
375-
defer p.sessionMutex.Unlock()
376-
s, ok := p.sessions[id]
377-
return s, ok
378-
}

pkg/transport/session/manager.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Package session provides a session manager with TTL cleanup.
2+
package session
3+
4+
import (
5+
"fmt"
6+
"sync"
7+
"time"
8+
)
9+
10+
// Session interface
11+
type Session interface {
12+
ID() string
13+
CreatedAt() time.Time
14+
UpdatedAt() time.Time
15+
Touch()
16+
}
17+
18+
// Manager holds sessions with TTL cleanup.
19+
type Manager struct {
20+
sessions map[string]Session
21+
mu sync.RWMutex
22+
ttl time.Duration
23+
stopCh chan struct{}
24+
}
25+
26+
// NewManager creates a session manager with TTL and starts cleanup worker.
27+
func NewManager(ttl time.Duration) *Manager {
28+
m := &Manager{
29+
sessions: make(map[string]Session),
30+
ttl: ttl,
31+
stopCh: make(chan struct{}),
32+
}
33+
go m.cleanupRoutine()
34+
return m
35+
}
36+
37+
func (m *Manager) cleanupRoutine() {
38+
ticker := time.NewTicker(m.ttl / 2)
39+
defer ticker.Stop()
40+
for {
41+
select {
42+
case <-ticker.C:
43+
m.CleanupExpired()
44+
case <-m.stopCh:
45+
return
46+
}
47+
}
48+
}
49+
50+
// AddWithID creates (and adds) a new session with the provided ID.
51+
// Returns error if ID is empty or already exists.
52+
func (m *Manager) AddWithID(id string) (Session, error) {
53+
if id == "" {
54+
return nil, fmt.Errorf("session ID cannot be empty")
55+
}
56+
57+
m.mu.Lock()
58+
defer m.mu.Unlock()
59+
60+
if _, exists := m.sessions[id]; exists {
61+
return nil, fmt.Errorf("session ID %q already exists", id)
62+
}
63+
64+
s := NewProxySession(id)
65+
m.sessions[id] = s
66+
return s, nil
67+
}
68+
69+
// Get retrieves a session by ID. Returns (session, true) if found,
70+
// and also updates its UpdatedAt timestamp.
71+
func (m *Manager) Get(id string) (Session, bool) {
72+
m.mu.RLock()
73+
s, ok := m.sessions[id]
74+
m.mu.RUnlock()
75+
76+
if !ok {
77+
return nil, false
78+
}
79+
80+
s.Touch()
81+
return s, true
82+
}
83+
84+
// Delete removes a session by ID.
85+
func (m *Manager) Delete(id string) {
86+
m.mu.Lock()
87+
delete(m.sessions, id)
88+
m.mu.Unlock()
89+
}
90+
91+
// CleanupExpired removes sessions that have not been updated within the TTL.
92+
func (m *Manager) CleanupExpired() {
93+
cutoff := time.Now().Add(-m.ttl)
94+
m.mu.Lock()
95+
defer m.mu.Unlock()
96+
for id, s := range m.sessions {
97+
if s.UpdatedAt().Before(cutoff) {
98+
delete(m.sessions, id)
99+
}
100+
}
101+
}
102+
103+
// Stop stops the cleanup worker.
104+
func (m *Manager) Stop() {
105+
close(m.stopCh)
106+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package session
2+
3+
import "time"
4+
5+
// ProxySession implements the Session interface for proxy sessions.
6+
type ProxySession struct {
7+
id string
8+
created time.Time
9+
updated time.Time
10+
}
11+
12+
// NewProxySession creates a new ProxySession with the given ID.
13+
func NewProxySession(id string) *ProxySession {
14+
now := time.Now()
15+
return &ProxySession{id: id, created: now, updated: now}
16+
}
17+
18+
// ID returns the session ID.
19+
func (s *ProxySession) ID() string { return s.id }
20+
21+
// CreatedAt returns the creation time of the session.
22+
func (s *ProxySession) CreatedAt() time.Time { return s.created }
23+
24+
// UpdatedAt returns the last updated time of the session.
25+
func (s *ProxySession) UpdatedAt() time.Time { return s.updated }
26+
27+
// Touch updates the session's last updated time to the current time.
28+
func (s *ProxySession) Touch() { s.updated = time.Now() }

pkg/transport/session/session.go

Lines changed: 0 additions & 40 deletions
This file was deleted.

0 commit comments

Comments
 (0)