Skip to content

Commit 2cf2cb8

Browse files
committed
Enhance: Support cookie flow for mcp-ui
Signed-off-by: Daishan Peng <[email protected]> wip Signed-off-by: Daishan Peng <[email protected]>
1 parent 95fc784 commit 2cf2cb8

File tree

19 files changed

+1396
-193
lines changed

19 files changed

+1396
-193
lines changed

cmd/root.go

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
package cmd
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"net/http"
7+
"net/url"
8+
9+
"github.com/gptscript-ai/cmd"
10+
"github.com/obot-platform/mcp-oauth-proxy/pkg/proxy"
11+
"github.com/obot-platform/mcp-oauth-proxy/pkg/types"
12+
"github.com/spf13/cobra"
13+
)
14+
15+
var (
16+
version = "dev"
17+
buildTime = "unknown"
18+
)
19+
20+
// RootCmd represents the base command when called without any subcommands
21+
type RootCmd struct {
22+
// Database configuration
23+
DatabaseDSN string `name:"database-dsn" env:"DATABASE_DSN" usage:"Database connection string (PostgreSQL or SQLite file path). If empty, uses SQLite at data/oauth_proxy.db"`
24+
25+
// OAuth Provider configuration
26+
OAuthClientID string `name:"oauth-client-id" env:"OAUTH_CLIENT_ID" usage:"OAuth client ID from your OAuth provider" required:"true"`
27+
OAuthClientSecret string `name:"oauth-client-secret" env:"OAUTH_CLIENT_SECRET" usage:"OAuth client secret from your OAuth provider" required:"true"`
28+
OAuthAuthorizeURL string `name:"oauth-authorize-url" env:"OAUTH_AUTHORIZE_URL" usage:"Authorization endpoint URL from your OAuth provider (e.g., https://accounts.google.com)" required:"true"`
29+
30+
// Scopes and MCP configuration
31+
ScopesSupported string `name:"scopes-supported" env:"SCOPES_SUPPORTED" usage:"Comma-separated list of supported OAuth scopes (e.g., 'openid,profile,email')" required:"true"`
32+
MCPServerURL string `name:"mcp-server-url" env:"MCP_SERVER_URL" usage:"URL of the MCP server to proxy requests to" required:"true"`
33+
34+
// Security configuration
35+
EncryptionKey string `name:"encryption-key" env:"ENCRYPTION_KEY" usage:"Base64-encoded 32-byte AES-256 key for encrypting sensitive data (optional)"`
36+
37+
// Server configuration
38+
Port string `name:"port" env:"PORT" usage:"Port to run the server on" default:"8080"`
39+
Host string `name:"host" env:"HOST" usage:"Host to bind the server to" default:"localhost"`
40+
41+
// Logging
42+
Verbose bool `name:"verbose,v" usage:"Enable verbose logging"`
43+
Version bool `name:"version" usage:"Show version information"`
44+
45+
Mode string `name:"mode" env:"MODE" usage:"Mode to run the server in" default:"proxy"`
46+
}
47+
48+
const (
49+
ModeProxy = "proxy"
50+
ModeForwardAuth = "forward_auth"
51+
)
52+
53+
func (c *RootCmd) Run(cobraCmd *cobra.Command, args []string) error {
54+
if c.Version {
55+
fmt.Printf("MCP OAuth Proxy\n")
56+
fmt.Printf("Version: %s\n", version)
57+
fmt.Printf("Built: %s\n", buildTime)
58+
return nil
59+
}
60+
61+
// Configure logging
62+
if c.Verbose {
63+
log.SetFlags(log.LstdFlags | log.Lshortfile)
64+
log.Println("Verbose logging enabled")
65+
}
66+
67+
// Convert CLI config to internal config format
68+
config := &types.Config{
69+
DatabaseDSN: c.DatabaseDSN,
70+
OAuthClientID: c.OAuthClientID,
71+
OAuthClientSecret: c.OAuthClientSecret,
72+
OAuthAuthorizeURL: c.OAuthAuthorizeURL,
73+
ScopesSupported: c.ScopesSupported,
74+
MCPServerURL: c.MCPServerURL,
75+
EncryptionKey: c.EncryptionKey,
76+
Mode: c.Mode,
77+
}
78+
79+
// Validate configuration
80+
if err := c.validateConfig(); err != nil {
81+
return fmt.Errorf("configuration validation failed: %w", err)
82+
}
83+
84+
// Create OAuth proxy
85+
oauthProxy, err := proxy.NewOAuthProxy(config)
86+
if err != nil {
87+
return fmt.Errorf("failed to create OAuth proxy: %w", err)
88+
}
89+
defer func() {
90+
if err := oauthProxy.Close(); err != nil {
91+
log.Printf("Error closing database: %v", err)
92+
}
93+
}()
94+
95+
// Get HTTP handler
96+
handler := oauthProxy.GetHandler()
97+
98+
// Start server
99+
address := fmt.Sprintf("%s:%s", c.Host, c.Port)
100+
log.Printf("Starting OAuth proxy server on %s", address)
101+
log.Printf("OAuth Provider: %s", c.OAuthAuthorizeURL)
102+
log.Printf("MCP Server: %s", c.MCPServerURL)
103+
log.Printf("Database: %s", c.getDatabaseType())
104+
105+
return http.ListenAndServe(address, handler)
106+
}
107+
108+
func (c *RootCmd) validateConfig() error {
109+
if c.OAuthClientID == "" {
110+
return fmt.Errorf("oauth-client-id is required")
111+
}
112+
if c.OAuthClientSecret == "" {
113+
return fmt.Errorf("oauth-client-secret is required")
114+
}
115+
if c.OAuthAuthorizeURL == "" {
116+
return fmt.Errorf("oauth-authorize-url is required")
117+
}
118+
if c.ScopesSupported == "" {
119+
return fmt.Errorf("scopes-supported is required")
120+
}
121+
if c.MCPServerURL == "" {
122+
return fmt.Errorf("mcp-server-url is required")
123+
}
124+
if c.Mode == ModeProxy {
125+
if u, err := url.Parse(c.MCPServerURL); err != nil || u.Scheme != "http" && u.Scheme != "https" {
126+
return fmt.Errorf("invalid MCP server URL: %w", err)
127+
} else if u.Path != "" && u.Path != "/" || u.RawQuery != "" || u.Fragment != "" {
128+
return fmt.Errorf("MCP server URL must not contain a path, query, or fragment")
129+
}
130+
}
131+
return nil
132+
}
133+
134+
func (c *RootCmd) getDatabaseType() string {
135+
if c.DatabaseDSN == "" {
136+
return "SQLite (data/oauth_proxy.db)"
137+
}
138+
if len(c.DatabaseDSN) > 10 && (c.DatabaseDSN[:11] == "postgres://" || c.DatabaseDSN[:14] == "postgresql://") {
139+
return "PostgreSQL"
140+
}
141+
return fmt.Sprintf("SQLite (%s)", c.DatabaseDSN)
142+
}
143+
144+
// Customizer interface implementation for additional command customization
145+
func (c *RootCmd) Customize(cobraCmd *cobra.Command) {
146+
cobraCmd.Use = "mcp-oauth-proxy"
147+
cobraCmd.Short = "OAuth 2.1 proxy server for MCP (Model Context Protocol)"
148+
cobraCmd.Long = `MCP OAuth Proxy is a comprehensive OAuth 2.1 proxy server that provides
149+
OAuth authorization server functionality with PostgreSQL/SQLite storage.
150+
151+
This proxy supports multiple OAuth providers (Google, Microsoft, GitHub) and
152+
proxies requests to MCP servers with user context headers.
153+
154+
Examples:
155+
# Start with environment variables
156+
export OAUTH_CLIENT_ID="your-google-client-id"
157+
export OAUTH_CLIENT_SECRET="your-secret"
158+
export OAUTH_AUTHORIZE_URL="https://accounts.google.com"
159+
export SCOPES_SUPPORTED="openid,profile,email"
160+
export MCP_SERVER_URL="http://localhost:3000"
161+
mcp-oauth-proxy
162+
163+
# Start with CLI flags
164+
mcp-oauth-proxy \
165+
--oauth-client-id="your-google-client-id" \
166+
--oauth-client-secret="your-secret" \
167+
--oauth-authorize-url="https://accounts.google.com" \
168+
--scopes-supported="openid,profile,email" \
169+
--mcp-server-url="http://localhost:3000"
170+
171+
# Use PostgreSQL database
172+
mcp-oauth-proxy \
173+
--database-dsn="postgres://user:pass@localhost:5432/oauth_db?sslmode=disable" \
174+
--oauth-client-id="your-client-id" \
175+
# ... other required flags
176+
177+
Configuration:
178+
Configuration values are loaded in this order (later values override earlier ones):
179+
1. Default values
180+
2. Environment variables
181+
3. Command line flags
182+
183+
Database Support:
184+
- PostgreSQL: Full ACID compliance, recommended for production
185+
- SQLite: Zero configuration, perfect for development and small deployments`
186+
187+
cobraCmd.Version = version
188+
}
189+
190+
// Execute is the main entry point for the CLI
191+
func Execute() error {
192+
rootCmd := &RootCmd{}
193+
cobraCmd := cmd.Command(rootCmd)
194+
return cobraCmd.Execute()
195+
}

