Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .go-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.14.4
1.24.0
44 changes: 32 additions & 12 deletions postgresql/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (
"strconv"
"strings"
"sync"
"time"
"unicode"

"github.com/blang/semver"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry"
_ "github.com/lib/pq" // PostgreSQL db
"gocloud.dev/gcp"
"gocloud.dev/gcp/cloudsql"
Expand Down Expand Up @@ -178,6 +180,9 @@ type Config struct {
ApplicationName string
Timeout int
ConnectTimeoutSec int
MaxConnRetries int
ConnectionRetryTimeoutSeconds int
ConnMaxLifetimeSeconds int
MaxConns int
ExpectedVersion semver.Version
SSLClientCert *ClientCertificateConfig
Expand Down Expand Up @@ -282,26 +287,40 @@ func (c *Config) getDatabaseUsername() string {
func (c *Client) Connect() (*DBConnection, error) {
dbRegistryLock.Lock()
defer dbRegistryLock.Unlock()
ctx := context.Background()

dsn := c.config.connStr(c.databaseName)
conn, found := dbRegistry[dsn]
if !found {

var db *sql.DB
var err error
if c.config.Scheme == "postgres" {
db, err = sql.Open(proxyDriverName, dsn)
} else if c.config.Scheme == "gcppostgres" && c.config.GCPIAMImpersonateServiceAccount != "" {
db, err = openImpersonatedGCPDBConnection(context.Background(), dsn, c.config.GCPIAMImpersonateServiceAccount)
} else {
db, err = postgres.Open(context.Background(), dsn)
}
retryCount := 0

connectRetryTimeout := time.Duration(c.config.ConnectionRetryTimeoutSeconds) * time.Second
retryError := retry.RetryContext(ctx, connectRetryTimeout, func() *retry.RetryError {
if c.config.Scheme == "postgres" {
db, err = sql.Open(proxyDriverName, dsn)
} else if c.config.Scheme == "gcppostgres" && c.config.GCPIAMImpersonateServiceAccount != "" {
db, err = openImpersonatedGCPDBConnection(ctx, dsn, c.config.GCPIAMImpersonateServiceAccount)
} else {
db, err = postgres.Open(ctx, dsn)
}
if err == nil {
err = db.PingContext(ctx)
}

if err == nil {
err = db.Ping()
}
if err != nil {
errString := strings.Replace(err.Error(), c.config.Password, "XXXX", 2)
retryCount++
if err != nil {
if retryCount >= c.config.MaxConnRetries {
return retry.NonRetryableError(err)
}
return retry.RetryableError(err)
}
return nil
})
if retryError != nil {
errString := strings.Replace(retryError.Error(), c.config.Password, "XXXX", 2)
return nil, fmt.Errorf("error connecting to PostgreSQL server %s (scheme: %s): %s", c.config.Host, c.config.Scheme, errString)
}

Expand All @@ -310,6 +329,7 @@ func (c *Client) Connect() (*DBConnection, error) {
// we don't keep opened connection in case of the db has to be dropped in the plan.
db.SetMaxIdleConns(0)
db.SetMaxOpenConns(c.config.MaxConns)
db.SetConnMaxLifetime(time.Duration(c.config.ConnMaxLifetimeSeconds) * time.Second)

defaultVersion, _ := semver.Parse(defaultExpectedPostgreSQLVersion)
version := &c.config.ExpectedVersion
Expand Down
34 changes: 31 additions & 3 deletions postgresql/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package postgresql
import (
"context"
"fmt"
"os"

"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"os"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
Expand All @@ -21,8 +22,11 @@ import (
)

const (
defaultProviderMaxOpenConnections = 20
defaultExpectedPostgreSQLVersion = "9.0.0"
defaultProviderMaxOpenConnections = 20
defaultProviderConnMaxLifetimeSeconds = 300
defaultProviderMaxConnRetries = 5
defaultProviderConnectionRetryTimeoutSeconds = 5
defaultExpectedPostgreSQLVersion = "9.0.0"
)

// Provider returns a terraform.ResourceProvider.
Expand Down Expand Up @@ -185,13 +189,34 @@ func Provider() *schema.Provider {
Description: "Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely.",
ValidateFunc: validation.IntAtLeast(-1),
},
"max_conn_retries": {
Type: schema.TypeInt,
Optional: true,
Default: defaultProviderMaxConnRetries,
Description: "Maximum number of connection retries.",
ValidateFunc: validation.IntAtLeast(0),
},
"connection_retry_timeout_seconds": {
Type: schema.TypeInt,
Optional: true,
Default: defaultProviderConnectionRetryTimeoutSeconds,
Description: "Maximum wait for connection retries, in seconds.",
ValidateFunc: validation.IntAtLeast(0),
},
"max_connections": {
Type: schema.TypeInt,
Optional: true,
Default: defaultProviderMaxOpenConnections,
Description: "Maximum number of connections to establish to the database. Zero means unlimited.",
ValidateFunc: validation.IntAtLeast(-1),
},
"conn_max_lifetime_seconds": {
Type: schema.TypeInt,
Optional: true,
Default: defaultProviderConnMaxLifetimeSeconds,
Description: "Maximum lifetime of a connection, in seconds. Zero means unlimited.",
ValidateFunc: validation.IntAtLeast(-1),
},
"expected_version": {
Type: schema.TypeString,
Optional: true,
Expand Down Expand Up @@ -382,7 +407,10 @@ func providerConfigure(d *schema.ResourceData) (any, error) {
SSLMode: sslMode,
ApplicationName: "Terraform provider",
ConnectTimeoutSec: d.Get("connect_timeout").(int),
MaxConnRetries: d.Get("max_conn_retries").(int),
ConnectionRetryTimeoutSeconds: d.Get("connection_retry_timeout_seconds").(int),
MaxConns: d.Get("max_connections").(int),
ConnMaxLifetimeSeconds: d.Get("conn_max_lifetime_seconds").(int),
ExpectedVersion: version,
SSLRootCertPath: d.Get("sslrootcert").(string),
GCPIAMImpersonateServiceAccount: d.Get("gcp_iam_impersonate_service_account").(string),
Expand Down