@@ -23,16 +23,17 @@ type Database struct {
2323
2424// TokenData represents stored token data for OAuth 2.1 compliance
2525type TokenData struct {
26- AccessToken string
27- RefreshToken string
28- ClientID string
29- UserID string
30- GrantID string
31- Scope string
32- ExpiresAt time.Time
33- CreatedAt time.Time
34- Revoked bool
35- RevokedAt * time.Time
26+ AccessToken string
27+ RefreshToken string
28+ ClientID string
29+ UserID string
30+ GrantID string
31+ Scope string
32+ ExpiresAt time.Time
33+ RefreshTokenExpiresAt time.Time
34+ CreatedAt time.Time
35+ Revoked bool
36+ RevokedAt * time.Time
3637}
3738
3839// ClientInfo represents OAuth client registration information
@@ -116,8 +117,23 @@ func NewDatabase(dsn string) (*Database, error) {
116117 return database , nil
117118}
118119
119- // setupSchema creates the necessary tables
120+ // setupSchema creates the necessary tables and handles migrations
120121func (d * Database ) setupSchema () error {
122+ // First, create tables if they don't exist
123+ if err := d .createTables (); err != nil {
124+ return err
125+ }
126+
127+ // Then run migrations
128+ if err := d .runMigrations (); err != nil {
129+ return err
130+ }
131+
132+ return nil
133+ }
134+
135+ // createTables creates the base tables
136+ func (d * Database ) createTables () error {
121137 var queries []string
122138
123139 if d .dbType == "postgres" {
@@ -254,6 +270,86 @@ func (d *Database) setupSchema() error {
254270 return nil
255271}
256272
273+ // runMigrations handles database schema migrations
274+ func (d * Database ) runMigrations () error {
275+ // Migration 1: Add refresh_token_expires_at column to access_tokens table
276+ if err := d .migrateAddRefreshTokenExpiration (); err != nil {
277+ return fmt .Errorf ("failed to run migration 1: %w" , err )
278+ }
279+
280+ return nil
281+ }
282+
283+ // migrateAddRefreshTokenExpiration adds the refresh_token_expires_at column
284+ func (d * Database ) migrateAddRefreshTokenExpiration () error {
285+ // Check if the column already exists
286+ columnExists , err := d .columnExists ("access_tokens" , "refresh_token_expires_at" )
287+ if err != nil {
288+ return fmt .Errorf ("failed to check if column exists: %w" , err )
289+ }
290+
291+ if columnExists {
292+ return nil // Column already exists, no migration needed
293+ }
294+
295+ // Add the column
296+ var query string
297+ if d .dbType == "postgres" {
298+ query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at TIMESTAMPTZ`
299+ } else {
300+ query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at DATETIME`
301+ }
302+
303+ if _ , err := d .db .Exec (query ); err != nil {
304+ return fmt .Errorf ("failed to add refresh_token_expires_at column: %w" , err )
305+ }
306+
307+ // Update existing records to have a default expiration (30 days from now)
308+ updateQuery := `UPDATE access_tokens SET refresh_token_expires_at = ? WHERE refresh_token_expires_at IS NULL`
309+ if d .dbType == "postgres" {
310+ updateQuery = `UPDATE access_tokens SET refresh_token_expires_at = $1 WHERE refresh_token_expires_at IS NULL`
311+ }
312+
313+ defaultExpiration := time .Now ().Add (30 * 24 * time .Hour )
314+ if _ , err := d .db .Exec (updateQuery , defaultExpiration ); err != nil {
315+ return fmt .Errorf ("failed to update existing records: %w" , err )
316+ }
317+
318+ // Make the column NOT NULL (SQLite doesn't support ALTER COLUMN SET NOT NULL directly)
319+ if d .dbType == "postgres" {
320+ query = `ALTER TABLE access_tokens ALTER COLUMN refresh_token_expires_at SET NOT NULL`
321+ if _ , err := d .db .Exec (query ); err != nil {
322+ return fmt .Errorf ("failed to make refresh_token_expires_at NOT NULL: %w" , err )
323+ }
324+ }
325+
326+ return nil
327+ }
328+
329+ // columnExists checks if a column exists in a table
330+ func (d * Database ) columnExists (tableName , columnName string ) (bool , error ) {
331+ var query string
332+ if d .dbType == "postgres" {
333+ query = `
334+ SELECT COUNT(*) FROM information_schema.columns
335+ WHERE table_name = $1 AND column_name = $2
336+ `
337+ } else {
338+ query = `
339+ SELECT COUNT(*) FROM pragma_table_info(?)
340+ WHERE name = ?
341+ `
342+ }
343+
344+ var count int
345+ err := d .db .QueryRow (query , tableName , columnName ).Scan (& count )
346+ if err != nil {
347+ return false , err
348+ }
349+
350+ return count > 0 , nil
351+ }
352+
257353// GetClient retrieves a client by ID
258354func (d * Database ) GetClient (clientID string ) (* ClientInfo , error ) {
259355 var query string
@@ -522,21 +618,26 @@ func (d *Database) StoreToken(data *TokenData) error {
522618 var query string
523619 if d .dbType == "postgres" {
524620 query = `
525- INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at)
526- VALUES ($1, $2, $3, $4, $5, $6, $7)
621+ INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at )
622+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8 )
527623 `
528624 } else {
529625 query = `
530- INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at)
531- VALUES (?, ?, ?, ?, ?, ?, ?)
626+ INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at )
627+ VALUES (?, ?, ?, ?, ?, ?, ?, ? )
532628 `
533629 }
534630
535631 // Hash the refresh token for secure storage
536632 hashedAccessToken := hashToken (data .AccessToken )
537633 hashedRefreshToken := hashToken (data .RefreshToken )
538634
539- _ , err := d .db .Exec (query , hashedAccessToken , hashedRefreshToken , data .ClientID , data .UserID , data .GrantID , data .Scope , data .ExpiresAt )
635+ // Set refresh token expiration to 30 days from now if not already set
636+ if data .RefreshTokenExpiresAt .IsZero () {
637+ data .RefreshTokenExpiresAt = time .Now ().Add (30 * 24 * time .Hour )
638+ }
639+
640+ _ , err := d .db .Exec (query , hashedAccessToken , hashedRefreshToken , data .ClientID , data .UserID , data .GrantID , data .Scope , data .ExpiresAt , data .RefreshTokenExpiresAt )
540641 return err
541642}
542643
@@ -545,12 +646,12 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) {
545646 var query string
546647 if d .dbType == "postgres" {
547648 query = `
548- SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
649+ SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
549650 FROM access_tokens WHERE access_token = $1
550651 `
551652 } else {
552653 query = `
553- SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
654+ SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
554655 FROM access_tokens WHERE access_token = ?
555656 `
556657 }
@@ -567,6 +668,7 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) {
567668 & data .GrantID ,
568669 & data .Scope ,
569670 & data .ExpiresAt ,
671+ & data .RefreshTokenExpiresAt ,
570672 & data .CreatedAt ,
571673 & data .Revoked ,
572674 & revokedAt ,
@@ -588,13 +690,13 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
588690 var query string
589691 if d .dbType == "postgres" {
590692 query = `
591- SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
592- FROM access_tokens WHERE refresh_token = $1
693+ SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
694+ FROM access_tokens WHERE refresh_token = $1 AND refresh_token_expires_at > NOW()
593695 `
594696 } else {
595697 query = `
596- SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
597- FROM access_tokens WHERE refresh_token = ?
698+ SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
699+ FROM access_tokens WHERE refresh_token = ? AND refresh_token_expires_at > datetime('now')
598700 `
599701 }
600702
@@ -611,6 +713,7 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
611713 & data .GrantID ,
612714 & data .Scope ,
613715 & data .ExpiresAt ,
716+ & data .RefreshTokenExpiresAt ,
614717 & data .CreatedAt ,
615718 & data .Revoked ,
616719 & revokedAt ,
@@ -630,6 +733,29 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
630733 return & data , nil
631734}
632735
736+ // IsRefreshTokenExpired checks if a refresh token is expired
737+ func (d * Database ) IsRefreshTokenExpired (refreshToken string ) (bool , error ) {
738+ var query string
739+ if d .dbType == "postgres" {
740+ query = `
741+ SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = $1
742+ `
743+ } else {
744+ query = `
745+ SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = ?
746+ `
747+ }
748+
749+ hashedRefreshToken := hashToken (refreshToken )
750+ var expiresAt time.Time
751+ err := d .db .QueryRow (query , hashedRefreshToken ).Scan (& expiresAt )
752+ if err != nil {
753+ return false , err
754+ }
755+
756+ return time .Now ().After (expiresAt ), nil
757+ }
758+
633759// RevokeToken revokes an access token
634760func (d * Database ) RevokeToken (token string ) error {
635761 hashedToken := hashToken (token )
@@ -685,22 +811,27 @@ func (d *Database) CleanupExpiredTokens() error {
685811
686812 if d .dbType == "postgres" {
687813 queries = []string {
688- `DELETE FROM access_tokens WHERE expires_at < NOW() OR revoked = TRUE` ,
814+ `DELETE FROM access_tokens WHERE ( expires_at < NOW() AND refresh_token_expires_at < NOW() ) OR revoked = TRUE` ,
689815 `DELETE FROM authorization_codes WHERE expires_at < NOW()` ,
690816 `DELETE FROM grants WHERE expires_at < EXTRACT(EPOCH FROM NOW())` ,
691817 }
692818 } else {
693819 queries = []string {
694- `DELETE FROM access_tokens WHERE expires_at < datetime('now') OR revoked = 1` ,
820+ `DELETE FROM access_tokens WHERE ( expires_at < datetime('now') AND refresh_token_expires_at < datetime('now') ) OR revoked = 1` ,
695821 `DELETE FROM authorization_codes WHERE expires_at < datetime('now')` ,
696822 `DELETE FROM grants WHERE expires_at < strftime('%s', 'now')` ,
697823 }
698824 }
699825
700826 for _ , query := range queries {
701- if _ , err := d .db .Exec (query ); err != nil {
827+ result , err := d .db .Exec (query )
828+ if err != nil {
702829 return fmt .Errorf ("failed to cleanup expired tokens: %w" , err )
703830 }
831+ rowsAffected , _ := result .RowsAffected ()
832+ if rowsAffected > 0 {
833+ fmt .Printf ("Deleted %d expired rows for query %s\n " , rowsAffected , query )
834+ }
704835 }
705836
706837 return nil
0 commit comments