6
6
"bufio"
7
7
"bytes"
8
8
"context"
9
+ "encoding/json"
9
10
"fmt"
10
11
"io"
11
12
"mime"
@@ -89,81 +90,137 @@ func NewTransparentProxy(
89
90
return proxy
90
91
}
91
92
92
- var sessionIDRegex = regexp .MustCompile (`sessionId=([\w-]+)` )
93
+ type tracingTransport struct {
94
+ base http.RoundTripper
95
+ p * TransparentProxy
96
+ }
97
+
98
+ func (t * tracingTransport ) setServerInitialized () {
99
+ if ! t .p .IsServerInitialized {
100
+ t .p .mutex .Lock ()
101
+ t .p .IsServerInitialized = true
102
+ t .p .mutex .Unlock ()
103
+ logger .Infof ("Server was initialized successfully for %s" , t .p .containerName )
104
+ }
105
+ }
106
+
107
+ func (t * tracingTransport ) forward (req * http.Request ) (* http.Response , error ) {
108
+ tr := t .base
109
+ if tr == nil {
110
+ tr = http .DefaultTransport
111
+ }
112
+ return tr .RoundTrip (req )
113
+ }
114
+
115
+ func (t * tracingTransport ) watchEventStream (r io.Reader , w * io.PipeWriter ) {
116
+ defer w .Close ()
117
+
118
+ scanner := bufio .NewScanner (r )
119
+ sessionRe := regexp .MustCompile (`sessionId=([0-9a-fA-F-]+)|\"sessionId\"\s*:\s*\"([^\"]+)\"` )
93
120
94
- func (p * TransparentProxy ) handleModifyResponse (res * http.Response ) error {
95
- if sid := res .Header .Get ("Mcp-Session-Id" ); sid != "" {
96
- logger .Infof ("Detected Mcp-Session-Id header: %s" , sid )
97
- if _ , ok := p .sessionManager .Get (sid ); ! ok {
98
- if _ , err := p .sessionManager .AddWithID (sid ); err != nil {
99
- logger .Errorf ("Failed to create session from header %s: %v" , sid , err )
121
+ for scanner .Scan () {
122
+ line := scanner .Text ()
123
+
124
+ if m := sessionRe .FindStringSubmatch (line ); m != nil {
125
+ sid := m [1 ]
126
+ if sid == "" {
127
+ sid = m [2 ]
128
+ }
129
+
130
+ if _ , ok := t .p .sessionManager .Get (sid ); ! ok {
131
+ _ , err := t .p .sessionManager .AddWithID (sid )
132
+ if err != nil {
133
+ logger .Errorf ("Failed to create session from event stream: %v" , err )
134
+ }
100
135
}
136
+ t .setServerInitialized ()
101
137
}
102
- p .IsServerInitialized = true
103
- return nil
104
138
}
105
139
106
- // Handle streaming (SSE)
107
- ct , _ , err := mime .ParseMediaType (res .Header .Get ("Content-Type" ))
140
+ _ , err := io .Copy (io .Discard , r )
108
141
if err != nil {
109
- logger .Warnf ("Invalid Content-Type header, defaulting behavior: %v" , err )
110
- ct = "" // or choose a fallback
142
+ logger .Errorf ("Failed to copy event stream: %v" , err )
111
143
}
112
- if ct == "text/event-stream" {
113
- pr , pw := io .Pipe ()
114
- orig := res .Body
115
- res .Body = pr
116
-
117
- go func () {
118
- defer pw .Close ()
119
- scanner := bufio .NewScanner (orig )
120
- for scanner .Scan () {
121
- line := scanner .Text ()
122
-
123
- if matches := sessionIDRegex .FindStringSubmatch (line ); len (matches ) == 2 {
124
- sessionID := matches [1 ]
125
- _ , ok := p .sessionManager .Get (sessionID )
126
- if ! ok {
127
- var err error
128
- _ , err = p .sessionManager .AddWithID (sessionID )
129
- if err != nil {
130
- logger .Errorf ("Failed to create session %s: %v" , sessionID , err )
131
- continue
132
- }
133
- }
134
- p .IsServerInitialized = true
135
- }
136
- _ , err := pw .Write ([]byte (line + "\n " ))
137
- if err != nil {
138
- logger .Errorf ("Failed to write to pipe: %v" , err )
144
+ }
145
+
146
+ func (t * tracingTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
147
+ reqBody := readRequestBody (req )
148
+
149
+ path := req .URL .Path
150
+ isMCP := strings .HasPrefix (path , "/mcp" )
151
+ isJSON := strings .Contains (req .Header .Get ("Content-Type" ), "application/json" )
152
+ sawInitialize := false
153
+
154
+ if isMCP && isJSON && len (reqBody ) > 0 {
155
+ sawInitialize = t .detectInitialize (reqBody )
156
+ }
157
+
158
+ resp , err := t .forward (req )
159
+ if err != nil {
160
+ logger .Errorf ("Failed to forward request: %v" , err )
161
+ return nil , err
162
+ }
163
+
164
+ if resp .StatusCode == http .StatusOK {
165
+ // check if we saw a valid mcp header
166
+ ct := resp .Header .Get ("Mcp-Session-Id" )
167
+ if ct != "" {
168
+ logger .Infof ("Detected Mcp-Session-Id header: %s" , ct )
169
+ if _ , ok := t .p .sessionManager .Get (ct ); ! ok {
170
+ if _ , err := t .p .sessionManager .AddWithID (ct ); err != nil {
171
+ logger .Errorf ("Failed to create session from header %s: %v" , ct , err )
139
172
}
140
173
}
141
- }()
142
- return nil
174
+ t .setServerInitialized ()
175
+ return resp , nil
176
+ }
177
+ // status was ok and we saw an initialize call
178
+ if sawInitialize && ! t .p .IsServerInitialized {
179
+ t .setServerInitialized ()
180
+ return resp , nil
181
+ }
182
+ ct = resp .Header .Get ("Content-Type" )
183
+ mediaType , _ , _ := mime .ParseMediaType (ct )
184
+ if mediaType == "text/event-stream" {
185
+ originalBody := resp .Body
186
+ pr , pw := io .Pipe ()
187
+ tee := io .TeeReader (originalBody , pw )
188
+ resp .Body = pr
189
+
190
+ go t .watchEventStream (tee , pw )
191
+ }
143
192
}
144
193
145
- return nil
194
+ return resp , nil
146
195
}
147
196
148
- func (p * TransparentProxy ) handleAndDetectInitialize (w http.ResponseWriter , r * http.Request , proxy * httputil.ReverseProxy ) {
149
- logger .Infof ("Transparent proxy: %s %s -> %s" , r .Method , r .URL .Path , p .targetURI )
150
-
151
- if r .Method == http .MethodPost && strings .HasSuffix (r .URL .Path , "/mcp" ) {
152
- // Read the body for inspection without consuming it
153
- body , err := io .ReadAll (r .Body )
197
+ func readRequestBody (req * http.Request ) []byte {
198
+ reqBody := []byte {}
199
+ if req .Body != nil {
200
+ buf , err := io .ReadAll (req .Body )
154
201
if err != nil {
155
- logger .Errorf ("Error reading request body: %v" , err )
202
+ logger .Errorf ("Failed to read request body: %v" , err )
156
203
} else {
157
- if bytes .Contains (body , []byte (`"method":"initialize"` )) {
158
- logger .Infof ("Detected initialize request to %s" , r .URL .Path )
159
- p .IsServerInitialized = true
160
- }
161
- r .Body = io .NopCloser (bytes .NewReader (body ))
162
- r .ContentLength = int64 (len (body ))
204
+ reqBody = buf
163
205
}
206
+ req .Body = io .NopCloser (bytes .NewReader (reqBody ))
164
207
}
208
+ return reqBody
209
+ }
165
210
166
- proxy .ServeHTTP (w , r )
211
+ func (t * tracingTransport ) detectInitialize (body []byte ) bool {
212
+ var rpc struct {
213
+ Method string `json:"method"`
214
+ }
215
+ if err := json .Unmarshal (body , & rpc ); err != nil {
216
+ logger .Errorf ("Failed to parse JSON-RPC body: %v" , err )
217
+ return false
218
+ }
219
+ if rpc .Method == "initialize" {
220
+ logger .Infof ("Detected initialize method call for %s" , t .p .containerName )
221
+ return true
222
+ }
223
+ return false
167
224
}
168
225
169
226
// Start starts the transparent proxy.
@@ -179,11 +236,12 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
179
236
180
237
// Create a reverse proxy
181
238
proxy := httputil .NewSingleHostReverseProxy (targetURL )
182
- proxy .ModifyResponse = p . handleModifyResponse
239
+ proxy .Transport = & tracingTransport { base : http . DefaultTransport , p : p }
183
240
184
241
// Create a handler that logs requests
185
242
handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
186
- p .handleAndDetectInitialize (w , r , proxy )
243
+ logger .Infof ("Transparent proxy: %s %s -> %s" , r .Method , r .URL .Path , targetURL )
244
+ proxy .ServeHTTP (w , r )
187
245
})
188
246
189
247
// Create a mux to handle both proxy and health endpoints
0 commit comments