Skip to content

Commit b840f98

Browse files
committed
Fix: fix grant and refresh token cleanup
Signed-off-by: Daishan Peng <[email protected]>
1 parent fa27c2e commit b840f98

File tree

4 files changed

+446
-48
lines changed

4 files changed

+446
-48
lines changed

database/database.go

Lines changed: 156 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@ type Database struct {
2323

2424
// TokenData represents stored token data for OAuth 2.1 compliance
2525
type TokenData struct {
26-
AccessToken string
27-
RefreshToken string
28-
ClientID string
29-
UserID string
30-
GrantID string
31-
Scope string
32-
ExpiresAt time.Time
33-
CreatedAt time.Time
34-
Revoked bool
35-
RevokedAt *time.Time
26+
AccessToken string
27+
RefreshToken string
28+
ClientID string
29+
UserID string
30+
GrantID string
31+
Scope string
32+
ExpiresAt time.Time
33+
RefreshTokenExpiresAt time.Time
34+
CreatedAt time.Time
35+
Revoked bool
36+
RevokedAt *time.Time
3637
}
3738

3839
// ClientInfo represents OAuth client registration information
@@ -116,8 +117,23 @@ func NewDatabase(dsn string) (*Database, error) {
116117
return database, nil
117118
}
118119

119-
// setupSchema creates the necessary tables
120+
// setupSchema creates the necessary tables and handles migrations
120121
func (d *Database) setupSchema() error {
122+
// First, create tables if they don't exist
123+
if err := d.createTables(); err != nil {
124+
return err
125+
}
126+
127+
// Then run migrations
128+
if err := d.runMigrations(); err != nil {
129+
return err
130+
}
131+
132+
return nil
133+
}
134+
135+
// createTables creates the base tables
136+
func (d *Database) createTables() error {
121137
var queries []string
122138

123139
if d.dbType == "postgres" {
@@ -254,6 +270,86 @@ func (d *Database) setupSchema() error {
254270
return nil
255271
}
256272

273+
// runMigrations handles database schema migrations
274+
func (d *Database) runMigrations() error {
275+
// Migration 1: Add refresh_token_expires_at column to access_tokens table
276+
if err := d.migrateAddRefreshTokenExpiration(); err != nil {
277+
return fmt.Errorf("failed to run migration 1: %w", err)
278+
}
279+
280+
return nil
281+
}
282+
283+
// migrateAddRefreshTokenExpiration adds the refresh_token_expires_at column
284+
func (d *Database) migrateAddRefreshTokenExpiration() error {
285+
// Check if the column already exists
286+
columnExists, err := d.columnExists("access_tokens", "refresh_token_expires_at")
287+
if err != nil {
288+
return fmt.Errorf("failed to check if column exists: %w", err)
289+
}
290+
291+
if columnExists {
292+
return nil // Column already exists, no migration needed
293+
}
294+
295+
// Add the column
296+
var query string
297+
if d.dbType == "postgres" {
298+
query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at TIMESTAMPTZ`
299+
} else {
300+
query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at DATETIME`
301+
}
302+
303+
if _, err := d.db.Exec(query); err != nil {
304+
return fmt.Errorf("failed to add refresh_token_expires_at column: %w", err)
305+
}
306+
307+
// Update existing records to have a default expiration (30 days from now)
308+
updateQuery := `UPDATE access_tokens SET refresh_token_expires_at = ? WHERE refresh_token_expires_at IS NULL`
309+
if d.dbType == "postgres" {
310+
updateQuery = `UPDATE access_tokens SET refresh_token_expires_at = $1 WHERE refresh_token_expires_at IS NULL`
311+
}
312+
313+
defaultExpiration := time.Now().Add(30 * 24 * time.Hour)
314+
if _, err := d.db.Exec(updateQuery, defaultExpiration); err != nil {
315+
return fmt.Errorf("failed to update existing records: %w", err)
316+
}
317+
318+
// Make the column NOT NULL (SQLite doesn't support ALTER COLUMN SET NOT NULL directly)
319+
if d.dbType == "postgres" {
320+
query = `ALTER TABLE access_tokens ALTER COLUMN refresh_token_expires_at SET NOT NULL`
321+
if _, err := d.db.Exec(query); err != nil {
322+
return fmt.Errorf("failed to make refresh_token_expires_at NOT NULL: %w", err)
323+
}
324+
}
325+
326+
return nil
327+
}
328+
329+
// columnExists checks if a column exists in a table
330+
func (d *Database) columnExists(tableName, columnName string) (bool, error) {
331+
var query string
332+
if d.dbType == "postgres" {
333+
query = `
334+
SELECT COUNT(*) FROM information_schema.columns
335+
WHERE table_name = $1 AND column_name = $2
336+
`
337+
} else {
338+
query = `
339+
SELECT COUNT(*) FROM pragma_table_info(?)
340+
WHERE name = ?
341+
`
342+
}
343+
344+
var count int
345+
err := d.db.QueryRow(query, tableName, columnName).Scan(&count)
346+
if err != nil {
347+
return false, err
348+
}
349+
350+
return count > 0, nil
351+
}
352+
257353
// GetClient retrieves a client by ID
258354
func (d *Database) GetClient(clientID string) (*ClientInfo, error) {
259355
var query string
@@ -522,21 +618,26 @@ func (d *Database) StoreToken(data *TokenData) error {
522618
var query string
523619
if d.dbType == "postgres" {
524620
query = `
525-
INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at)
526-
VALUES ($1, $2, $3, $4, $5, $6, $7)
621+
INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at)
622+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
527623
`
528624
} else {
529625
query = `
530-
INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at)
531-
VALUES (?, ?, ?, ?, ?, ?, ?)
626+
INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at)
627+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
532628
`
533629
}
534630

535631
// Hash the refresh token for secure storage
536632
hashedAccessToken := hashToken(data.AccessToken)
537633
hashedRefreshToken := hashToken(data.RefreshToken)
538634

539-
_, err := d.db.Exec(query, hashedAccessToken, hashedRefreshToken, data.ClientID, data.UserID, data.GrantID, data.Scope, data.ExpiresAt)
635+
// Set refresh token expiration to 30 days from now if not already set
636+
if data.RefreshTokenExpiresAt.IsZero() {
637+
data.RefreshTokenExpiresAt = time.Now().Add(30 * 24 * time.Hour)
638+
}
639+
640+
_, err := d.db.Exec(query, hashedAccessToken, hashedRefreshToken, data.ClientID, data.UserID, data.GrantID, data.Scope, data.ExpiresAt, data.RefreshTokenExpiresAt)
540641
return err
541642
}
542643

@@ -545,12 +646,12 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) {
545646
var query string
546647
if d.dbType == "postgres" {
547648
query = `
548-
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
649+
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
549650
FROM access_tokens WHERE access_token = $1
550651
`
551652
} else {
552653
query = `
553-
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
654+
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
554655
FROM access_tokens WHERE access_token = ?
555656
`
556657
}
@@ -567,6 +668,7 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) {
567668
&data.GrantID,
568669
&data.Scope,
569670
&data.ExpiresAt,
671+
&data.RefreshTokenExpiresAt,
570672
&data.CreatedAt,
571673
&data.Revoked,
572674
&revokedAt,
@@ -588,13 +690,13 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
588690
var query string
589691
if d.dbType == "postgres" {
590692
query = `
591-
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
592-
FROM access_tokens WHERE refresh_token = $1
693+
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
694+
FROM access_tokens WHERE refresh_token = $1 AND refresh_token_expires_at > NOW()
593695
`
594696
} else {
595697
query = `
596-
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
597-
FROM access_tokens WHERE refresh_token = ?
698+
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
699+
FROM access_tokens WHERE refresh_token = ? AND refresh_token_expires_at > datetime('now')
598700
`
599701
}
600702

@@ -611,6 +713,7 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
611713
&data.GrantID,
612714
&data.Scope,
613715
&data.ExpiresAt,
716+
&data.RefreshTokenExpiresAt,
614717
&data.CreatedAt,
615718
&data.Revoked,
616719
&revokedAt,
@@ -630,6 +733,29 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
630733
return &data, nil
631734
}
632735

736+
// IsRefreshTokenExpired checks if a refresh token is expired
737+
func (d *Database) IsRefreshTokenExpired(refreshToken string) (bool, error) {
738+
var query string
739+
if d.dbType == "postgres" {
740+
query = `
741+
SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = $1
742+
`
743+
} else {
744+
query = `
745+
SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = ?
746+
`
747+
}
748+
749+
hashedRefreshToken := hashToken(refreshToken)
750+
var expiresAt time.Time
751+
err := d.db.QueryRow(query, hashedRefreshToken).Scan(&expiresAt)
752+
if err != nil {
753+
return false, err
754+
}
755+
756+
return time.Now().After(expiresAt), nil
757+
}
758+
633759
// RevokeToken revokes an access token
634760
func (d *Database) RevokeToken(token string) error {
635761
hashedToken := hashToken(token)
@@ -685,22 +811,27 @@ func (d *Database) CleanupExpiredTokens() error {
685811

686812
if d.dbType == "postgres" {
687813
queries = []string{
688-
`DELETE FROM access_tokens WHERE expires_at < NOW() OR revoked = TRUE`,
814+
`DELETE FROM access_tokens WHERE (expires_at < NOW() AND refresh_token_expires_at < NOW()) OR revoked = TRUE`,
689815
`DELETE FROM authorization_codes WHERE expires_at < NOW()`,
690816
`DELETE FROM grants WHERE expires_at < EXTRACT(EPOCH FROM NOW())`,
691817
}
692818
} else {
693819
queries = []string{
694-
`DELETE FROM access_tokens WHERE expires_at < datetime('now') OR revoked = 1`,
820+
`DELETE FROM access_tokens WHERE (expires_at < datetime('now') AND refresh_token_expires_at < datetime('now')) OR revoked = 1`,
695821
`DELETE FROM authorization_codes WHERE expires_at < datetime('now')`,
696822
`DELETE FROM grants WHERE expires_at < strftime('%s', 'now')`,
697823
}
698824
}
699825

700826
for _, query := range queries {
701-
if _, err := d.db.Exec(query); err != nil {
827+
result, err := d.db.Exec(query)
828+
if err != nil {
702829
return fmt.Errorf("failed to cleanup expired tokens: %w", err)
703830
}
831+
rowsAffected, _ := result.RowsAffected()
832+
if rowsAffected > 0 {
833+
fmt.Printf("Deleted %d expired rows for query %s\n", rowsAffected, query)
834+
}
704835
}
705836

706837
return nil

0 commit comments

Comments
 (0)