From 86a1cd657e8b895daac572c4279886da53d0a8ac Mon Sep 17 00:00:00 2001 From: Daishan Peng Date: Thu, 17 Jul 2025 14:42:47 -0700 Subject: [PATCH] Fix: fix grant and refresh token cleanup Signed-off-by: Daishan Peng --- database/database.go | 181 ++++++++++++++++++++---- database/database_test.go | 288 +++++++++++++++++++++++++++++++++++--- database/sqlite_test.go | 13 ++ main.go | 15 +- 4 files changed, 449 insertions(+), 48 deletions(-) diff --git a/database/database.go b/database/database.go index 54ecc2f..e82acd8 100644 --- a/database/database.go +++ b/database/database.go @@ -23,16 +23,17 @@ type Database struct { // TokenData represents stored token data for OAuth 2.1 compliance type TokenData struct { - AccessToken string - RefreshToken string - ClientID string - UserID string - GrantID string - Scope string - ExpiresAt time.Time - CreatedAt time.Time - Revoked bool - RevokedAt *time.Time + AccessToken string + RefreshToken string + ClientID string + UserID string + GrantID string + Scope string + ExpiresAt time.Time + RefreshTokenExpiresAt time.Time + CreatedAt time.Time + Revoked bool + RevokedAt *time.Time } // ClientInfo represents OAuth client registration information @@ -116,8 +117,23 @@ func NewDatabase(dsn string) (*Database, error) { return database, nil } -// setupSchema creates the necessary tables +// setupSchema creates the necessary tables and handles migrations func (d *Database) setupSchema() error { + // First, create tables if they don't exist + if err := d.createTables(); err != nil { + return err + } + + // Then run migrations + if err := d.runMigrations(); err != nil { + return err + } + + return nil +} + +// createTables creates the base tables +func (d *Database) createTables() error { var queries []string if d.dbType == "postgres" { @@ -254,6 +270,86 @@ func (d *Database) setupSchema() error { return nil } +// runMigrations handles database schema migrations +func (d *Database) runMigrations() error { + // Migration 1: Add refresh_token_expires_at column to access_tokens table + if err := d.migrateAddRefreshTokenExpiration(); err != nil { + return fmt.Errorf("failed to run migration 1: %w", err) + } + + return nil +} + +// migrateAddRefreshTokenExpiration adds the refresh_token_expires_at column +func (d *Database) migrateAddRefreshTokenExpiration() error { + // Check if the column already exists + columnExists, err := d.columnExists("access_tokens", "refresh_token_expires_at") + if err != nil { + return fmt.Errorf("failed to check if column exists: %w", err) + } + + if columnExists { + return nil // Column already exists, no migration needed + } + + // Add the column + var query string + if d.dbType == "postgres" { + query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at TIMESTAMPTZ` + } else { + query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at DATETIME` + } + + if _, err := d.db.Exec(query); err != nil { + return fmt.Errorf("failed to add refresh_token_expires_at column: %w", err) + } + + // Update existing records to have a default expiration (30 days from now) + updateQuery := `UPDATE access_tokens SET refresh_token_expires_at = ? WHERE refresh_token_expires_at IS NULL` + if d.dbType == "postgres" { + updateQuery = `UPDATE access_tokens SET refresh_token_expires_at = $1 WHERE refresh_token_expires_at IS NULL` + } + + defaultExpiration := time.Now().Add(30 * 24 * time.Hour) + if _, err := d.db.Exec(updateQuery, defaultExpiration); err != nil { + return fmt.Errorf("failed to update existing records: %w", err) + } + + // Make the column NOT NULL (SQLite doesn't support ALTER COLUMN SET NOT NULL directly) + if d.dbType == "postgres" { + query = `ALTER TABLE access_tokens ALTER COLUMN refresh_token_expires_at SET NOT NULL` + if _, err := d.db.Exec(query); err != nil { + return fmt.Errorf("failed to make refresh_token_expires_at NOT NULL: %w", err) + } + } + + return nil +} + +// columnExists checks if a column exists in a table +func (d *Database) columnExists(tableName, columnName string) (bool, error) { + var query string + if d.dbType == "postgres" { + query = ` + SELECT COUNT(*) FROM information_schema.columns + WHERE table_name = $1 AND column_name = $2 + ` + } else { + query = ` + SELECT COUNT(*) FROM pragma_table_info(?) + WHERE name = ? + ` + } + + var count int + err := d.db.QueryRow(query, tableName, columnName).Scan(&count) + if err != nil { + return false, err + } + + return count > 0, nil +} + // GetClient retrieves a client by ID func (d *Database) GetClient(clientID string) (*ClientInfo, error) { var query string @@ -522,13 +618,13 @@ func (d *Database) StoreToken(data *TokenData) error { var query string if d.dbType == "postgres" { query = ` - INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ` } else { query = ` - INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) ` } @@ -536,7 +632,12 @@ func (d *Database) StoreToken(data *TokenData) error { hashedAccessToken := hashToken(data.AccessToken) hashedRefreshToken := hashToken(data.RefreshToken) - _, err := d.db.Exec(query, hashedAccessToken, hashedRefreshToken, data.ClientID, data.UserID, data.GrantID, data.Scope, data.ExpiresAt) + // Set refresh token expiration to 30 days from now if not already set + if data.RefreshTokenExpiresAt.IsZero() { + data.RefreshTokenExpiresAt = time.Now().Add(30 * 24 * time.Hour) + } + + _, err := d.db.Exec(query, hashedAccessToken, hashedRefreshToken, data.ClientID, data.UserID, data.GrantID, data.Scope, data.ExpiresAt, data.RefreshTokenExpiresAt) return err } @@ -545,12 +646,12 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) { var query string if d.dbType == "postgres" { query = ` - SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at + SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at FROM access_tokens WHERE access_token = $1 ` } else { query = ` - SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at + SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at FROM access_tokens WHERE access_token = ? ` } @@ -567,6 +668,7 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) { &data.GrantID, &data.Scope, &data.ExpiresAt, + &data.RefreshTokenExpiresAt, &data.CreatedAt, &data.Revoked, &revokedAt, @@ -588,13 +690,13 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro var query string if d.dbType == "postgres" { query = ` - SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at - FROM access_tokens WHERE refresh_token = $1 + SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at + FROM access_tokens WHERE refresh_token = $1 AND refresh_token_expires_at > NOW() ` } else { query = ` - SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at - FROM access_tokens WHERE refresh_token = ? + SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at + FROM access_tokens WHERE refresh_token = ? AND refresh_token_expires_at > datetime('now') ` } @@ -611,6 +713,7 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro &data.GrantID, &data.Scope, &data.ExpiresAt, + &data.RefreshTokenExpiresAt, &data.CreatedAt, &data.Revoked, &revokedAt, @@ -630,6 +733,29 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro return &data, nil } +// IsRefreshTokenExpired checks if a refresh token is expired +func (d *Database) IsRefreshTokenExpired(refreshToken string) (bool, error) { + var query string + if d.dbType == "postgres" { + query = ` + SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = $1 + ` + } else { + query = ` + SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = ? + ` + } + + hashedRefreshToken := hashToken(refreshToken) + var expiresAt time.Time + err := d.db.QueryRow(query, hashedRefreshToken).Scan(&expiresAt) + if err != nil { + return false, err + } + + return time.Now().After(expiresAt), nil +} + // RevokeToken revokes an access token func (d *Database) RevokeToken(token string) error { hashedToken := hashToken(token) @@ -685,22 +811,27 @@ func (d *Database) CleanupExpiredTokens() error { if d.dbType == "postgres" { queries = []string{ - `DELETE FROM access_tokens WHERE expires_at < NOW() OR revoked = TRUE`, + `DELETE FROM access_tokens WHERE (expires_at < NOW() AND refresh_token_expires_at < NOW()) OR revoked = TRUE`, `DELETE FROM authorization_codes WHERE expires_at < NOW()`, `DELETE FROM grants WHERE expires_at < EXTRACT(EPOCH FROM NOW())`, } } else { queries = []string{ - `DELETE FROM access_tokens WHERE expires_at < datetime('now') OR revoked = 1`, + `DELETE FROM access_tokens WHERE (expires_at < datetime('now') AND refresh_token_expires_at < datetime('now')) OR revoked = 1`, `DELETE FROM authorization_codes WHERE expires_at < datetime('now')`, `DELETE FROM grants WHERE expires_at < strftime('%s', 'now')`, } } for _, query := range queries { - if _, err := d.db.Exec(query); err != nil { + result, err := d.db.Exec(query) + if err != nil { return fmt.Errorf("failed to cleanup expired tokens: %w", err) } + rowsAffected, _ := result.RowsAffected() + if rowsAffected > 0 { + fmt.Printf("Deleted %d expired rows for query %s\n", rowsAffected, query) + } } return nil diff --git a/database/database_test.go b/database/database_test.go index c4f27c2..18cd668 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -52,6 +52,14 @@ func TestDatabaseOperations(t *testing.T) { t.Run("TestCleanupOperations", func(t *testing.T) { testCleanupOperations(t, db) }) + + t.Run("TestRefreshTokenExpiration", func(t *testing.T) { + testRefreshTokenExpiration(t, db) + }) + + t.Run("TestDatabaseMigration", func(t *testing.T) { + testDatabaseMigration(t, db) + }) } func testClientOperations(t *testing.T, db *Database) { @@ -182,15 +190,16 @@ func testTokenOperations(t *testing.T, db *Database) { // Test storing token tokenData := &TokenData{ - AccessToken: accessTokenData, - RefreshToken: refreshTokenData, - ClientID: "test_client_db", - UserID: "test_user_123", - GrantID: grantID, - Scope: "read write admin", - ExpiresAt: time.Now().Add(1 * time.Hour), - CreatedAt: time.Now(), - Revoked: false, + AccessToken: accessTokenData, + RefreshToken: refreshTokenData, + ClientID: "test_client_db", + UserID: "test_user_123", + GrantID: grantID, + Scope: "read write admin", + ExpiresAt: time.Now().Add(1 * time.Hour), + RefreshTokenExpiresAt: time.Now().Add(30 * 24 * time.Hour), // 30 days + CreatedAt: time.Now(), + Revoked: false, } err = db.StoreToken(tokenData) @@ -206,12 +215,14 @@ func testTokenOperations(t *testing.T, db *Database) { assert.Equal(t, tokenData.GrantID, retrievedToken.GrantID) assert.Equal(t, tokenData.Scope, retrievedToken.Scope) assert.False(t, retrievedToken.Revoked) + assert.True(t, retrievedToken.RefreshTokenExpiresAt.After(time.Now().Add(29*24*time.Hour))) // Should be ~30 days // Test retrieving token by refresh token refreshToken, err := db.GetTokenByRefreshToken(refreshTokenData) require.NoError(t, err) assert.Equal(t, hashToken(tokenData.AccessToken), refreshToken.AccessToken) assert.Equal(t, tokenData.RefreshToken, refreshToken.RefreshToken) + assert.True(t, refreshToken.RefreshTokenExpiresAt.After(time.Now().Add(29*24*time.Hour))) // Should be ~30 days // Test revoking token err = db.RevokeToken(accessTokenData) @@ -303,15 +314,16 @@ func testCleanupOperations(t *testing.T, db *Database) { // Create expired tokens expiredToken := &TokenData{ - AccessToken: accessTokenData, - RefreshToken: refreshTokenData, - ClientID: "test_client_db", - UserID: userID, - GrantID: grantID, - Scope: "read write", - ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired - CreatedAt: time.Now().Add(-2 * time.Hour), - Revoked: false, + AccessToken: accessTokenData, + RefreshToken: refreshTokenData, + ClientID: "test_client_db", + UserID: userID, + GrantID: grantID, + Scope: "read write", + ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired + RefreshTokenExpiresAt: time.Now().Add(-1 * time.Hour), // Also expired + CreatedAt: time.Now().Add(-2 * time.Hour), + Revoked: false, } err = db.StoreToken(expiredToken) @@ -321,8 +333,8 @@ func testCleanupOperations(t *testing.T, db *Database) { err = db.CleanupExpiredTokens() require.NoError(t, err) - // Verify expired token is cleaned up - _, err = db.GetToken("expired_token_123") + // Verify expired token is cleaned up (both access and refresh tokens expired) + _, err = db.GetToken(accessTokenData) assert.Error(t, err) } @@ -359,3 +371,239 @@ func TestTokenHashing(t *testing.T) { assert.NotEmpty(t, hash1) assert.Len(t, hash1, 44) // Base64 encoded SHA256 hash length } + +func testRefreshTokenExpiration(t *testing.T, db *Database) { + grantID, err := generateRandomString(16) + require.NoError(t, err) + + clientID, err := generateRandomString(16) + require.NoError(t, err) + + grant := &Grant{ + ID: grantID, + ClientID: clientID, + UserID: "test_user_123", + Scope: []string{"read", "write", "admin"}, + Metadata: map[string]interface{}{"provider": "test", "ip": "127.0.0.1"}, + ExpiresAt: time.Now().Add(1 * time.Hour).Unix(), + } + + err = db.StoreGrant(grant) + require.NoError(t, err) + + // Test 1: Store token with default refresh token expiration (30 days) + accessToken1, err := generateRandomString(16) + require.NoError(t, err) + refreshToken1, err := generateRandomString(16) + require.NoError(t, err) + + tokenData1 := &TokenData{ + AccessToken: accessToken1, + RefreshToken: refreshToken1, + ClientID: clientID, + UserID: "test_user_123", + GrantID: grantID, + Scope: "read write admin", + ExpiresAt: time.Now().Add(1 * time.Hour), + // RefreshTokenExpiresAt will be set automatically to 30 days + CreatedAt: time.Now(), + Revoked: false, + } + + err = db.StoreToken(tokenData1) + require.NoError(t, err) + + // Verify refresh token expiration was set to 30 days + retrievedToken1, err := db.GetToken(accessToken1) + require.NoError(t, err) + assert.True(t, retrievedToken1.RefreshTokenExpiresAt.After(time.Now().Add(29*24*time.Hour))) + assert.True(t, retrievedToken1.RefreshTokenExpiresAt.Before(time.Now().Add(31*24*time.Hour))) + + // Test 2: Store token with custom refresh token expiration + accessToken2, err := generateRandomString(16) + require.NoError(t, err) + refreshToken2, err := generateRandomString(16) + require.NoError(t, err) + + customExpiration := time.Now().Add(7 * 24 * time.Hour) // 7 days + tokenData2 := &TokenData{ + AccessToken: accessToken2, + RefreshToken: refreshToken2, + ClientID: clientID, + UserID: "test_user_123", + GrantID: grantID, + Scope: "read write admin", + ExpiresAt: time.Now().Add(1 * time.Hour), + RefreshTokenExpiresAt: customExpiration, + CreatedAt: time.Now(), + Revoked: false, + } + + err = db.StoreToken(tokenData2) + require.NoError(t, err) + + // Verify custom expiration was preserved + retrievedToken2, err := db.GetToken(accessToken2) + require.NoError(t, err) + // Use tolerance-based comparison instead of exact equality due to database precision differences + timeDiff := retrievedToken2.RefreshTokenExpiresAt.Sub(customExpiration) + assert.True(t, timeDiff >= -time.Second && timeDiff <= time.Second, + "Refresh token expiration should be within 1 second of expected time, got difference of %v", timeDiff) + + // Test 3: Test refresh token expiration check + expired, err := db.IsRefreshTokenExpired(refreshToken1) + require.NoError(t, err) + assert.False(t, expired) + + // Test 4: Test cleanup with mixed expiration scenarios + // Create token with expired access token but valid refresh token + accessToken3, err := generateRandomString(16) + require.NoError(t, err) + refreshToken3, err := generateRandomString(16) + require.NoError(t, err) + + tokenData3 := &TokenData{ + AccessToken: accessToken3, + RefreshToken: refreshToken3, + ClientID: clientID, + UserID: "test_user_123", + GrantID: grantID, + Scope: "read write admin", + ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired access token + RefreshTokenExpiresAt: time.Now().Add(1 * time.Hour), // Valid refresh token + CreatedAt: time.Now().Add(-2 * time.Hour), + Revoked: false, + } + + err = db.StoreToken(tokenData3) + require.NoError(t, err) + + // Create token with both expired access and refresh tokens + accessToken4, err := generateRandomString(16) + require.NoError(t, err) + refreshToken4, err := generateRandomString(16) + require.NoError(t, err) + + tokenData4 := &TokenData{ + AccessToken: accessToken4, + RefreshToken: refreshToken4, + ClientID: clientID, + UserID: "test_user_123", + GrantID: grantID, + Scope: "read write admin", + ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired access token + RefreshTokenExpiresAt: time.Now().Add(-1 * time.Hour), // Expired refresh token + CreatedAt: time.Now().Add(-2 * time.Hour), + Revoked: false, + } + + err = db.StoreToken(tokenData4) + require.NoError(t, err) + + // Run cleanup + err = db.CleanupExpiredTokens() + require.NoError(t, err) + + // Verify that only the token with both expired access and refresh tokens was cleaned up + _, err = db.GetToken(accessToken3) // Should still exist (valid refresh token) + assert.NoError(t, err) + + _, err = db.GetToken(accessToken4) // Should be cleaned up (both expired) + assert.Error(t, err) + + // Test 5: Test that expired refresh tokens cannot be used + // Create a token with expired refresh token + accessToken5, err := generateRandomString(16) + require.NoError(t, err) + refreshToken5, err := generateRandomString(16) + require.NoError(t, err) + + tokenData5 := &TokenData{ + AccessToken: accessToken5, + RefreshToken: refreshToken5, + ClientID: clientID, + UserID: "test_user_123", + GrantID: grantID, + Scope: "read write admin", + ExpiresAt: time.Now().Add(1 * time.Hour), + RefreshTokenExpiresAt: time.Now().Add(-1 * time.Hour), // Expired refresh token + CreatedAt: time.Now().Add(-2 * time.Hour), + Revoked: false, + } + + err = db.StoreToken(tokenData5) + require.NoError(t, err) + + // Try to get token by expired refresh token + _, err = db.GetTokenByRefreshToken(refreshToken5) + assert.Error(t, err, "Should return error when trying to get token by expired refresh token") + + // Verify the token still exists when accessed by access token + retrievedToken5, err := db.GetToken(accessToken5) + require.NoError(t, err) + assert.Equal(t, hashToken(accessToken5), retrievedToken5.AccessToken) +} + +func testDatabaseMigration(t *testing.T, db *Database) { + // Test that the migration system works correctly + + // Test 1: Check if column exists function works + exists, err := db.columnExists("access_tokens", "refresh_token_expires_at") + require.NoError(t, err) + assert.True(t, exists, "refresh_token_expires_at column should exist after migration") + + // Test 2: Check that a non-existent column returns false + exists, err = db.columnExists("access_tokens", "non_existent_column") + require.NoError(t, err) + assert.False(t, exists, "non-existent column should return false") + + // Test 3: Test migration idempotency (running migration again should not fail) + err = db.migrateAddRefreshTokenExpiration() + require.NoError(t, err, "Migration should be idempotent and not fail when run again") + + // Test 4: Verify that existing tokens have refresh token expiration set + // Create a grant first + grantID, err := generateRandomString(16) + require.NoError(t, err) + + grant := &Grant{ + ID: grantID, + ClientID: "test_client_migration", + UserID: "test_user_migration", + Scope: []string{"read", "write"}, + Metadata: map[string]interface{}{"provider": "test"}, + CreatedAt: time.Now().Unix(), + ExpiresAt: time.Now().Add(time.Hour).Unix(), + } + + err = db.StoreGrant(grant) + require.NoError(t, err) + + // Store a token + accessToken, err := generateRandomString(16) + require.NoError(t, err) + refreshToken, err := generateRandomString(16) + require.NoError(t, err) + + tokenData := &TokenData{ + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: "test_client_migration", + UserID: "test_user_migration", + GrantID: grantID, + Scope: "read write", + ExpiresAt: time.Now().Add(time.Hour), + // RefreshTokenExpiresAt will be set automatically + } + + err = db.StoreToken(tokenData) + require.NoError(t, err) + + // Verify the token has refresh token expiration set + retrievedToken, err := db.GetToken(accessToken) + require.NoError(t, err) + assert.True(t, retrievedToken.RefreshTokenExpiresAt.After(time.Now().Add(29*24*time.Hour)), + "Refresh token should expire in approximately 30 days") + assert.True(t, retrievedToken.RefreshTokenExpiresAt.Before(time.Now().Add(31*24*time.Hour)), + "Refresh token should expire in approximately 30 days") +} diff --git a/database/sqlite_test.go b/database/sqlite_test.go index eaf637e..6876a9e 100644 --- a/database/sqlite_test.go +++ b/database/sqlite_test.go @@ -93,6 +93,7 @@ func TestSQLiteDatabase(t *testing.T) { GrantID: "test_grant_sqlite", Scope: "openid profile", ExpiresAt: time.Now().Add(time.Hour), + // RefreshTokenExpiresAt will be set automatically to 30 days } // Store token @@ -108,6 +109,18 @@ func TestSQLiteDatabase(t *testing.T) { assert.Equal(t, tokenData.UserID, retrievedToken.UserID) assert.Equal(t, tokenData.GrantID, retrievedToken.GrantID) assert.Equal(t, tokenData.Scope, retrievedToken.Scope) + assert.True(t, retrievedToken.RefreshTokenExpiresAt.After(time.Now().Add(29*24*time.Hour))) // Should be ~30 days + + // Test retrieving by refresh token + refreshToken, err := db.GetTokenByRefreshToken("test_refresh_token_sqlite") + require.NoError(t, err) + assert.Equal(t, tokenData.ClientID, refreshToken.ClientID) + assert.True(t, refreshToken.RefreshTokenExpiresAt.After(time.Now().Add(29*24*time.Hour))) + + // Test refresh token expiration check + expired, err := db.IsRefreshTokenExpired("test_refresh_token_sqlite") + require.NoError(t, err) + assert.False(t, expired) }) t.Run("TestAuthCodeOperations", func(t *testing.T) { diff --git a/main.go b/main.go index 0e1d182..1665155 100644 --- a/main.go +++ b/main.go @@ -764,7 +764,7 @@ func (p *OAuthProxy) callbackHandler(c *gin.Context) { }, Props: props, CreatedAt: now, - ExpiresAt: now + 600, // 10 minutes for authorization code + ExpiresAt: now + 3600*24*30, // 30 days to expire for grant, same as refresh token CodeChallenge: authReq.CodeChallenge, CodeChallengeMethod: authReq.CodeChallengeMethod, } @@ -1336,7 +1336,7 @@ func (p *OAuthProxy) handleRefreshTokenGrant(c *gin.Context, clientID string) { // Validate refresh token from database tokenData, err := p.db.GetTokenByRefreshToken(refreshToken) if err != nil { - c.JSON(http.StatusBadRequest, OAuthError{ + c.JSON(http.StatusUnauthorized, OAuthError{ Error: "invalid_grant", ErrorDescription: "Invalid refresh token", }) @@ -1345,13 +1345,22 @@ func (p *OAuthProxy) handleRefreshTokenGrant(c *gin.Context, clientID string) { // Check if token is revoked if tokenData.Revoked { - c.JSON(http.StatusBadRequest, OAuthError{ + c.JSON(http.StatusUnauthorized, OAuthError{ Error: "invalid_grant", ErrorDescription: "Token has been revoked", }) return } + // Check if refresh token is expired + if time.Now().After(tokenData.RefreshTokenExpiresAt) { + c.JSON(http.StatusUnauthorized, OAuthError{ + Error: "invalid_grant", + ErrorDescription: "Refresh token has expired", + }) + return + } + // Check if token belongs to the requesting client if tokenData.ClientID != clientID { c.JSON(http.StatusBadRequest, OAuthError{