go.mod

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ module github.com/obot-platform/mcp-oauth-proxy
33
go 1.23.0
44

55
require (
6+
github.com/golang-jwt/jwt/v5 v5.3.0
67
github.com/gorilla/handlers v1.5.2
8+
github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070
9+
github.com/spf13/cobra v1.7.0
710
github.com/stretchr/testify v1.10.0
11+
golang.org/x/oauth2 v0.30.0
812
gorm.io/driver/postgres v1.6.0
913
gorm.io/driver/sqlite v1.6.0
1014
gorm.io/gorm v1.30.1
@@ -13,6 +17,7 @@ require (
1317
require (
1418
github.com/davecgh/go-spew v1.1.1 // indirect
1519
github.com/felixge/httpsnoop v1.0.3 // indirect
20+
github.com/inconshreveable/mousetrap v1.1.0 // indirect
1621
github.com/jackc/pgpassfile v1.0.0 // indirect
1722
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
1823
github.com/jackc/pgx/v5 v5.7.5 // indirect
@@ -23,8 +28,8 @@ require (
2328
github.com/mattn/go-sqlite3 v1.14.32 // indirect
2429
github.com/pmezard/go-difflib v1.0.0 // indirect
2530
github.com/rogpeppe/go-internal v1.8.0 // indirect
31+
github.com/spf13/pflag v1.0.5 // indirect
2632
golang.org/x/crypto v0.41.0 // indirect
27-
golang.org/x/oauth2 v0.30.0 // indirect
2833
golang.org/x/sync v0.16.0 // indirect
2934
golang.org/x/text v0.28.0 // indirect
3035
gopkg.in/yaml.v3 v3.0.1 // indirect

go.sum

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
12
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
23
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
34
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
45
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
56
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
7+
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
8+
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
69
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
710
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
11+
github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070 h1:xm5ZZFraWFwxyE7TBEncCXArubCDZTwG6s5bpMzqhSY=
12+
github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
13+
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
14+
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
815
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
916
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
1017
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -30,6 +37,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
3037
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
3138
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
3239
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
40+
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
41+
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
42+
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
43+
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
44+
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
3345
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
3446
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
3547
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

main.go

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,13 @@
11
package main
22

33
import (
4-
"log"
5-
"net/http"
4+
"os"
65

7-
"github.com/obot-platform/mcp-oauth-proxy/pkg/proxy"
6+
"github.com/obot-platform/mcp-oauth-proxy/cmd"
87
)
98

109
func main() {
11-
// Load configuration from environment variables
12-
config, err := proxy.LoadConfigFromEnv()
13-
if err != nil {
14-
log.Fatalf("Failed to load configuration: %v", err)
10+
if err := cmd.Execute(); err != nil {
11+
os.Exit(1)
1512
}
16-
17-
proxy, err := proxy.NewOAuthProxy(config)
18-
if err != nil {
19-
log.Fatalf("Failed to create OAuth proxy: %v", err)
20-
}
21-
defer func() {
22-
if err := proxy.Close(); err != nil {
23-
log.Printf("Error closing database: %v", err)
24-
}
25-
}()
26-
27-
// Get HTTP handler
28-
handler := proxy.GetHandler()
29-
30-
// Start server
31-
log.Printf("Starting OAuth proxy server on localhost:" + config.Port)
32-
log.Fatal(http.ListenAndServe(":"+config.Port, handler))
3313
}

pkg/db/db.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,6 @@ func (d *Store) RevokeToken(token string) error {
295295
return result.Error
296296
}
297297

298-
// UpdateTokenRefreshToken updates the refresh token for an existing token
299-
func (d *Store) UpdateTokenRefreshToken(accessToken, newRefreshToken string) error {
300-
hashedAccessToken := hashToken(accessToken)
301-
hashedNewRefreshToken := hashToken(newRefreshToken)
302-
303-
return d.db.Model(&types.TokenData{}).Where("access_token = ?", hashedAccessToken).Update("refresh_token", hashedNewRefreshToken).Error
304-
}
305-
306298
// CleanupExpiredTokens removes expired tokens and authorization codes
307299
func (d *Store) CleanupExpiredTokens() error {
308300
now := time.Now()

pkg/db/db_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func TestDatabaseOperations(t *testing.T) {
2424
if dsn == "" {
2525
t.Skip("Skipping database tests: TEST_DATABASE_DSN is not set")
2626
}
27+
2728
db, err := New(dsn)
2829
if err != nil {
2930
t.Skipf("Skipping database tests: %v", err)
@@ -234,8 +235,6 @@ func testTokenOperations(t *testing.T, db *Store) {
234235
// Test updating refresh token
235236
newRefreshTokenData, err := generateRandomString(16)
236237
require.NoError(t, err)
237-
err = db.UpdateTokenRefreshToken(accessTokenData, newRefreshTokenData)
238-
require.NoError(t, err)
239238

240239
updatedToken, err := db.GetToken(accessTokenData)
241240
require.NoError(t, err)

pkg/encryption/encryption.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,69 @@ func DecryptPropsIfNeeded(encryptionKey []byte, props map[string]any) (map[strin
143143

144144
return result, nil
145145
}
146+
147+
// EncryptString encrypts a string using AES-256-GCM
148+
func EncryptString(encryptionKey []byte, plaintext string) (string, error) {
149+
// Create AES cipher
150+
block, err := aes.NewCipher(encryptionKey)
151+
if err != nil {
152+
return "", fmt.Errorf("failed to create cipher: %w", err)
153+
}
154+
155+
// Create GCM mode
156+
gcm, err := cipher.NewGCM(block)
157+
if err != nil {
158+
return "", fmt.Errorf("failed to create GCM: %w", err)
159+
}
160+
161+
// Generate random IV
162+
iv := make([]byte, gcm.NonceSize())
163+
if _, err := rand.Read(iv); err != nil {
164+
return "", fmt.Errorf("failed to generate IV: %w", err)
165+
}
166+
167+
// Encrypt the data
168+
ciphertext := gcm.Seal(nil, iv, []byte(plaintext), nil)
169+
170+
// Combine IV and ciphertext, then base64 encode
171+
combined := append(iv, ciphertext...)
172+
return base64.StdEncoding.EncodeToString(combined), nil
173+
}
174+
175+
// DecryptString decrypts a string using AES-256-GCM
176+
func DecryptString(encryptionKey []byte, encryptedData string) (string, error) {
177+
// Decode base64 data
178+
combined, err := base64.StdEncoding.DecodeString(encryptedData)
179+
if err != nil {
180+
return "", fmt.Errorf("failed to decode encrypted data: %w", err)
181+
}
182+
183+
// Create AES cipher
184+
block, err := aes.NewCipher(encryptionKey)
185+
if err != nil {
186+
return "", fmt.Errorf("failed to create cipher: %w", err)
187+
}
188+
189+
// Create GCM mode
190+
gcm, err := cipher.NewGCM(block)
191+
if err != nil {
192+
return "", fmt.Errorf("failed to create GCM: %w", err)
193+
}
194+
195+
// Extract IV and ciphertext
196+
ivSize := gcm.NonceSize()
197+
if len(combined) < ivSize {
198+
return "", fmt.Errorf("encrypted data too short")
199+
}
200+
201+
iv := combined[:ivSize]
202+
ciphertext := combined[ivSize:]
203+
204+
// Decrypt the data
205+
plaintext, err := gcm.Open(nil, iv, ciphertext, nil)
206+
if err != nil {
207+
return "", fmt.Errorf("failed to decrypt data: %w", err)
208+
}
209+
210+
return string(plaintext), nil
211+
}

0 commit comments

Comments
 (0)