@@ -23,16 +23,17 @@ type Database struct {
23
23
24
24
// TokenData represents stored token data for OAuth 2.1 compliance
25
25
type TokenData struct {
26
- AccessToken string
27
- RefreshToken string
28
- ClientID string
29
- UserID string
30
- GrantID string
31
- Scope string
32
- ExpiresAt time.Time
33
- CreatedAt time.Time
34
- Revoked bool
35
- RevokedAt * time.Time
26
+ AccessToken string
27
+ RefreshToken string
28
+ ClientID string
29
+ UserID string
30
+ GrantID string
31
+ Scope string
32
+ ExpiresAt time.Time
33
+ RefreshTokenExpiresAt time.Time
34
+ CreatedAt time.Time
35
+ Revoked bool
36
+ RevokedAt * time.Time
36
37
}
37
38
38
39
// ClientInfo represents OAuth client registration information
@@ -116,8 +117,23 @@ func NewDatabase(dsn string) (*Database, error) {
116
117
return database , nil
117
118
}
118
119
119
- // setupSchema creates the necessary tables
120
+ // setupSchema creates the necessary tables and handles migrations
120
121
func (d * Database ) setupSchema () error {
122
+ // First, create tables if they don't exist
123
+ if err := d .createTables (); err != nil {
124
+ return err
125
+ }
126
+
127
+ // Then run migrations
128
+ if err := d .runMigrations (); err != nil {
129
+ return err
130
+ }
131
+
132
+ return nil
133
+ }
134
+
135
+ // createTables creates the base tables
136
+ func (d * Database ) createTables () error {
121
137
var queries []string
122
138
123
139
if d .dbType == "postgres" {
@@ -254,6 +270,86 @@ func (d *Database) setupSchema() error {
254
270
return nil
255
271
}
256
272
273
+ // runMigrations handles database schema migrations
274
+ func (d * Database ) runMigrations () error {
275
+ // Migration 1: Add refresh_token_expires_at column to access_tokens table
276
+ if err := d .migrateAddRefreshTokenExpiration (); err != nil {
277
+ return fmt .Errorf ("failed to run migration 1: %w" , err )
278
+ }
279
+
280
+ return nil
281
+ }
282
+
283
+ // migrateAddRefreshTokenExpiration adds the refresh_token_expires_at column
284
+ func (d * Database ) migrateAddRefreshTokenExpiration () error {
285
+ // Check if the column already exists
286
+ columnExists , err := d .columnExists ("access_tokens" , "refresh_token_expires_at" )
287
+ if err != nil {
288
+ return fmt .Errorf ("failed to check if column exists: %w" , err )
289
+ }
290
+
291
+ if columnExists {
292
+ return nil // Column already exists, no migration needed
293
+ }
294
+
295
+ // Add the column
296
+ var query string
297
+ if d .dbType == "postgres" {
298
+ query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at TIMESTAMPTZ`
299
+ } else {
300
+ query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at DATETIME`
301
+ }
302
+
303
+ if _ , err := d .db .Exec (query ); err != nil {
304
+ return fmt .Errorf ("failed to add refresh_token_expires_at column: %w" , err )
305
+ }
306
+
307
+ // Update existing records to have a default expiration (30 days from now)
308
+ updateQuery := `UPDATE access_tokens SET refresh_token_expires_at = ? WHERE refresh_token_expires_at IS NULL`
309
+ if d .dbType == "postgres" {
310
+ updateQuery = `UPDATE access_tokens SET refresh_token_expires_at = $1 WHERE refresh_token_expires_at IS NULL`
311
+ }
312
+
313
+ defaultExpiration := time .Now ().Add (30 * 24 * time .Hour )
314
+ if _ , err := d .db .Exec (updateQuery , defaultExpiration ); err != nil {
315
+ return fmt .Errorf ("failed to update existing records: %w" , err )
316
+ }
317
+
318
+ // Make the column NOT NULL (SQLite doesn't support ALTER COLUMN SET NOT NULL directly)
319
+ if d .dbType == "postgres" {
320
+ query = `ALTER TABLE access_tokens ALTER COLUMN refresh_token_expires_at SET NOT NULL`
321
+ if _ , err := d .db .Exec (query ); err != nil {
322
+ return fmt .Errorf ("failed to make refresh_token_expires_at NOT NULL: %w" , err )
323
+ }
324
+ }
325
+
326
+ return nil
327
+ }
328
+
329
+ // columnExists checks if a column exists in a table
330
+ func (d * Database ) columnExists (tableName , columnName string ) (bool , error ) {
331
+ var query string
332
+ if d .dbType == "postgres" {
333
+ query = `
334
+ SELECT COUNT(*) FROM information_schema.columns
335
+ WHERE table_name = $1 AND column_name = $2
336
+ `
337
+ } else {
338
+ query = `
339
+ SELECT COUNT(*) FROM pragma_table_info(?)
340
+ WHERE name = ?
341
+ `
342
+ }
343
+
344
+ var count int
345
+ err := d .db .QueryRow (query , tableName , columnName ).Scan (& count )
346
+ if err != nil {
347
+ return false , err
348
+ }
349
+
350
+ return count > 0 , nil
351
+ }
352
+
257
353
// GetClient retrieves a client by ID
258
354
func (d * Database ) GetClient (clientID string ) (* ClientInfo , error ) {
259
355
var query string
@@ -522,21 +618,26 @@ func (d *Database) StoreToken(data *TokenData) error {
522
618
var query string
523
619
if d .dbType == "postgres" {
524
620
query = `
525
- INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at)
526
- VALUES ($1, $2, $3, $4, $5, $6, $7)
621
+ INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at )
622
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8 )
527
623
`
528
624
} else {
529
625
query = `
530
- INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at)
531
- VALUES (?, ?, ?, ?, ?, ?, ?)
626
+ INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at )
627
+ VALUES (?, ?, ?, ?, ?, ?, ?, ? )
532
628
`
533
629
}
534
630
535
631
// Hash the refresh token for secure storage
536
632
hashedAccessToken := hashToken (data .AccessToken )
537
633
hashedRefreshToken := hashToken (data .RefreshToken )
538
634
539
- _ , err := d .db .Exec (query , hashedAccessToken , hashedRefreshToken , data .ClientID , data .UserID , data .GrantID , data .Scope , data .ExpiresAt )
635
+ // Set refresh token expiration to 30 days from now if not already set
636
+ if data .RefreshTokenExpiresAt .IsZero () {
637
+ data .RefreshTokenExpiresAt = time .Now ().Add (30 * 24 * time .Hour )
638
+ }
639
+
640
+ _ , err := d .db .Exec (query , hashedAccessToken , hashedRefreshToken , data .ClientID , data .UserID , data .GrantID , data .Scope , data .ExpiresAt , data .RefreshTokenExpiresAt )
540
641
return err
541
642
}
542
643
@@ -545,12 +646,12 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) {
545
646
var query string
546
647
if d .dbType == "postgres" {
547
648
query = `
548
- SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
649
+ SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
549
650
FROM access_tokens WHERE access_token = $1
550
651
`
551
652
} else {
552
653
query = `
553
- SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
654
+ SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
554
655
FROM access_tokens WHERE access_token = ?
555
656
`
556
657
}
@@ -567,6 +668,7 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) {
567
668
& data .GrantID ,
568
669
& data .Scope ,
569
670
& data .ExpiresAt ,
671
+ & data .RefreshTokenExpiresAt ,
570
672
& data .CreatedAt ,
571
673
& data .Revoked ,
572
674
& revokedAt ,
@@ -588,13 +690,13 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
588
690
var query string
589
691
if d .dbType == "postgres" {
590
692
query = `
591
- SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
592
- FROM access_tokens WHERE refresh_token = $1
693
+ SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
694
+ FROM access_tokens WHERE refresh_token = $1 AND refresh_token_expires_at > NOW()
593
695
`
594
696
} else {
595
697
query = `
596
- SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
597
- FROM access_tokens WHERE refresh_token = ?
698
+ SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
699
+ FROM access_tokens WHERE refresh_token = ? AND refresh_token_expires_at > datetime('now')
598
700
`
599
701
}
600
702
@@ -611,6 +713,7 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
611
713
& data .GrantID ,
612
714
& data .Scope ,
613
715
& data .ExpiresAt ,
716
+ & data .RefreshTokenExpiresAt ,
614
717
& data .CreatedAt ,
615
718
& data .Revoked ,
616
719
& revokedAt ,
@@ -630,6 +733,29 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
630
733
return & data , nil
631
734
}
632
735
736
+ // IsRefreshTokenExpired checks if a refresh token is expired
737
+ func (d * Database ) IsRefreshTokenExpired (refreshToken string ) (bool , error ) {
738
+ var query string
739
+ if d .dbType == "postgres" {
740
+ query = `
741
+ SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = $1
742
+ `
743
+ } else {
744
+ query = `
745
+ SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = ?
746
+ `
747
+ }
748
+
749
+ hashedRefreshToken := hashToken (refreshToken )
750
+ var expiresAt time.Time
751
+ err := d .db .QueryRow (query , hashedRefreshToken ).Scan (& expiresAt )
752
+ if err != nil {
753
+ return false , err
754
+ }
755
+
756
+ return time .Now ().After (expiresAt ), nil
757
+ }
758
+
633
759
// RevokeToken revokes an access token
634
760
func (d * Database ) RevokeToken (token string ) error {
635
761
hashedToken := hashToken (token )
@@ -685,22 +811,27 @@ func (d *Database) CleanupExpiredTokens() error {
685
811
686
812
if d .dbType == "postgres" {
687
813
queries = []string {
688
- `DELETE FROM access_tokens WHERE expires_at < NOW() OR revoked = TRUE` ,
814
+ `DELETE FROM access_tokens WHERE ( expires_at < NOW() AND refresh_token_expires_at < NOW() ) OR revoked = TRUE` ,
689
815
`DELETE FROM authorization_codes WHERE expires_at < NOW()` ,
690
816
`DELETE FROM grants WHERE expires_at < EXTRACT(EPOCH FROM NOW())` ,
691
817
}
692
818
} else {
693
819
queries = []string {
694
- `DELETE FROM access_tokens WHERE expires_at < datetime('now') OR revoked = 1` ,
820
+ `DELETE FROM access_tokens WHERE ( expires_at < datetime('now') AND refresh_token_expires_at < datetime('now') ) OR revoked = 1` ,
695
821
`DELETE FROM authorization_codes WHERE expires_at < datetime('now')` ,
696
822
`DELETE FROM grants WHERE expires_at < strftime('%s', 'now')` ,
697
823
}
698
824
}
699
825
700
826
for _ , query := range queries {
701
- if _ , err := d .db .Exec (query ); err != nil {
827
+ result , err := d .db .Exec (query )
828
+ if err != nil {
702
829
return fmt .Errorf ("failed to cleanup expired tokens: %w" , err )
703
830
}
831
+ rowsAffected , _ := result .RowsAffected ()
832
+ if rowsAffected > 0 {
833
+ fmt .Printf ("Deleted %d expired rows for query %s\n " , rowsAffected , query )
834
+ }
704
835
}
705
836
706
837
return nil
0 commit comments