Skip to content

Commit ed1e12f

Browse files
committed
fixes from copilot
1 parent 9f62a12 commit ed1e12f

File tree

3 files changed

+109
-71
lines changed

3 files changed

+109
-71
lines changed

pkg/transport/proxy/transparent/transparent_proxy.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func NewTransparentProxy(
8080
middlewares: middlewares,
8181
shutdownCh: make(chan struct{}),
8282
prometheusHandler: prometheusHandler,
83-
sessionManager: session.NewManager(30 * time.Minute),
83+
sessionManager: session.NewManager(30*time.Minute, session.NewProxySession),
8484
}
8585

8686
// Create MCP pinger and health checker
@@ -95,12 +95,12 @@ type tracingTransport struct {
9595
p *TransparentProxy
9696
}
9797

98-
func (t *tracingTransport) setServerInitialized() {
99-
if !t.p.IsServerInitialized {
100-
t.p.mutex.Lock()
101-
t.p.IsServerInitialized = true
102-
t.p.mutex.Unlock()
103-
logger.Infof("Server was initialized successfully for %s", t.p.containerName)
98+
func (p *TransparentProxy) setServerInitialized() {
99+
if !p.IsServerInitialized {
100+
p.mutex.Lock()
101+
p.IsServerInitialized = true
102+
p.mutex.Unlock()
103+
logger.Infof("Server was initialized successfully for %s", p.containerName)
104104
}
105105
}
106106

@@ -136,21 +136,16 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error)
136136
if ct != "" {
137137
logger.Infof("Detected Mcp-Session-Id header: %s", ct)
138138
if _, ok := t.p.sessionManager.Get(ct); !ok {
139-
fmt.Println("i get session id")
140139
if err := t.p.sessionManager.AddWithID(ct); err != nil {
141-
fmt.Println("i add session")
142140
logger.Errorf("Failed to create session from header %s: %v", ct, err)
143141
}
144-
fmt.Println("i set server initialized")
145142
}
146-
fmt.Println("i set server initialized")
147-
t.setServerInitialized()
143+
t.p.setServerInitialized()
148144
return resp, nil
149145
}
150146
// status was ok and we saw an initialize call
151147
if sawInitialize && !t.p.IsServerInitialized {
152-
fmt.Println("here")
153-
t.setServerInitialized()
148+
t.p.setServerInitialized()
154149
return resp, nil
155150
}
156151
}
@@ -212,7 +207,7 @@ func (p *TransparentProxy) modifyForSessionID(resp *http.Response) error {
212207
if sid == "" {
213208
sid = string(m[2])
214209
}
215-
p.IsServerInitialized = true
210+
p.setServerInitialized()
216211
err := p.sessionManager.AddWithID(sid)
217212
if err != nil {
218213
logger.Errorf("Failed to create session from SSE line: %v", err)

pkg/transport/session/manager.go

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,19 @@ type Manager struct {
2020
sessions sync.Map
2121
ttl time.Duration
2222
stopCh chan struct{}
23+
factory Factory
2324
}
2425

26+
// Factory defines a function type for creating new sessions.
27+
type Factory func(id string) *ProxySession
28+
2529
// NewManager creates a session manager with TTL and starts cleanup worker.
26-
func NewManager(ttl time.Duration) *Manager {
30+
func NewManager(ttl time.Duration, factory Factory) *Manager {
2731
m := &Manager{
28-
ttl: ttl,
29-
stopCh: make(chan struct{}),
32+
sessions: sync.Map{},
33+
ttl: ttl,
34+
stopCh: make(chan struct{}),
35+
factory: factory,
3036
}
3137
go m.cleanupRoutine()
3238
return m
@@ -40,7 +46,11 @@ func (m *Manager) cleanupRoutine() {
4046
case <-ticker.C:
4147
cutoff := time.Now().Add(-m.ttl)
4248
m.sessions.Range(func(key, val any) bool {
43-
sess := val.(Session)
49+
sess, ok := val.(Session)
50+
if !ok {
51+
// Skip invalid value
52+
return true
53+
}
4454
if sess.UpdatedAt().Before(cutoff) {
4555
m.sessions.Delete(key)
4656
}
@@ -59,7 +69,8 @@ func (m *Manager) AddWithID(id string) error {
5969
return fmt.Errorf("session ID cannot be empty")
6070
}
6171
// Use LoadOrStore: returns existing if already present
62-
_, loaded := m.sessions.LoadOrStore(id, NewProxySession(id))
72+
session := m.factory(id)
73+
_, loaded := m.sessions.LoadOrStore(id, session)
6374
if loaded {
6475
return fmt.Errorf("session ID %q already exists", id)
6576
}
@@ -73,7 +84,11 @@ func (m *Manager) Get(id string) (Session, bool) {
7384
if !ok {
7485
return nil, false
7586
}
76-
sess := v.(Session)
87+
sess, ok := v.(Session)
88+
if !ok {
89+
return nil, false // Invalid session type
90+
}
91+
7792
sess.Touch()
7893
return sess, true
7994
}
@@ -87,3 +102,14 @@ func (m *Manager) Delete(id string) {
87102
func (m *Manager) Stop() {
88103
close(m.stopCh)
89104
}
105+
106+
func (m *Manager) cleanupExpiredOnce() {
107+
cutoff := time.Now().Add(-m.ttl)
108+
m.sessions.Range(func(key, val any) bool {
109+
sess := val.(Session)
110+
if sess.UpdatedAt().Before(cutoff) {
111+
m.sessions.Delete(key)
112+
}
113+
return true
114+
})
115+
}

pkg/transport/session/manager_test.go

Lines changed: 67 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,140 @@
11
package session
22

33
import (
4+
"sync"
45
"testing"
56
"time"
67

78
"github.com/stretchr/testify/assert"
89
"github.com/stretchr/testify/require"
910
)
1011

11-
func TestAddAndGetWithStubSession(t *testing.T) {
12-
t.Parallel()
12+
// stubFactory returns ProxySessions with fixed timestamps and records IDs.
13+
type stubFactory struct {
14+
mu sync.Mutex
15+
createdIDs []string
16+
fixedTime time.Time
17+
}
1318

14-
orig := NewProxySession
15-
NewProxySession = func(id string) *ProxySession {
16-
ts := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
17-
return &ProxySession{id: id, created: ts, updated: ts}
19+
func (f *stubFactory) New(id string) *ProxySession {
20+
f.mu.Lock()
21+
defer f.mu.Unlock()
22+
f.createdIDs = append(f.createdIDs, id)
23+
return &ProxySession{
24+
id: id,
25+
created: f.fixedTime,
26+
updated: f.fixedTime,
1827
}
19-
defer func() { NewProxySession = orig }()
28+
}
2029

21-
m := NewManager(1 * time.Hour)
30+
func TestAddAndGetWithStubSession(t *testing.T) {
31+
t.Parallel()
32+
now := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
33+
factory := &stubFactory{fixedTime: now}
34+
35+
m := NewManager(time.Hour, factory.New)
2236
defer m.Stop()
2337

2438
require.NoError(t, m.AddWithID("foo"))
2539

2640
sess, ok := m.Get("foo")
27-
require.True(t, ok)
41+
require.True(t, ok, "session foo should exist")
2842
assert.Equal(t, "foo", sess.ID())
43+
assert.Contains(t, factory.createdIDs, "foo")
2944
}
3045

3146
func TestAddDuplicate(t *testing.T) {
3247
t.Parallel()
48+
factory := &stubFactory{fixedTime: time.Now()}
3349

34-
m := NewManager(time.Hour)
50+
m := NewManager(time.Hour, factory.New)
3551
defer m.Stop()
3652

37-
err := m.AddWithID("dup")
38-
assert.NoError(t, err)
53+
require.NoError(t, m.AddWithID("dup"))
3954

40-
err2 := m.AddWithID("dup")
41-
assert.Error(t, err2)
42-
assert.Contains(t, err2.Error(), "already exists")
55+
err := m.AddWithID("dup")
56+
assert.Error(t, err)
57+
assert.Contains(t, err.Error(), "already exists")
4358
}
4459

4560
func TestDeleteSession(t *testing.T) {
4661
t.Parallel()
62+
factory := &stubFactory{fixedTime: time.Now()}
4763

48-
m := NewManager(time.Hour)
64+
m := NewManager(time.Hour, factory.New)
4965
defer m.Stop()
5066

5167
require.NoError(t, m.AddWithID("del"))
5268
m.Delete("del")
5369

5470
_, ok := m.Get("del")
55-
assert.False(t, ok)
71+
assert.False(t, ok, "deleted session should not be found")
5672
}
5773

5874
func TestGetUpdatesTimestamp(t *testing.T) {
5975
t.Parallel()
76+
oldTime := time.Now().Add(-1 * time.Minute)
77+
factory := &stubFactory{fixedTime: oldTime}
6078

61-
orig := NewProxySession
62-
NewProxySession = func(id string) *ProxySession {
63-
ts := time.Now().Add(-1 * time.Minute)
64-
return &ProxySession{id: id, created: ts, updated: ts}
65-
}
66-
defer func() { NewProxySession = orig }()
67-
68-
m := NewManager(1 * time.Hour)
79+
m := NewManager(time.Hour, factory.New)
6980
defer m.Stop()
7081

7182
require.NoError(t, m.AddWithID("touchme"))
72-
s1, _ := m.Get("touchme")
83+
s1, ok := m.Get("touchme")
84+
require.True(t, ok)
7385
t0 := s1.UpdatedAt()
7486

75-
time.Sleep(5 * time.Millisecond)
76-
s2, _ := m.Get("touchme")
87+
time.Sleep(10 * time.Millisecond)
88+
s2, ok2 := m.Get("touchme")
89+
require.True(t, ok2)
7790
t1 := s2.UpdatedAt()
7891

79-
assert.True(t, t1.After(t0), "UpdatedAt should update on Get()")
92+
assert.True(t, t1.After(t0), "UpdatedAt should update on repeated Get()")
8093
}
81-
82-
func TestCleanupExpired(t *testing.T) {
94+
func TestCleanupExpired_ManualTrigger(t *testing.T) {
8395
t.Parallel()
8496

97+
// Stub factory: all sessions start with UpdatedAt = `now`
98+
now := time.Now()
99+
factory := &stubFactory{fixedTime: now}
85100
ttl := 50 * time.Millisecond
86-
orig := NewProxySession
87-
NewProxySession = func(id string) *ProxySession {
88-
return &ProxySession{
89-
id: id,
90-
created: time.Now(),
91-
updated: time.Now(),
92-
}
93-
}
94-
defer func() { NewProxySession = orig }()
95101

96-
m := NewManager(ttl)
102+
m := NewManager(ttl, factory.New)
97103
defer m.Stop()
98104

99105
require.NoError(t, m.AddWithID("old"))
100-
time.Sleep(ttl * 2) // allow old to expire
101106

102-
require.NoError(t, m.AddWithID("new"))
103-
time.Sleep(ttl) // let cleanup execute
107+
// Retrieve and expire session manually
108+
sess, ok := m.Get("old")
109+
require.True(t, ok)
110+
ps := sess.(*ProxySession)
111+
ps.updated = now.Add(-ttl * 2)
104112

113+
// Run cleanup manually
114+
m.cleanupExpiredOnce()
115+
116+
// Now it should be gone
105117
_, okOld := m.Get("old")
118+
assert.False(t, okOld, "expired session should have been cleaned")
119+
120+
// Add fresh session and assert it remains after cleanup
121+
require.NoError(t, m.AddWithID("new"))
122+
m.cleanupExpiredOnce()
106123
_, okNew := m.Get("new")
107-
assert.False(t, okOld, "expired session should be cleaned")
108-
assert.True(t, okNew, "recent session should remain")
124+
assert.True(t, okNew, "new session should still exist after cleanup")
109125
}
110126

111127
func TestStopDisablesCleanup(t *testing.T) {
112128
t.Parallel()
113-
114129
ttl := 50 * time.Millisecond
115-
m := NewManager(ttl)
116-
m.Stop() // stop cleanup upfront
130+
factory := &stubFactory{fixedTime: time.Now()}
131+
132+
m := NewManager(ttl, factory.New)
133+
m.Stop() // disable cleanup before any session expires
117134

118135
require.NoError(t, m.AddWithID("stay"))
119136
time.Sleep(ttl * 2)
120137

121138
_, ok := m.Get("stay")
122-
assert.True(t, ok, "session should persist after Stop()")
139+
assert.True(t, ok, "session should still be present even after Stop() and TTL elapsed")
123140
}

0 commit comments

Comments
 (0)