Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions internal/outpost/proxyv2/postgresstore/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
type RefreshableConnPool struct {
mu sync.RWMutex
db *sql.DB
dsnBuilder func(config.PostgreSQLConfig) (string, error)
log *log.Entry
currentDSN string
gormConfig *gorm.Config
Expand Down Expand Up @@ -49,7 +48,6 @@ func NewRefreshableConnPool(initialDSN string, gormConfig *gorm.Config, maxIdleC

pool := &RefreshableConnPool{
db: db,
dsnBuilder: BuildDSN,
log: log.WithField("logger", "authentik.outpost.proxyv2.postgresstore.connpool"),
currentDSN: initialDSN,
gormConfig: gormConfig,
Expand Down Expand Up @@ -86,7 +84,7 @@ func (p *RefreshableConnPool) refreshCredentials(ctx context.Context) error {

// Get fresh config
cfg := config.Get().RefreshPostgreSQLConfig()
newDSN, err := p.dsnBuilder(cfg)
newDSN, err := BuildDSN(cfg)
if err != nil {
p.log.WithError(err).Warn("Failed to build DSN with refreshed credentials")
return err
Expand Down
129 changes: 97 additions & 32 deletions internal/outpost/proxyv2/postgresstore/postgresstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@ package postgresstore

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"

"github.com/google/uuid"
"github.com/gorilla/sessions"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
_ "gorm.io/driver/postgres"
Expand Down Expand Up @@ -49,60 +53,121 @@ func (ProxySession) TableName() string {
return "authentik_providers_proxy_proxysession"
}

// BuildDSN constructs a PostgreSQL connection string
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
// BuildConnConfig constructs a pgx.ConnConfig from PostgreSQL configuration.
func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
// Validate required fields
if cfg.Host == "" {
return "", fmt.Errorf("PostgreSQL host is required")
return nil, fmt.Errorf("PostgreSQL host is required")
}
if cfg.User == "" {
return "", fmt.Errorf("PostgreSQL user is required")
return nil, fmt.Errorf("PostgreSQL user is required")
}
if cfg.Name == "" {
return "", fmt.Errorf("PostgreSQL database name is required")
return nil, fmt.Errorf("PostgreSQL database name is required")
}
if cfg.Port <= 0 {
return "", fmt.Errorf("PostgreSQL port must be positive")
return nil, fmt.Errorf("PostgreSQL port must be positive")
}

// Build DSN string with all parameters
dsnParts := []string{
"host=" + cfg.Host,
fmt.Sprintf("port=%d", cfg.Port),
"user=" + cfg.User,
"dbname=" + cfg.Name,
// Start with a default config
connConfig, err := pgx.ParseConfig("")
if err != nil {
return nil, fmt.Errorf("failed to create default config: %w", err)
}

if cfg.Password != "" {
dsnParts = append(dsnParts, "password="+cfg.Password)
}
// Set connection parameters
connConfig.Host = cfg.Host
connConfig.Port = uint16(cfg.Port)
connConfig.User = cfg.User
connConfig.Password = cfg.Password
connConfig.Database = cfg.Name

// Add SSL mode
// Configure TLS/SSL
if cfg.SSLMode != "" {
dsnParts = append(dsnParts, "sslmode="+cfg.SSLMode)
}
switch cfg.SSLMode {
case "disable":
connConfig.TLSConfig = nil
case "require", "verify-ca", "verify-full":
tlsConfig := &tls.Config{}

// Load root CA certificate if provided
if cfg.SSLRootCert != "" {
caCert, err := os.ReadFile(cfg.SSLRootCert)
if err != nil {
return nil, fmt.Errorf("failed to read SSL root certificate: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse SSL root certificate")
}
tlsConfig.RootCAs = caCertPool
}

// Add SSL certificates if provided
if cfg.SSLRootCert != "" {
dsnParts = append(dsnParts, "sslrootcert="+cfg.SSLRootCert)
}
if cfg.SSLCert != "" {
dsnParts = append(dsnParts, "sslcert="+cfg.SSLCert)
// Load client certificate and key if provided
if cfg.SSLCert != "" && cfg.SSLKey != "" {
cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey)
if err != nil {
return nil, fmt.Errorf("failed to load SSL client certificate: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}

// Set verification mode
switch cfg.SSLMode {
case "require":
// Verify the server certificate (secure by default)
tlsConfig.InsecureSkipVerify = false

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change, as-is, is a breaking change that does not conform to the Postgres or Authentik documentation. If SSL mode is set to require, the documented behavior is to not check certificate validity:

This change instead makes "require" equivalent to "verify-ca". If this is intended to be the case, it should be documented as a breaking change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I need to change that, Good catch.

case "verify-ca":
// Verify the certificate is signed by a trusted CA
tlsConfig.InsecureSkipVerify = false
case "verify-full":
// Verify the certificate and hostname
tlsConfig.InsecureSkipVerify = false
tlsConfig.ServerName = cfg.Host
}

connConfig.TLSConfig = tlsConfig
}
}
if cfg.SSLKey != "" {
dsnParts = append(dsnParts, "sslkey="+cfg.SSLKey)

// Set runtime params
if connConfig.RuntimeParams == nil {
connConfig.RuntimeParams = make(map[string]string)
}

if cfg.DefaultSchema != "" {
dsnParts = append(dsnParts, "search_path="+cfg.DefaultSchema)
connConfig.RuntimeParams["search_path"] = cfg.DefaultSchema
}

// Add connection options if specified
// Parse and apply connection options if specified
if cfg.ConnOptions != "" {
dsnParts = append(dsnParts, cfg.ConnOptions)
// Parse key=value pairs from ConnOptions
// Format: "key1=value1 key2=value2"
pairs := strings.Split(cfg.ConnOptions, " ")
for _, pair := range pairs {
if pair == "" {
continue
}
kv := strings.SplitN(pair, "=", 2)
if len(kv) == 2 {
connConfig.RuntimeParams[kv[0]] = kv[1]
}
}
}

return connConfig, nil
}

// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
connConfig, err := BuildConnConfig(cfg)
if err != nil {
return "", err
}

// Join parts with spaces
return strings.Join(dsnParts, " "), nil
// Register the config and get a connection string
// (This approach lets pgx handle all the escaping internally which is quite convenient for say spaces in the password)
return stdlib.RegisterConnConfig(connConfig), nil
}

// SetupGORMWithRefreshablePool creates a GORM DB with a refreshable connection pool.
Expand Down
Loading
Loading