diff --git a/internal/outpost/proxyv2/postgresstore/connpool.go b/internal/outpost/proxyv2/postgresstore/connpool.go index aa4650e63995..d96372966b57 100644 --- a/internal/outpost/proxyv2/postgresstore/connpool.go +++ b/internal/outpost/proxyv2/postgresstore/connpool.go @@ -21,7 +21,6 @@ import ( type RefreshableConnPool struct { mu sync.RWMutex db *sql.DB - dsnBuilder func(config.PostgreSQLConfig) (string, error) log *log.Entry currentDSN string gormConfig *gorm.Config @@ -49,7 +48,6 @@ func NewRefreshableConnPool(initialDSN string, gormConfig *gorm.Config, maxIdleC pool := &RefreshableConnPool{ db: db, - dsnBuilder: BuildDSN, log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore.connpool"), currentDSN: initialDSN, gormConfig: gormConfig, @@ -86,7 +84,7 @@ func (p *RefreshableConnPool) refreshCredentials(ctx context.Context) error { // Get fresh config cfg := config.Get().RefreshPostgreSQLConfig() - newDSN, err := p.dsnBuilder(cfg) + newDSN, err := BuildDSN(cfg) if err != nil { p.log.WithError(err).Warn("Failed to build DSN with refreshed credentials") return err diff --git a/internal/outpost/proxyv2/postgresstore/postgresstore.go b/internal/outpost/proxyv2/postgresstore/postgresstore.go index 7328bf518ab7..767fc8792449 100644 --- a/internal/outpost/proxyv2/postgresstore/postgresstore.go +++ b/internal/outpost/proxyv2/postgresstore/postgresstore.go @@ -2,16 +2,20 @@ package postgresstore import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" "net/http" + "os" "strings" "time" "github.com/google/uuid" "github.com/gorilla/sessions" - _ "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" "github.com/mitchellh/mapstructure" log "github.com/sirupsen/logrus" _ "gorm.io/driver/postgres" @@ -49,60 +53,121 @@ func (ProxySession) TableName() string { return "authentik_providers_proxy_proxysession" } -// BuildDSN constructs a PostgreSQL connection string -func BuildDSN(cfg config.PostgreSQLConfig) (string, error) { +// BuildConnConfig constructs a pgx.ConnConfig from PostgreSQL configuration. +func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) { // Validate required fields if cfg.Host == "" { - return "", fmt.Errorf("PostgreSQL host is required") + return nil, fmt.Errorf("PostgreSQL host is required") } if cfg.User == "" { - return "", fmt.Errorf("PostgreSQL user is required") + return nil, fmt.Errorf("PostgreSQL user is required") } if cfg.Name == "" { - return "", fmt.Errorf("PostgreSQL database name is required") + return nil, fmt.Errorf("PostgreSQL database name is required") } if cfg.Port <= 0 { - return "", fmt.Errorf("PostgreSQL port must be positive") + return nil, fmt.Errorf("PostgreSQL port must be positive") } - // Build DSN string with all parameters - dsnParts := []string{ - "host=" + cfg.Host, - fmt.Sprintf("port=%d", cfg.Port), - "user=" + cfg.User, - "dbname=" + cfg.Name, + // Start with a default config + connConfig, err := pgx.ParseConfig("") + if err != nil { + return nil, fmt.Errorf("failed to create default config: %w", err) } - if cfg.Password != "" { - dsnParts = append(dsnParts, "password="+cfg.Password) - } + // Set connection parameters + connConfig.Host = cfg.Host + connConfig.Port = uint16(cfg.Port) + connConfig.User = cfg.User + connConfig.Password = cfg.Password + connConfig.Database = cfg.Name - // Add SSL mode + // Configure TLS/SSL if cfg.SSLMode != "" { - dsnParts = append(dsnParts, "sslmode="+cfg.SSLMode) - } + switch cfg.SSLMode { + case "disable": + connConfig.TLSConfig = nil + case "require", "verify-ca", "verify-full": + tlsConfig := &tls.Config{} + + // Load root CA certificate if provided + if cfg.SSLRootCert != "" { + caCert, err := os.ReadFile(cfg.SSLRootCert) + if err != nil { + return nil, fmt.Errorf("failed to read SSL root certificate: %w", err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse SSL root certificate") + } + tlsConfig.RootCAs = caCertPool + } - // Add SSL certificates if provided - if cfg.SSLRootCert != "" { - dsnParts = append(dsnParts, "sslrootcert="+cfg.SSLRootCert) - } - if cfg.SSLCert != "" { - dsnParts = append(dsnParts, "sslcert="+cfg.SSLCert) + // Load client certificate and key if provided + if cfg.SSLCert != "" && cfg.SSLKey != "" { + cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey) + if err != nil { + return nil, fmt.Errorf("failed to load SSL client certificate: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + // Set verification mode + switch cfg.SSLMode { + case "require": + // Verify the server certificate (secure by default) + tlsConfig.InsecureSkipVerify = false + case "verify-ca": + // Verify the certificate is signed by a trusted CA + tlsConfig.InsecureSkipVerify = false + case "verify-full": + // Verify the certificate and hostname + tlsConfig.InsecureSkipVerify = false + tlsConfig.ServerName = cfg.Host + } + + connConfig.TLSConfig = tlsConfig + } } - if cfg.SSLKey != "" { - dsnParts = append(dsnParts, "sslkey="+cfg.SSLKey) + + // Set runtime params + if connConfig.RuntimeParams == nil { + connConfig.RuntimeParams = make(map[string]string) } + if cfg.DefaultSchema != "" { - dsnParts = append(dsnParts, "search_path="+cfg.DefaultSchema) + connConfig.RuntimeParams["search_path"] = cfg.DefaultSchema } - // Add connection options if specified + // Parse and apply connection options if specified if cfg.ConnOptions != "" { - dsnParts = append(dsnParts, cfg.ConnOptions) + // Parse key=value pairs from ConnOptions + // Format: "key1=value1 key2=value2" + pairs := strings.Split(cfg.ConnOptions, " ") + for _, pair := range pairs { + if pair == "" { + continue + } + kv := strings.SplitN(pair, "=", 2) + if len(kv) == 2 { + connConfig.RuntimeParams[kv[0]] = kv[1] + } + } + } + + return connConfig, nil +} + +// BuildDSN constructs a PostgreSQL connection string from a ConnConfig. +func BuildDSN(cfg config.PostgreSQLConfig) (string, error) { + connConfig, err := BuildConnConfig(cfg) + if err != nil { + return "", err } - // Join parts with spaces - return strings.Join(dsnParts, " "), nil + // Register the config and get a connection string + // (This approach lets pgx handle all the escaping internally which is quite convenient for say spaces in the password) + return stdlib.RegisterConnConfig(connConfig), nil } // SetupGORMWithRefreshablePool creates a GORM DB with a refreshable connection pool. diff --git a/internal/outpost/proxyv2/postgresstore/postgresstore_test.go b/internal/outpost/proxyv2/postgresstore/postgresstore_test.go index 4cfe4bda9f00..0b29801199c5 100644 --- a/internal/outpost/proxyv2/postgresstore/postgresstore_test.go +++ b/internal/outpost/proxyv2/postgresstore/postgresstore_test.go @@ -2,14 +2,23 @@ package postgresstore import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" "encoding/json" + "encoding/pem" "fmt" + "math/big" "net/http/httptest" + "os" + "path/filepath" "testing" "time" "github.com/google/uuid" "github.com/gorilla/sessions" + "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" @@ -541,11 +550,11 @@ func TestBuildDSN_Validation(t *testing.T) { } } -func TestBuildDSN(t *testing.T) { +func TestBuildConnConfig(t *testing.T) { tests := []struct { name string cfg config.PostgreSQLConfig - expected string + validate func(*testing.T, *pgx.ConnConfig) }{ { name: "basic configuration", @@ -555,10 +564,16 @@ func TestBuildDSN(t *testing.T) { User: "testuser", Name: "testdb", }, - expected: "host=localhost port=5432 user=testuser dbname=testdb", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "localhost", cc.Host) + assert.Equal(t, uint16(5432), cc.Port) + assert.Equal(t, "testuser", cc.User) + assert.Equal(t, "testdb", cc.Database) + assert.Equal(t, "", cc.Password) + }, }, { - name: "with password", + name: "with simple password", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: 5432, @@ -566,7 +581,87 @@ func TestBuildDSN(t *testing.T) { Password: "testpass", Name: "testdb", }, - expected: "host=localhost port=5432 user=testuser dbname=testdb password=testpass", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "testpass", cc.Password) + }, + }, + { + name: "with password containing spaces", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: "my secure password", + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "my secure password", cc.Password) + }, + }, + { + name: "with password containing single quotes", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: "pass'word", + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "pass'word", cc.Password) + }, + }, + { + name: "with password containing backslashes", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: `pass\word`, + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, `pass\word`, cc.Password) + }, + }, + { + name: "with password containing special characters", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: `p@ss w0rd!#$%^&*()`, + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, `p@ss w0rd!#$%^&*()`, cc.Password) + }, + }, + { + name: "with password containing quotes and backslashes", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: `my'pass\word"here`, + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, `my'pass\word"here`, cc.Password) + }, + }, + { + name: "with passphrase (multiple spaces)", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: "the quick brown fox jumps over", + Name: "testdb", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "the quick brown fox jumps over", cc.Password) + }, }, { name: "with sslmode=disable", @@ -577,10 +672,12 @@ func TestBuildDSN(t *testing.T) { Name: "testdb", SSLMode: "disable", }, - expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=disable", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Nil(t, cc.TLSConfig) + }, }, { - name: "with sslmode=require", + name: "with sslmode=require (no certs)", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: 5432, @@ -588,79 +685,217 @@ func TestBuildDSN(t *testing.T) { Name: "testdb", SSLMode: "require", }, - expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=require", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.NotNil(t, cc.TLSConfig) + assert.True(t, cc.TLSConfig.InsecureSkipVerify) + }, }, { - name: "with sslmode=prefer", + name: "with custom schema", cfg: config.PostgreSQLConfig{ - Host: "localhost", - Port: 5432, - User: "testuser", - Name: "testdb", - SSLMode: "prefer", + Host: "localhost", + Port: 5432, + User: "testuser", + Name: "testdb", + DefaultSchema: "custom_schema", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "custom_schema", cc.RuntimeParams["search_path"]) }, - expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=prefer", }, { - name: "with SSL certificates", + name: "with connection options", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: 5432, User: "testuser", Name: "testdb", - SSLMode: "verify-full", - SSLRootCert: "/path/to/root.crt", - SSLCert: "/path/to/client.crt", - SSLKey: "/path/to/client.key", + ConnOptions: "connect_timeout=10 application_name=authentik", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "10", cc.RuntimeParams["connect_timeout"]) + assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) }, - expected: "host=localhost port=5432 user=testuser dbname=testdb sslmode=verify-full sslrootcert=/path/to/root.crt sslcert=/path/to/client.crt sslkey=/path/to/client.key", }, { - name: "with custom schema", + name: "full configuration with special password", cfg: config.PostgreSQLConfig{ - Host: "localhost", - Port: 5432, - User: "testuser", - Name: "testdb", - DefaultSchema: "custom_schema", + Host: "db.example.com", + Port: 5433, + User: "admin", + Password: "my super secret password!@#", + Name: "production", + SSLMode: "require", + DefaultSchema: "app_schema", + ConnOptions: "application_name=authentik", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "db.example.com", cc.Host) + assert.Equal(t, uint16(5433), cc.Port) + assert.Equal(t, "admin", cc.User) + assert.Equal(t, "my super secret password!@#", cc.Password) + assert.Equal(t, "production", cc.Database) + assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"]) + assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) }, - expected: "host=localhost port=5432 user=testuser dbname=testdb search_path=custom_schema", }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := BuildConnConfig(tt.cfg) + require.NoError(t, err) + require.NotNil(t, result) + tt.validate(t, result) + }) + } +} + +// TestBuildConnConfig_WithSSLCertificates tests SSL certificate configuration +func TestBuildConnConfig_WithSSLCertificates(t *testing.T) { + rootCertPath, clientCertPath, clientKeyPath, cleanup := generateTestCerts(t) + defer cleanup() + + tests := []struct { + name string + cfg config.PostgreSQLConfig + validate func(*testing.T, *pgx.ConnConfig) + }{ { - name: "with connection options", + name: "verify-full with all certificates", + cfg: config.PostgreSQLConfig{ + Host: "db.example.com", + Port: 5432, + User: "testuser", + Password: "my secure password", + Name: "testdb", + SSLMode: "verify-full", + SSLRootCert: rootCertPath, + SSLCert: clientCertPath, + SSLKey: clientKeyPath, + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + require.NotNil(t, cc.TLSConfig) + assert.False(t, cc.TLSConfig.InsecureSkipVerify) + assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName) + assert.NotNil(t, cc.TLSConfig.RootCAs) + assert.Len(t, cc.TLSConfig.Certificates, 1) + }, + }, + { + name: "verify-ca with root cert only", cfg: config.PostgreSQLConfig{ Host: "localhost", Port: 5432, User: "testuser", Name: "testdb", - ConnOptions: "connect_timeout=10", + SSLMode: "verify-ca", + SSLRootCert: rootCertPath, + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + require.NotNil(t, cc.TLSConfig) + assert.False(t, cc.TLSConfig.InsecureSkipVerify) + assert.NotNil(t, cc.TLSConfig.RootCAs) + assert.Empty(t, cc.TLSConfig.Certificates) }, - expected: "host=localhost port=5432 user=testuser dbname=testdb connect_timeout=10", }, { - name: "full configuration", + name: "require with client cert", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Name: "testdb", + SSLMode: "require", + SSLCert: clientCertPath, + SSLKey: clientKeyPath, + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + require.NotNil(t, cc.TLSConfig) + assert.True(t, cc.TLSConfig.InsecureSkipVerify) + assert.Len(t, cc.TLSConfig.Certificates, 1) + }, + }, + { + name: "full configuration with SSL and special password", cfg: config.PostgreSQLConfig{ Host: "db.example.com", Port: 5433, User: "admin", - Password: "secret", + Password: "my super secret password!@#", Name: "production", SSLMode: "verify-full", - SSLRootCert: "/certs/root.crt", - SSLCert: "/certs/client.crt", - SSLKey: "/certs/client.key", + SSLRootCert: rootCertPath, + SSLCert: clientCertPath, + SSLKey: clientKeyPath, DefaultSchema: "app_schema", ConnOptions: "application_name=authentik", }, - expected: "host=db.example.com port=5433 user=admin dbname=production password=secret sslmode=verify-full sslrootcert=/certs/root.crt sslcert=/certs/client.crt sslkey=/certs/client.key search_path=app_schema application_name=authentik", + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "db.example.com", cc.Host) + assert.Equal(t, uint16(5433), cc.Port) + assert.Equal(t, "admin", cc.User) + assert.Equal(t, "my super secret password!@#", cc.Password) + assert.Equal(t, "production", cc.Database) + require.NotNil(t, cc.TLSConfig) + assert.False(t, cc.TLSConfig.InsecureSkipVerify) + assert.Equal(t, "db.example.com", cc.TLSConfig.ServerName) + assert.NotNil(t, cc.TLSConfig.RootCAs) + assert.Len(t, cc.TLSConfig.Certificates, 1) + assert.Equal(t, "app_schema", cc.RuntimeParams["search_path"]) + assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := BuildDSN(tt.cfg) + result, err := BuildConnConfig(tt.cfg) require.NoError(t, err) - assert.Equal(t, tt.expected, result) + require.NotNil(t, result) + tt.validate(t, result) + }) + } +} + +// TestBuildDSN_WithSpecialPasswords tests that BuildDSN can handle passwords with special characters +// by verifying the DSN can actually be used to connect to a database +func TestBuildDSN_WithSpecialPasswords(t *testing.T) { + tests := []struct { + name string + password string + }{ + {"space in password", "my password"}, + {"multiple spaces", "the quick brown fox"}, + {"single quote", "pass'word"}, + {"backslash", `pass\word`}, + {"double quote", `pass"word`}, + {"special chars", `p@ss!#$%^&*()`}, + {"mixed special", `my'pass\word"here`}, + {"unicode", "pässwörd"}, + {"leading/trailing spaces", " password "}, + {"tab character", "pass\tword"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.PostgreSQLConfig{ + Host: "localhost", + Port: 5432, + User: "testuser", + Password: tt.password, + Name: "testdb", + } + + // Test that BuildDSN doesn't error + dsn, err := BuildDSN(cfg) + require.NoError(t, err) + require.NotEmpty(t, dsn) + + // Test that BuildConnConfig preserves the password exactly + connConfig, err := BuildConnConfig(cfg) + require.NoError(t, err) + assert.Equal(t, tt.password, connConfig.Password, "Password should be preserved exactly") }) } } @@ -715,3 +950,77 @@ func createSessionData(t *testing.T, claims map[string]interface{}) string { require.NoError(t, err) return string(sessionDataJSON) } + +// generateTestCerts creates temporary SSL certificates for testing +func generateTestCerts(t *testing.T) (rootCertPath, clientCertPath, clientKeyPath string, cleanup func()) { + tmpDir := t.TempDir() + + // Generate CA certificate + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test CA"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + require.NoError(t, err) + + // Write CA certificate + rootCertPath = filepath.Join(tmpDir, "root.crt") + rootCertFile, err := os.Create(rootCertPath) + require.NoError(t, err) + defer rootCertFile.Close() + err = pem.Encode(rootCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: caCertDER}) + require.NoError(t, err) + + // Generate client key + clientKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + // Generate client certificate + clientTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"Test Client"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + clientCertDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey) + require.NoError(t, err) + + // Write client certificate + clientCertPath = filepath.Join(tmpDir, "client.crt") + clientCertFile, err := os.Create(clientCertPath) + require.NoError(t, err) + defer clientCertFile.Close() + err = pem.Encode(clientCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER}) + require.NoError(t, err) + + // Write client key + clientKeyPath = filepath.Join(tmpDir, "client.key") + clientKeyFile, err := os.Create(clientKeyPath) + require.NoError(t, err) + defer clientKeyFile.Close() + clientKeyBytes := x509.MarshalPKCS1PrivateKey(clientKey) + err = pem.Encode(clientKeyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: clientKeyBytes}) + require.NoError(t, err) + + cleanup = func() { + // TempDir cleanup is automatic in Go tests + } + + return rootCertPath, clientCertPath, clientKeyPath, cleanup +}