Skip to content

Fix: fix grant and refresh token cleanup #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 156 additions & 25 deletions database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ type Database struct {

// TokenData represents stored token data for OAuth 2.1 compliance
type TokenData struct {
AccessToken string
RefreshToken string
ClientID string
UserID string
GrantID string
Scope string
ExpiresAt time.Time
CreatedAt time.Time
Revoked bool
RevokedAt *time.Time
AccessToken string
RefreshToken string
ClientID string
UserID string
GrantID string
Scope string
ExpiresAt time.Time
RefreshTokenExpiresAt time.Time
CreatedAt time.Time
Revoked bool
RevokedAt *time.Time
}

// ClientInfo represents OAuth client registration information
Expand Down Expand Up @@ -116,8 +117,23 @@ func NewDatabase(dsn string) (*Database, error) {
return database, nil
}

// setupSchema creates the necessary tables
// setupSchema creates the necessary tables and handles migrations
func (d *Database) setupSchema() error {
// First, create tables if they don't exist
if err := d.createTables(); err != nil {
return err
}

// Then run migrations
if err := d.runMigrations(); err != nil {
return err
}

return nil
}

// createTables creates the base tables
func (d *Database) createTables() error {
var queries []string

if d.dbType == "postgres" {
Expand Down Expand Up @@ -254,6 +270,86 @@ func (d *Database) setupSchema() error {
return nil
}

// runMigrations handles database schema migrations
func (d *Database) runMigrations() error {
// Migration 1: Add refresh_token_expires_at column to access_tokens table
if err := d.migrateAddRefreshTokenExpiration(); err != nil {
return fmt.Errorf("failed to run migration 1: %w", err)
}

return nil
}

// migrateAddRefreshTokenExpiration adds the refresh_token_expires_at column
func (d *Database) migrateAddRefreshTokenExpiration() error {
// Check if the column already exists
columnExists, err := d.columnExists("access_tokens", "refresh_token_expires_at")
if err != nil {
return fmt.Errorf("failed to check if column exists: %w", err)
}

if columnExists {
return nil // Column already exists, no migration needed
}

// Add the column
var query string
if d.dbType == "postgres" {
query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at TIMESTAMPTZ`
} else {
query = `ALTER TABLE access_tokens ADD COLUMN refresh_token_expires_at DATETIME`
}

if _, err := d.db.Exec(query); err != nil {
return fmt.Errorf("failed to add refresh_token_expires_at column: %w", err)
}

// Update existing records to have a default expiration (30 days from now)
updateQuery := `UPDATE access_tokens SET refresh_token_expires_at = ? WHERE refresh_token_expires_at IS NULL`
if d.dbType == "postgres" {
updateQuery = `UPDATE access_tokens SET refresh_token_expires_at = $1 WHERE refresh_token_expires_at IS NULL`
}

defaultExpiration := time.Now().Add(30 * 24 * time.Hour)
if _, err := d.db.Exec(updateQuery, defaultExpiration); err != nil {
return fmt.Errorf("failed to update existing records: %w", err)
}

// Make the column NOT NULL (SQLite doesn't support ALTER COLUMN SET NOT NULL directly)
if d.dbType == "postgres" {
query = `ALTER TABLE access_tokens ALTER COLUMN refresh_token_expires_at SET NOT NULL`
if _, err := d.db.Exec(query); err != nil {
return fmt.Errorf("failed to make refresh_token_expires_at NOT NULL: %w", err)
}
}

return nil
}

// columnExists checks if a column exists in a table
func (d *Database) columnExists(tableName, columnName string) (bool, error) {
var query string
if d.dbType == "postgres" {
query = `
SELECT COUNT(*) FROM information_schema.columns
WHERE table_name = $1 AND column_name = $2
`
} else {
query = `
SELECT COUNT(*) FROM pragma_table_info(?)
WHERE name = ?
`
}

var count int
err := d.db.QueryRow(query, tableName, columnName).Scan(&count)
if err != nil {
return false, err
}

return count > 0, nil
}

// GetClient retrieves a client by ID
func (d *Database) GetClient(clientID string) (*ClientInfo, error) {
var query string
Expand Down Expand Up @@ -522,21 +618,26 @@ func (d *Database) StoreToken(data *TokenData) error {
var query string
if d.dbType == "postgres" {
query = `
INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`
} else {
query = `
INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
INSERT INTO access_tokens (access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`
}

// Hash the refresh token for secure storage
hashedAccessToken := hashToken(data.AccessToken)
hashedRefreshToken := hashToken(data.RefreshToken)

_, err := d.db.Exec(query, hashedAccessToken, hashedRefreshToken, data.ClientID, data.UserID, data.GrantID, data.Scope, data.ExpiresAt)
// Set refresh token expiration to 30 days from now if not already set
if data.RefreshTokenExpiresAt.IsZero() {
data.RefreshTokenExpiresAt = time.Now().Add(30 * 24 * time.Hour)
}

_, err := d.db.Exec(query, hashedAccessToken, hashedRefreshToken, data.ClientID, data.UserID, data.GrantID, data.Scope, data.ExpiresAt, data.RefreshTokenExpiresAt)
return err
}

Expand All @@ -545,12 +646,12 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) {
var query string
if d.dbType == "postgres" {
query = `
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
FROM access_tokens WHERE access_token = $1
`
} else {
query = `
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
FROM access_tokens WHERE access_token = ?
`
}
Expand All @@ -567,6 +668,7 @@ func (d *Database) GetToken(accessToken string) (*TokenData, error) {
&data.GrantID,
&data.Scope,
&data.ExpiresAt,
&data.RefreshTokenExpiresAt,
&data.CreatedAt,
&data.Revoked,
&revokedAt,
Expand All @@ -588,13 +690,13 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
var query string
if d.dbType == "postgres" {
query = `
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
FROM access_tokens WHERE refresh_token = $1
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
FROM access_tokens WHERE refresh_token = $1 AND refresh_token_expires_at > NOW()
`
} else {
query = `
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, created_at, revoked, revoked_at
FROM access_tokens WHERE refresh_token = ?
SELECT access_token, refresh_token, client_id, user_id, grant_id, scope, expires_at, refresh_token_expires_at, created_at, revoked, revoked_at
FROM access_tokens WHERE refresh_token = ? AND refresh_token_expires_at > datetime('now')
`
}

Expand All @@ -611,6 +713,7 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
&data.GrantID,
&data.Scope,
&data.ExpiresAt,
&data.RefreshTokenExpiresAt,
&data.CreatedAt,
&data.Revoked,
&revokedAt,
Expand All @@ -630,6 +733,29 @@ func (d *Database) GetTokenByRefreshToken(refreshToken string) (*TokenData, erro
return &data, nil
}

// IsRefreshTokenExpired checks if a refresh token is expired
func (d *Database) IsRefreshTokenExpired(refreshToken string) (bool, error) {
var query string
if d.dbType == "postgres" {
query = `
SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = $1
`
} else {
query = `
SELECT refresh_token_expires_at FROM access_tokens WHERE refresh_token = ?
`
}

hashedRefreshToken := hashToken(refreshToken)
var expiresAt time.Time
err := d.db.QueryRow(query, hashedRefreshToken).Scan(&expiresAt)
if err != nil {
return false, err
}

return time.Now().After(expiresAt), nil
}

// RevokeToken revokes an access token
func (d *Database) RevokeToken(token string) error {
hashedToken := hashToken(token)
Expand Down Expand Up @@ -685,22 +811,27 @@ func (d *Database) CleanupExpiredTokens() error {

if d.dbType == "postgres" {
queries = []string{
`DELETE FROM access_tokens WHERE expires_at < NOW() OR revoked = TRUE`,
`DELETE FROM access_tokens WHERE (expires_at < NOW() AND refresh_token_expires_at < NOW()) OR revoked = TRUE`,
`DELETE FROM authorization_codes WHERE expires_at < NOW()`,
`DELETE FROM grants WHERE expires_at < EXTRACT(EPOCH FROM NOW())`,
}
} else {
queries = []string{
`DELETE FROM access_tokens WHERE expires_at < datetime('now') OR revoked = 1`,
`DELETE FROM access_tokens WHERE (expires_at < datetime('now') AND refresh_token_expires_at < datetime('now')) OR revoked = 1`,
`DELETE FROM authorization_codes WHERE expires_at < datetime('now')`,
`DELETE FROM grants WHERE expires_at < strftime('%s', 'now')`,
}
}

for _, query := range queries {
if _, err := d.db.Exec(query); err != nil {
result, err := d.db.Exec(query)
if err != nil {
return fmt.Errorf("failed to cleanup expired tokens: %w", err)
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected > 0 {
fmt.Printf("Deleted %d expired rows for query %s\n", rowsAffected, query)
}
}

return nil
Expand Down
Loading
Loading