Skip to content
18 changes: 13 additions & 5 deletions results.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,24 @@ func NewResults(targetNames []string, testModes []string) *Results {
}
}

// SingleResult represents the verification result from a single target, with the schema:
// SingleResult[schema][table][mode] = test output.
type SingleResult map[string]map[string]map[string]string
// DatabaseResult represents the verification result from a single target database:
// DatabaseResult[schema][table][mode] = test output.
type DatabaseResult map[string]SchemaResult

// SchemaResult represents the verification result from a single schema:
// SchemaResult[table][mode] = test output.
type SchemaResult map[string]TableResult

// TableResult represents the verification result from a single table:
// TableResult[mode] = test output.
type TableResult map[string]string

// AddResult adds a SingleResult from a test on a specific target to the Results object.
func (r *Results) AddResult(targetName string, schemaTableHashes SingleResult) {
func (r *Results) AddResult(targetName string, databaseHashes DatabaseResult) {
r.mutex.Lock()
defer r.mutex.Unlock()

for schema, tables := range schemaTableHashes {
for schema, tables := range databaseHashes {
if _, ok := r.content[schema]; !ok {
r.content[schema] = make(map[string]map[string]map[string][]string)
}
Expand Down
225 changes: 128 additions & 97 deletions verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgverify

import (
"context"
"sync"

"github.com/jackc/pgx/pgtype"
"github.com/jackc/pgx/v4"
Expand Down Expand Up @@ -30,7 +31,7 @@ func (c Config) Verify(ctx context.Context, targets []*pgx.ConnConfig) (*Results

// First check that we can connect to every specified target database.
targetNames := make([]string, len(targets))
conns := make(map[int]*pgx.Conn)
connConfs := make(map[int]*pgx.ConnConfig)

for i, target := range targets {
pgxLoggerFields := logrus.Fields{
Expand All @@ -52,29 +53,23 @@ func (c Config) Verify(ctx context.Context, targets []*pgx.ConnConfig) (*Results

target.LogLevel = pgx.LogLevelError

conn, err := pgx.ConnectConfig(ctx, target)
if err != nil {
return finalResults, err
}
defer conn.Close(ctx)
conns[i] = conn
connConfs[i] = target
}

finalResults = NewResults(targetNames, c.TestModes)

// Then query each target database in parallel to generate table hashes.
var doneChannels []chan struct{}
wg := &sync.WaitGroup{}

for i, conn := range conns {
done := make(chan struct{})
go c.runTestsOnTarget(ctx, targetNames[i], conn, finalResults, done)
doneChannels = append(doneChannels, done)
}
for i, connConf := range connConfs {
wg.Add(1)

for _, done := range doneChannels {
<-done
go c.runTestsOnTarget(ctx, targetNames[i], connConf, finalResults, wg)
}

// Wait for queries to complete
wg.Wait()

// Compare final results
reportErrors := finalResults.CheckForErrors()
if len(reportErrors) > 0 {
Expand All @@ -86,26 +81,40 @@ func (c Config) Verify(ctx context.Context, targets []*pgx.ConnConfig) (*Results
return finalResults, nil
}

func (c Config) runTestsOnTarget(ctx context.Context, targetName string, conn *pgx.Conn, finalResults *Results, done chan struct{}) {
func (c Config) runTestsOnTarget(ctx context.Context, targetName string, connConf *pgx.ConnConfig, finalResults *Results, wg *sync.WaitGroup) {
defer wg.Done()

logger := c.Logger.WithField("target", targetName)

conn, err := pgx.ConnectConfig(ctx, connConf)
if err != nil {
logger.WithError(err).Error("failed to connect to target")

return
}

defer conn.Close(ctx)

schemaTableHashes, err := c.fetchTargetTableNames(ctx, conn)
if err != nil {
logger.WithError(err).Error("failed to fetch target tables")
close(done)

return
}

schemaTableHashes = c.runTestQueriesOnTarget(ctx, logger, conn, schemaTableHashes)
for schemaName, schemaHashes := range schemaTableHashes {
for tableName := range schemaHashes {
wg.Add(1)

go c.runTestQueriesOnTable(ctx, logger, connConf, targetName, schemaName, tableName, finalResults, wg)
}
}

finalResults.AddResult(targetName, schemaTableHashes)
logger.Info("Table hashes computed")
close(done)
}

func (c Config) fetchTargetTableNames(ctx context.Context, conn *pgx.Conn) (SingleResult, error) {
schemaTableHashes := make(SingleResult)
func (c Config) fetchTargetTableNames(ctx context.Context, conn *pgx.Conn) (DatabaseResult, error) {
schemaTableHashes := make(DatabaseResult)

rows, err := conn.Query(ctx, buildGetTablesQuery(c.IncludeSchemas, c.ExcludeSchemas, c.IncludeTables, c.ExcludeTables))
if err != nil {
Expand All @@ -119,10 +128,10 @@ func (c Config) fetchTargetTableNames(ctx context.Context, conn *pgx.Conn) (Sing
}

if _, ok := schemaTableHashes[schema.String]; !ok {
schemaTableHashes[schema.String] = make(map[string]map[string]string)
schemaTableHashes[schema.String] = make(SchemaResult)
}

schemaTableHashes[schema.String][table.String] = make(map[string]string)
schemaTableHashes[schema.String][table.String] = make(TableResult)

for _, testMode := range c.TestModes {
schemaTableHashes[schema.String][table.String][testMode] = defaultErrorOutput
Expand Down Expand Up @@ -152,111 +161,133 @@ func (c Config) validColumnTarget(columnName string) bool {
return false
}

func (c Config) runTestQueriesOnTarget(ctx context.Context, logger *logrus.Entry, conn *pgx.Conn, schemaTableHashes SingleResult) SingleResult {
for schemaName, tables := range schemaTableHashes {
for tableName := range tables {
tableLogger := logger.WithField("table", tableName).WithField("schema", schemaName)
tableLogger.Info("Computing hash")
func (c Config) runTestQueriesOnTable(ctx context.Context, logger *logrus.Entry, connConf *pgx.ConnConfig, targetName, schemaName, tableName string, finalResults *Results, wg *sync.WaitGroup) {
defer wg.Done()

rows, err := conn.Query(ctx, buildGetColumsQuery(schemaName, tableName))
if err != nil {
tableLogger.WithError(err).Error("Failed to query column names, data types")
tableLogger := logger.WithField("table", tableName).WithField("schema", schemaName)
tableLogger.Info("Computing hash")

continue
}
conn, err := pgx.ConnectConfig(ctx, connConf)
if err != nil {
logger.WithError(err).Error("failed to connect to target")

allTableColumns := make(map[string]column)
return
}

for rows.Next() {
var columnName, dataType, constraintName, constraintType pgtype.Text
defer conn.Close(ctx)

err := rows.Scan(&columnName, &dataType, &constraintName, &constraintType)
if err != nil {
tableLogger.WithError(err).Error("Failed to parse column names, data types from query response")
rows, err := conn.Query(ctx, buildGetColumsQuery(schemaName, tableName))
if err != nil {
tableLogger.WithError(err).Error("Failed to query column names, data types")

continue
}
return
}

existing, ok := allTableColumns[columnName.String]
if ok {
existing.constraints = append(existing.constraints, constraintType.String)
allTableColumns[columnName.String] = existing
} else {
allTableColumns[columnName.String] = column{columnName.String, dataType.String, []string{constraintType.String}}
}
}
allTableColumns := make(map[string]column)

var tableColumns []column
for rows.Next() {
var columnName, dataType, constraintName, constraintType pgtype.Text

var primaryKeyColumnNames []string
err := rows.Scan(&columnName, &dataType, &constraintName, &constraintType)
if err != nil {
tableLogger.WithError(err).Error("Failed to parse column names, data types from query response")

for _, col := range allTableColumns {
if col.IsPrimaryKey() {
primaryKeyColumnNames = append(primaryKeyColumnNames, col.name)
}
continue
}

if c.validColumnTarget(col.name) {
tableColumns = append(tableColumns, col)
}
}
existing, ok := allTableColumns[columnName.String]
if ok {
existing.constraints = append(existing.constraints, constraintType.String)
allTableColumns[columnName.String] = existing
} else {
allTableColumns[columnName.String] = column{columnName.String, dataType.String, []string{constraintType.String}}
}
}

if len(primaryKeyColumnNames) == 0 {
tableLogger.Error("No primary keys found")
var tableColumns []column

continue
}
var primaryKeyColumnNames []string

tableLogger.WithFields(logrus.Fields{
"primary_keys": primaryKeyColumnNames,
"columns": tableColumns,
}).Info("Determined columns to hash")
for _, col := range allTableColumns {
if col.IsPrimaryKey() {
primaryKeyColumnNames = append(primaryKeyColumnNames, col.name)
}

for _, testMode := range c.TestModes {
testLogger := tableLogger.WithField("test", testMode)
if c.validColumnTarget(col.name) {
tableColumns = append(tableColumns, col)
}
}

var query string
if len(primaryKeyColumnNames) == 0 {
tableLogger.Error("No primary keys found")

switch testMode {
case TestModeFull:
query = buildFullHashQuery(c, schemaName, tableName, tableColumns)
case TestModeBookend:
query = buildBookendHashQuery(c, schemaName, tableName, tableColumns, c.BookendLimit)
case TestModeSparse:
query = buildSparseHashQuery(c, schemaName, tableName, tableColumns, c.SparseMod)
case TestModeRowCount:
query = buildRowCountQuery(schemaName, tableName)
}
return
}

testLogger.Debugf("Generated query: %s", query)
tableLogger.WithFields(logrus.Fields{
"primary_keys": primaryKeyColumnNames,
"columns": tableColumns,
}).Info("Determined columns to hash")

for _, testMode := range c.TestModes {
testLogger := tableLogger.WithField("test", testMode)

var query string

switch testMode {
case TestModeFull:
query = buildFullHashQuery(c, schemaName, tableName, tableColumns)
case TestModeBookend:
query = buildBookendHashQuery(c, schemaName, tableName, tableColumns, c.BookendLimit)
case TestModeSparse:
query = buildSparseHashQuery(c, schemaName, tableName, tableColumns, c.SparseMod)
case TestModeRowCount:
query = buildRowCountQuery(schemaName, tableName)
}

testOutput, err := runTestOnTable(ctx, conn, query)
if err != nil {
testLogger.WithError(err).Error("Failed to compute hash")
testLogger.Debugf("Generated query: %s", query)

continue
}
wg.Add(1)

schemaTableHashes[schemaName][tableName][testMode] = testOutput
testLogger.Infof("Hash computed: %s", testOutput)
}
}
go runTestOnTable(ctx, testLogger, connConf, targetName, schemaName, tableName, testMode, query, finalResults, wg)
}

return schemaTableHashes
}

func runTestOnTable(ctx context.Context, conn *pgx.Conn, query string) (string, error) {
func runTestOnTable(ctx context.Context, logger *logrus.Entry, connConf *pgx.ConnConfig, targetName, schemaName, tableName, testMode, query string, finalResults *Results, wg *sync.WaitGroup) {
defer wg.Done()

conn, err := pgx.ConnectConfig(ctx, connConf)
if err != nil {
logger.WithError(err).Error("failed to connect to target")

return
}

defer conn.Close(ctx)

row := conn.QueryRow(ctx, query)

var testOutputString string

var testOutput pgtype.Text
if err := row.Scan(&testOutput); err != nil {
switch err {
case pgx.ErrNoRows:
return "no rows", nil
testOutputString = "no rows"
default:
return "", errors.Wrap(err, "failed to scan test output")
logger.WithError(err).Error("failed to scan test output")

return
}
} else {
testOutputString = testOutput.String
}

return testOutput.String, nil
logger.Infof("Hash computed: %s", testOutputString)

databaseResults := make(DatabaseResult)
databaseResults[schemaName] = make(SchemaResult)
databaseResults[schemaName][tableName] = make(TableResult)
databaseResults[schemaName][tableName][testMode] = testOutputString
finalResults.AddResult(targetName, databaseResults)
}