@@ -13,7 +13,6 @@ import (
13
13
"net/http/httputil"
14
14
"net/url"
15
15
"regexp"
16
- "strconv"
17
16
"strings"
18
17
"sync"
19
18
"time"
@@ -56,16 +55,13 @@ type TransparentProxy struct {
56
55
// Optional Prometheus metrics handler
57
56
prometheusHandler http.Handler
58
57
59
- // Sessions for managing client connections
60
- sessions map [ string ] session.Session
58
+ // Sessions for tracking state
59
+ sessionManager * session.Manager
61
60
62
- // mutex for protecting session access
63
- sessionMutex sync. Mutex
61
+ // If mcp server has been initialized
62
+ IsServerInitialized bool
64
63
}
65
64
66
- // TransparentProxySessionID is the session ID used for the transparent proxy.
67
- const TransparentProxySessionID = "transparent-proxy-session"
68
-
69
65
// NewTransparentProxy creates a new transparent proxy with optional middlewares.
70
66
func NewTransparentProxy (
71
67
host string ,
@@ -83,53 +79,58 @@ func NewTransparentProxy(
83
79
middlewares : middlewares ,
84
80
shutdownCh : make (chan struct {}),
85
81
prometheusHandler : prometheusHandler ,
86
- sessions : make ( map [ string ]session. Session ),
82
+ sessionManager : session . NewManager ( 30 * time . Minute ),
87
83
}
88
84
89
85
// Create MCP pinger and health checker
90
86
mcpPinger := NewMCPPinger (targetURI )
91
87
proxy .healthChecker = healthcheck .NewHealthChecker ("sse" , mcpPinger )
92
88
93
- _ , err := proxy .CreateSession (TransparentProxySessionID )
94
- if err != nil {
95
- logger .Errorf ("Failed to create session for TransparentProxy: %v" , err )
96
- return nil
97
- }
98
-
99
89
return proxy
100
90
}
101
91
92
+ var sessionIDRegex = regexp .MustCompile (`sessionId=([\w-]+)` )
93
+
102
94
func (p * TransparentProxy ) handleModifyResponse (res * http.Response ) error {
103
- // Log headers
104
95
if sid := res .Header .Get ("Mcp-Session-Id" ); sid != "" {
105
- logger .Infof ("🆔 Streamable session ID from header: %s" , sid )
96
+ logger .Infof ("Detected Mcp-Session-Id header: %s" , sid )
97
+ if _ , ok := p .sessionManager .Get (sid ); ! ok {
98
+ if _ , err := p .sessionManager .AddWithID (sid ); err != nil {
99
+ logger .Errorf ("Failed to create session from header %s: %v" , sid , err )
100
+ }
101
+ }
102
+ p .IsServerInitialized = true
106
103
}
107
104
108
105
// Handle streaming (SSE)
109
- if ct , _ , _ := mime .ParseMediaType (res .Header .Get ("Content-Type" )); ct == "text/event-stream" {
106
+ ct , _ , err := mime .ParseMediaType (res .Header .Get ("Content-Type" ))
107
+ if err != nil {
108
+ logger .Errorf ("Failed to parse Content-Type: %v" , err )
109
+ return err
110
+ }
111
+ if ct == "text/event-stream" {
110
112
pr , pw := io .Pipe ()
111
113
orig := res .Body
112
114
res .Body = pr
113
115
114
116
go func () {
115
117
defer pw .Close ()
116
118
scanner := bufio .NewScanner (orig )
117
- re := regexp .MustCompile (`sessionId=([\w-]+)` ) // Capture UUID-like IDs
118
-
119
119
for scanner .Scan () {
120
120
line := scanner .Text ()
121
121
122
- if matches := re .FindStringSubmatch (line ); len (matches ) == 2 {
122
+ if matches := sessionIDRegex .FindStringSubmatch (line ); len (matches ) == 2 {
123
123
sessionID := matches [1 ]
124
-
125
- // set session id for proxy
126
- extractedSession , ok := p .GetSession (TransparentProxySessionID )
124
+ _ , ok := p .sessionManager .Get (sessionID )
127
125
if ! ok {
128
- logger .Errorf ("Failed to get session for TransparentProxy" )
129
- continue
126
+ var err error
127
+ _ , err = p .sessionManager .AddWithID (sessionID )
128
+ if err != nil {
129
+ logger .Errorf ("Failed to create session %s: %v" , sessionID , err )
130
+ continue
131
+ }
130
132
}
131
- extractedSession .SetIsInitialized (true )
132
- extractedSession .SetMCPSessionID (sessionID )
133
+ p .IsServerInitialized = true
133
134
}
134
135
_ , err := pw .Write ([]byte (line + "\n " ))
135
136
if err != nil {
@@ -140,29 +141,6 @@ func (p *TransparentProxy) handleModifyResponse(res *http.Response) error {
140
141
return nil
141
142
}
142
143
143
- // Handle non-streaming (JSON) bodies
144
- body , err := io .ReadAll (res .Body )
145
- if err != nil {
146
- return err
147
- }
148
- err = res .Body .Close ()
149
- if err != nil {
150
- logger .Errorf ("Failed to close response body: %v" , err )
151
- }
152
-
153
- text := string (body )
154
- logger .Infof ("HTTP response body: %s" , text )
155
-
156
- // Parse sessionId if embedded in JSON
157
- rejson := regexp .MustCompile (`"sessionId"\s*:\s*"([\w-]+)"` )
158
- if matches := rejson .FindStringSubmatch (text ); len (matches ) == 2 {
159
- sessionID := matches [1 ]
160
- logger .Infof ("🆔 Captured sessionId from JSON: %s" , sessionID )
161
- }
162
-
163
- res .Body = io .NopCloser (bytes .NewReader (body ))
164
- res .ContentLength = int64 (len (body ))
165
- res .Header .Set ("Content-Length" , strconv .Itoa (len (body )))
166
144
return nil
167
145
}
168
146
@@ -176,13 +154,8 @@ func (p *TransparentProxy) handleAndDetectInitialize(w http.ResponseWriter, r *h
176
154
logger .Errorf ("Error reading request body: %v" , err )
177
155
} else {
178
156
if bytes .Contains (body , []byte (`"method":"initialize"` )) {
179
- logger .Infof ("🔧 Detected initialize request to %s" , r .URL .Path )
180
- extractedSession , ok := p .GetSession (TransparentProxySessionID )
181
- if ok {
182
- extractedSession .SetIsInitialized (true )
183
- } else {
184
- logger .Errorf ("No session found to mark initialized" )
185
- }
157
+ logger .Infof ("Detected initialize request to %s" , r .URL .Path )
158
+ p .IsServerInitialized = true
186
159
}
187
160
r .Body = io .NopCloser (bytes .NewReader (body ))
188
161
r .ContentLength = int64 (len (body ))
@@ -278,13 +251,7 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
278
251
return
279
252
case <- ticker .C :
280
253
// Perform health check only if mcp server has been initialized
281
- extractedSession , isSession := p .GetSession (TransparentProxySessionID )
282
- if ! isSession {
283
- logger .Errorf ("Failed to get session for health check" )
284
- return
285
- }
286
-
287
- if extractedSession .IsInitialized () {
254
+ if p .IsServerInitialized {
288
255
alive := p .healthChecker .CheckHealth (parentCtx )
289
256
if alive .Status != healthcheck .StatusHealthy {
290
257
logger .Infof ("Health check failed for %s; initiating proxy shutdown" , p .containerName )
@@ -294,7 +261,7 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
294
261
return
295
262
}
296
263
} else {
297
- logger .Infof ("Session %s is not initialized, cannot start healthcheck " , extractedSession . ID () )
264
+ logger .Infof ("MCP server not initialized yet, skipping health check for %s " , p . containerName )
298
265
}
299
266
}
300
267
}
@@ -347,32 +314,3 @@ func (*TransparentProxy) SendMessageToDestination(_ jsonrpc2.Message) error {
347
314
func (* TransparentProxy ) ForwardResponseToClients (_ context.Context , _ jsonrpc2.Message ) error {
348
315
return fmt .Errorf ("ForwardResponseToClients not implemented for TransparentProxy" )
349
316
}
350
-
351
- // CreateSession creates a new session for the transparent proxy.
352
- func (p * TransparentProxy ) CreateSession (id string ) (session.Session , error ) {
353
- logger .Infof ("Creating session with ID: %s" , id )
354
- if id == "" {
355
- return nil , fmt .Errorf ("session ID cannot be empty" )
356
- }
357
- p .sessionMutex .Lock ()
358
- defer p .sessionMutex .Unlock ()
359
- if _ , found := p .sessions [id ]; found {
360
- return nil , fmt .Errorf ("session %s exists" , id )
361
- }
362
- s := & session.ProxySession {Id : id }
363
- s .Init ()
364
- p .sessions [id ] = s
365
- return s , nil
366
- }
367
-
368
- // GetSession retrieves a session by ID.
369
- func (p * TransparentProxy ) GetSession (id string ) (session.Session , bool ) {
370
- logger .Infof ("Retrieving session with ID: %s" , id )
371
- if id == "" {
372
- return nil , false
373
- }
374
- p .sessionMutex .Lock ()
375
- defer p .sessionMutex .Unlock ()
376
- s , ok := p .sessions [id ]
377
- return s , ok
378
- }
0 commit comments