Skip to content

Commit dc85ef6

Browse files
committed
Fix: properly store renewed access token into grant table
Signed-off-by: Daishan Peng <[email protected]>
1 parent 363c91e commit dc85ef6

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

database/database.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,52 @@ func (d *Database) StoreGrant(grant *Grant) error {
499499
return err
500500
}
501501

502+
// UpdateGrant updates an existing grant's properties
503+
func (d *Database) UpdateGrant(grant *Grant) error {
504+
var query string
505+
if d.dbType == "postgres" {
506+
query = `
507+
UPDATE grants
508+
SET scope = $1, metadata = $2, props = $3, expires_at = $4
509+
WHERE id = $5 AND user_id = $6
510+
`
511+
} else {
512+
query = `
513+
UPDATE grants
514+
SET scope = ?, metadata = ?, props = ?, expires_at = ?
515+
WHERE id = ? AND user_id = ?
516+
`
517+
}
518+
519+
scope, _ := json.Marshal(grant.Scope)
520+
metadata, _ := json.Marshal(grant.Metadata)
521+
props, _ := json.Marshal(grant.Props)
522+
523+
result, err := d.db.Exec(query,
524+
scope,
525+
metadata,
526+
props,
527+
grant.ExpiresAt,
528+
grant.ID,
529+
grant.UserID,
530+
)
531+
if err != nil {
532+
return err
533+
}
534+
535+
// Check if any rows were affected
536+
rowsAffected, err := result.RowsAffected()
537+
if err != nil {
538+
return err
539+
}
540+
541+
if rowsAffected == 0 {
542+
return fmt.Errorf("grant not found: id=%s, user_id=%s", grant.ID, grant.UserID)
543+
}
544+
545+
return nil
546+
}
547+
502548
// GetGrant retrieves a grant by ID and user ID
503549
func (d *Database) GetGrant(grantID, userID string) (*Grant, error) {
504550
var query string

main.go

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,80 @@ func (p *OAuthProxy) decryptPropsIfNeeded(props map[string]interface{}) (map[str
327327
return result, nil
328328
}
329329

330+
// updateGrant updates a grant with new token information
331+
func (p *OAuthProxy) updateGrant(grantID, userID string, newTokenInfo *providers.TokenInfo) error {
332+
// Get the existing grant
333+
grant, err := p.db.GetGrant(grantID, userID)
334+
if err != nil {
335+
return fmt.Errorf("failed to get grant: %w", err)
336+
}
337+
338+
// Prepare sensitive props data
339+
sensitiveProps := map[string]interface{}{
340+
"access_token": newTokenInfo.AccessToken,
341+
"refresh_token": newTokenInfo.RefreshToken,
342+
"expires_at": newTokenInfo.ExpireAt,
343+
}
344+
345+
// Add existing user info if available
346+
if grant.Props != nil {
347+
if email, ok := grant.Props["email"].(string); ok {
348+
sensitiveProps["email"] = email
349+
}
350+
if name, ok := grant.Props["name"].(string); ok {
351+
sensitiveProps["name"] = name
352+
}
353+
if userID, ok := grant.Props["user_id"].(string); ok {
354+
sensitiveProps["user_id"] = userID
355+
}
356+
}
357+
358+
// Initialize props map
359+
props := make(map[string]interface{})
360+
361+
// Check if encryption is enabled
362+
if p.encryptionKey != "" {
363+
// Decode the encryption key from base64
364+
encryptionKey, err := base64.StdEncoding.DecodeString(p.encryptionKey)
365+
if err != nil {
366+
return fmt.Errorf("failed to decode encryption key: %w", err)
367+
}
368+
369+
// Validate key length (must be 32 bytes for AES-256)
370+
if len(encryptionKey) != 32 {
371+
return fmt.Errorf("invalid encryption key length: %d bytes (expected 32)", len(encryptionKey))
372+
}
373+
374+
// Encrypt the sensitive props data
375+
encryptedProps, err := encryptData(sensitiveProps, encryptionKey)
376+
if err != nil {
377+
return fmt.Errorf("failed to encrypt props data: %w", err)
378+
}
379+
380+
// Store encrypted data
381+
props["encrypted_data"] = encryptedProps.Data
382+
props["iv"] = encryptedProps.IV
383+
props["algorithm"] = encryptedProps.Algorithm
384+
props["encrypted"] = true
385+
} else {
386+
// Store data in plain text if no encryption key is provided
387+
for key, value := range sensitiveProps {
388+
props[key] = value
389+
}
390+
props["encrypted"] = false
391+
}
392+
393+
// Update the grant with new props
394+
grant.Props = props
395+
396+
// Update the grant in the database
397+
if err := p.db.UpdateGrant(grant); err != nil {
398+
return fmt.Errorf("failed to update grant: %w", err)
399+
}
400+
401+
return nil
402+
}
403+
330404
// databaseAdapter adapts the database to the tokens.Database interface
331405
type databaseAdapter struct {
332406
db *database.Database
@@ -984,7 +1058,17 @@ func (p *OAuthProxy) mcpProxyHandler(c *gin.Context) {
9841058
return
9851059
}
9861060

987-
// Update the token info with the new access token
1061+
// Update the grant with new token information
1062+
if err := p.updateGrant(tokenInfo.GrantID, tokenInfo.UserID, newTokenInfo); err != nil {
1063+
log.Printf("Failed to update grant: %v", err)
1064+
c.JSON(http.StatusInternalServerError, gin.H{
1065+
"error": "server_error",
1066+
"error_description": "Failed to update grant with new token",
1067+
})
1068+
return
1069+
}
1070+
1071+
// Update the token info with the new access token for the current request
9881072
tokenInfo.Props["access_token"] = newTokenInfo.AccessToken
9891073
if newTokenInfo.RefreshToken != "" {
9901074
tokenInfo.Props["refresh_token"] = newTokenInfo.RefreshToken
@@ -1042,6 +1126,7 @@ func (p *OAuthProxy) mcpProxyHandler(c *gin.Context) {
10421126
},
10431127
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
10441128
log.Printf("Proxy error: %v", err)
1129+
c.Abort()
10451130
},
10461131
}
10471132

tokens/jwt.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Database interface {
2323

2424
type TokenClaims struct {
2525
UserID string `json:"user_id"`
26+
GrantID string `json:"grant_id"`
2627
Props map[string]interface{} `json:"props,omitempty"`
2728
ExpiresAt time.Time `json:"expires_at"`
2829
}
@@ -95,6 +96,7 @@ func (tm *TokenManager) ValidateAccessToken(tokenString string) (*TokenClaims, e
9596
// Create TokenClaims with the grant's props
9697
claims := &TokenClaims{
9798
UserID: userID,
99+
GrantID: grantID,
98100
Props: grant.Props,
99101
ExpiresAt: tokenData.ExpiresAt,
100102
}
@@ -111,6 +113,7 @@ func (tm *TokenManager) GetTokenInfo(tokenString string) (*TokenInfo, error) {
111113

112114
return &TokenInfo{
113115
UserID: claims.UserID,
116+
GrantID: claims.GrantID,
114117
Props: claims.Props,
115118
ExpiresAt: claims.ExpiresAt,
116119
}, nil
@@ -119,6 +122,7 @@ func (tm *TokenManager) GetTokenInfo(tokenString string) (*TokenInfo, error) {
119122
// TokenInfo represents token information
120123
type TokenInfo struct {
121124
UserID string
125+
GrantID string
122126
Props map[string]interface{}
123127
ExpiresAt time.Time
124128
}

0 commit comments

Comments
 (0)