Skip to content

Commit 293c715

Browse files
committed
refactor session management
1 parent 5211d82 commit 293c715

File tree

2 files changed

+122
-64
lines changed

2 files changed

+122
-64
lines changed

pkg/transport/proxy/transparent/transparent_proxy.go

Lines changed: 117 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"bufio"
77
"bytes"
88
"context"
9+
"encoding/json"
910
"fmt"
1011
"io"
1112
"mime"
@@ -89,81 +90,137 @@ func NewTransparentProxy(
8990
return proxy
9091
}
9192

92-
var sessionIDRegex = regexp.MustCompile(`sessionId=([\w-]+)`)
93+
type tracingTransport struct {
94+
base http.RoundTripper
95+
p *TransparentProxy
96+
}
97+
98+
func (t *tracingTransport) setServerInitialized() {
99+
if !t.p.IsServerInitialized {
100+
t.p.mutex.Lock()
101+
t.p.IsServerInitialized = true
102+
t.p.mutex.Unlock()
103+
logger.Infof("Server was initialized successfully for %s", t.p.containerName)
104+
}
105+
}
106+
107+
func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) {
108+
tr := t.base
109+
if tr == nil {
110+
tr = http.DefaultTransport
111+
}
112+
return tr.RoundTrip(req)
113+
}
114+
115+
func (t *tracingTransport) watchEventStream(r io.Reader, w *io.PipeWriter) {
116+
defer w.Close()
117+
118+
scanner := bufio.NewScanner(r)
119+
sessionRe := regexp.MustCompile(`sessionId=([0-9a-fA-F-]+)|\"sessionId\"\s*:\s*\"([^\"]+)\"`)
93120

94-
func (p *TransparentProxy) handleModifyResponse(res *http.Response) error {
95-
if sid := res.Header.Get("Mcp-Session-Id"); sid != "" {
96-
logger.Infof("Detected Mcp-Session-Id header: %s", sid)
97-
if _, ok := p.sessionManager.Get(sid); !ok {
98-
if _, err := p.sessionManager.AddWithID(sid); err != nil {
99-
logger.Errorf("Failed to create session from header %s: %v", sid, err)
121+
for scanner.Scan() {
122+
line := scanner.Text()
123+
124+
if m := sessionRe.FindStringSubmatch(line); m != nil {
125+
sid := m[1]
126+
if sid == "" {
127+
sid = m[2]
128+
}
129+
130+
if _, ok := t.p.sessionManager.Get(sid); !ok {
131+
_, err := t.p.sessionManager.AddWithID(sid)
132+
if err != nil {
133+
logger.Errorf("Failed to create session from event stream: %v", err)
134+
}
100135
}
136+
t.setServerInitialized()
101137
}
102-
p.IsServerInitialized = true
103-
return nil
104138
}
105139

106-
// Handle streaming (SSE)
107-
ct, _, err := mime.ParseMediaType(res.Header.Get("Content-Type"))
140+
_, err := io.Copy(io.Discard, r)
108141
if err != nil {
109-
logger.Warnf("Invalid Content-Type header, defaulting behavior: %v", err)
110-
ct = "" // or choose a fallback
142+
logger.Errorf("Failed to copy event stream: %v", err)
111143
}
112-
if ct == "text/event-stream" {
113-
pr, pw := io.Pipe()
114-
orig := res.Body
115-
res.Body = pr
116-
117-
go func() {
118-
defer pw.Close()
119-
scanner := bufio.NewScanner(orig)
120-
for scanner.Scan() {
121-
line := scanner.Text()
122-
123-
if matches := sessionIDRegex.FindStringSubmatch(line); len(matches) == 2 {
124-
sessionID := matches[1]
125-
_, ok := p.sessionManager.Get(sessionID)
126-
if !ok {
127-
var err error
128-
_, err = p.sessionManager.AddWithID(sessionID)
129-
if err != nil {
130-
logger.Errorf("Failed to create session %s: %v", sessionID, err)
131-
continue
132-
}
133-
}
134-
p.IsServerInitialized = true
135-
}
136-
_, err := pw.Write([]byte(line + "\n"))
137-
if err != nil {
138-
logger.Errorf("Failed to write to pipe: %v", err)
144+
}
145+
146+
func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
147+
reqBody := readRequestBody(req)
148+
149+
path := req.URL.Path
150+
isMCP := strings.HasPrefix(path, "/mcp")
151+
isJSON := strings.Contains(req.Header.Get("Content-Type"), "application/json")
152+
sawInitialize := false
153+
154+
if isMCP && isJSON && len(reqBody) > 0 {
155+
sawInitialize = t.detectInitialize(reqBody)
156+
}
157+
158+
resp, err := t.forward(req)
159+
if err != nil {
160+
logger.Errorf("Failed to forward request: %v", err)
161+
return nil, err
162+
}
163+
164+
if resp.StatusCode == http.StatusOK {
165+
// check if we saw a valid mcp header
166+
ct := resp.Header.Get("Mcp-Session-Id")
167+
if ct != "" {
168+
logger.Infof("Detected Mcp-Session-Id header: %s", ct)
169+
if _, ok := t.p.sessionManager.Get(ct); !ok {
170+
if _, err := t.p.sessionManager.AddWithID(ct); err != nil {
171+
logger.Errorf("Failed to create session from header %s: %v", ct, err)
139172
}
140173
}
141-
}()
142-
return nil
174+
t.setServerInitialized()
175+
return resp, nil
176+
}
177+
// status was ok and we saw an initialize call
178+
if sawInitialize && !t.p.IsServerInitialized {
179+
t.setServerInitialized()
180+
return resp, nil
181+
}
182+
ct = resp.Header.Get("Content-Type")
183+
mediaType, _, _ := mime.ParseMediaType(ct)
184+
if mediaType == "text/event-stream" {
185+
originalBody := resp.Body
186+
pr, pw := io.Pipe()
187+
tee := io.TeeReader(originalBody, pw)
188+
resp.Body = pr
189+
190+
go t.watchEventStream(tee, pw)
191+
}
143192
}
144193

145-
return nil
194+
return resp, nil
146195
}
147196

148-
func (p *TransparentProxy) handleAndDetectInitialize(w http.ResponseWriter, r *http.Request, proxy *httputil.ReverseProxy) {
149-
logger.Infof("Transparent proxy: %s %s -> %s", r.Method, r.URL.Path, p.targetURI)
150-
151-
if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/mcp") {
152-
// Read the body for inspection without consuming it
153-
body, err := io.ReadAll(r.Body)
197+
func readRequestBody(req *http.Request) []byte {
198+
reqBody := []byte{}
199+
if req.Body != nil {
200+
buf, err := io.ReadAll(req.Body)
154201
if err != nil {
155-
logger.Errorf("Error reading request body: %v", err)
202+
logger.Errorf("Failed to read request body: %v", err)
156203
} else {
157-
if bytes.Contains(body, []byte(`"method":"initialize"`)) {
158-
logger.Infof("Detected initialize request to %s", r.URL.Path)
159-
p.IsServerInitialized = true
160-
}
161-
r.Body = io.NopCloser(bytes.NewReader(body))
162-
r.ContentLength = int64(len(body))
204+
reqBody = buf
163205
}
206+
req.Body = io.NopCloser(bytes.NewReader(reqBody))
164207
}
208+
return reqBody
209+
}
165210

166-
proxy.ServeHTTP(w, r)
211+
func (t *tracingTransport) detectInitialize(body []byte) bool {
212+
var rpc struct {
213+
Method string `json:"method"`
214+
}
215+
if err := json.Unmarshal(body, &rpc); err != nil {
216+
logger.Errorf("Failed to parse JSON-RPC body: %v", err)
217+
return false
218+
}
219+
if rpc.Method == "initialize" {
220+
logger.Infof("Detected initialize method call for %s", t.p.containerName)
221+
return true
222+
}
223+
return false
167224
}
168225

169226
// Start starts the transparent proxy.
@@ -179,11 +236,12 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
179236

180237
// Create a reverse proxy
181238
proxy := httputil.NewSingleHostReverseProxy(targetURL)
182-
proxy.ModifyResponse = p.handleModifyResponse
239+
proxy.Transport = &tracingTransport{base: http.DefaultTransport, p: p}
183240

184241
// Create a handler that logs requests
185242
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
186-
p.handleAndDetectInitialize(w, r, proxy)
243+
logger.Infof("Transparent proxy: %s %s -> %s", r.Method, r.URL.Path, targetURL)
244+
proxy.ServeHTTP(w, r)
187245
})
188246

189247
// Create a mux to handle both proxy and health endpoints

pkg/transport/session/manager.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type Session interface {
1717

1818
// Manager holds sessions with TTL cleanup.
1919
type Manager struct {
20-
sessions map[string]Session
20+
sessions map[string]*Session
2121
mu sync.RWMutex
2222
ttl time.Duration
2323
stopCh chan struct{}
@@ -26,7 +26,7 @@ type Manager struct {
2626
// NewManager creates a session manager with TTL and starts cleanup worker.
2727
func NewManager(ttl time.Duration) *Manager {
2828
m := &Manager{
29-
sessions: make(map[string]Session),
29+
sessions: make(map[string]*Session),
3030
ttl: ttl,
3131
stopCh: make(chan struct{}),
3232
}
@@ -68,16 +68,16 @@ func (m *Manager) AddWithID(id string) (Session, error) {
6868

6969
// Get retrieves a session by ID. Returns (session, true) if found,
7070
// and also updates its UpdatedAt timestamp.
71-
func (m *Manager) Get(id string) (Session, bool) {
71+
func (m *Manager) Get(id string) (*Session, bool) {
7272
m.mu.RLock()
7373
s, ok := m.sessions[id]
74-
m.mu.RUnlock()
75-
7674
if !ok {
7775
return nil, false
7876
}
7977

8078
s.Touch()
79+
m.mu.RUnlock()
80+
8181
return s, true
8282
}
8383

0 commit comments

Comments
 (0)