Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions postgresql/seed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package postgresql

import (
"database/sql"
"fmt"
"io"
"strings"

"github.com/swiftcarrot/dbx/seed"
)

// ImportFromCSV imports data from a CSV file into a PostgreSQL table
// It uses PostgreSQL's COPY command for efficient bulk data loading
func (pg *PostgreSQL) ImportFromCSV(db *sql.DB, tableName, schemaName string, reader io.Reader, options *seed.CSVImportOptions) error {
// Set default options if not provided
if options == nil {
options = &CSVImportOptions{

Check failure on line 17 in postgresql/seed.go

View workflow job for this annotation

GitHub Actions / Lint

undefined: CSVImportOptions

Check failure on line 17 in postgresql/seed.go

View workflow job for this annotation

GitHub Actions / Test (1.21, postgres:16, mysql:8.0)

undefined: CSVImportOptions
Delimiter: ",",
NullValue: "",
Header: true,
Quote: "\"",
Escape: "\\",
Encoding: "UTF8",
}
}

// Apply defaults for unset options
if options.Delimiter == "" {
options.Delimiter = ","
}
if options.Quote == "" {
options.Quote = "\""
}
if options.Escape == "" {
options.Escape = "\\"
}
if options.Encoding == "" {
options.Encoding = "UTF8"
}

// Qualify the table name with schema if provided
qualifiedTable := tableName
if schemaName != "" && schemaName != "public" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdentifier(schemaName), quoteIdentifier(tableName))
} else {
qualifiedTable = quoteIdentifier(tableName)
}

// Start a transaction
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
tx.Rollback()
}
}()

// Build the COPY command
copyCmd := fmt.Sprintf("COPY %s", qualifiedTable)

// Add columns if specified
if len(options.Columns) > 0 {
columnNames := make([]string, len(options.Columns))
for i, col := range options.Columns {
columnNames[i] = quoteIdentifier(col)
}
copyCmd += fmt.Sprintf(" (%s)", strings.Join(columnNames, ", "))
}

copyCmd += " FROM STDIN WITH ("
copyCmd += fmt.Sprintf("FORMAT CSV, DELIMITER '%s'", options.Delimiter)

if options.Header {
copyCmd += ", HEADER"
}

if options.Quote != "" {
copyCmd += fmt.Sprintf(", QUOTE '%s'", options.Quote)
}

if options.Escape != "" {
copyCmd += fmt.Sprintf(", ESCAPE '%s'", options.Escape)
}

if options.NullValue != "" {
copyCmd += fmt.Sprintf(", NULL '%s'", options.NullValue)
}

if options.Encoding != "" {
copyCmd += fmt.Sprintf(", ENCODING '%s'", options.Encoding)
}

copyCmd += ")"

// Get the PostgreSQL db/sql connection
stmt, err := tx.Prepare(copyCmd)
if err != nil {
return fmt.Errorf("failed to prepare COPY statement: %w", err)
}
defer stmt.Close()

// Get a reference to the underlying PostgreSQL connection
_, err = stmt.Exec()
if err != nil {
return fmt.Errorf("failed to execute COPY command: %w", err)
}

// Copy data
_, err = io.Copy(tx.(*sql.Tx), reader)

Check failure on line 111 in postgresql/seed.go

View workflow job for this annotation

GitHub Actions / Lint

invalid operation: tx (variable of type *sql.Tx) is not an interface (typecheck)

Check failure on line 111 in postgresql/seed.go

View workflow job for this annotation

GitHub Actions / Test (1.21, postgres:16, mysql:8.0)

invalid operation: tx (variable of type *sql.Tx) is not an interface
if err != nil {
return fmt.Errorf("failed to copy data: %w", err)
}

// Commit the transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}

return nil
}
117 changes: 117 additions & 0 deletions postgresql/seed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package postgresql

import (
"strings"
"testing"

"github.com/stretchr/testify/require"
"github.com/swiftcarrot/dbx/internal/testutil"
"github.com/swiftcarrot/dbx/seed"
)

func TestImportFromCSV(t *testing.T) {
db, err := testutil.GetPGTestConn()
require.NoError(t, err)

_, err = db.Exec(`
CREATE TABLE csv_import_test (
id integer,
name varchar(50),
email varchar(100),
active boolean,
score numeric(5,2)
)
`)
require.NoError(t, err)

t.Cleanup(func() {
_, err := db.Exec(`DROP TABLE IF EXISTS csv_import_test`)
require.NoError(t, err)
})

t.Run("BasicImport", func(t *testing.T) {
csvData := `id,name,email,active,score
1,John Doe,[email protected],true,85.50
2,Jane Smith,[email protected],false,92.75
3,Bob Johnson,[email protected],true,78.25`

reader := strings.NewReader(csvData)
pg := New()

err = pg.ImportFromCSV(db, "csv_import_test", "public", reader, &seed.CSVImportOptions{
Delimiter: ",",
Header: true,
Columns: []string{"id", "name", "email", "active", "score"},
})
require.NoError(t, err)

rows, err := db.Query("SELECT * FROM csv_import_test ORDER BY id")
require.NoError(t, err)
defer rows.Close()

var results []struct {
ID int
Name string
Email string
Active bool
Score float64
}

for rows.Next() {
var r struct {
ID int
Name string
Email string
Active bool
Score float64
}
err := rows.Scan(&r.ID, &r.Name, &r.Email, &r.Active, &r.Score)
require.NoError(t, err)
results = append(results, r)
}

require.Equal(t, []struct {
ID int
Name string
Email string
Active bool
Score float64
}{
{ID: 1, Name: "John Doe", Email: "[email protected]", Active: true, Score: 85.50},
{ID: 2, Name: "Jane Smith", Email: "[email protected]", Active: false, Score: 92.75},
{ID: 3, Name: "Bob Johnson", Email: "[email protected]", Active: true, Score: 78.25},
}, results)

_, err = db.Exec("DELETE FROM csv_import_test")
require.NoError(t, err)
})

t.Run("CustomDelimiterNoHeaders", func(t *testing.T) {
csvData := `4|Alice Wilson|[email protected]|true|91.40
5|Charlie Brown|[email protected]|false|68.30`

reader := strings.NewReader(csvData)
pg := New()

err = pg.ImportFromCSV(db, "csv_import_test", "public", reader, &seed.CSVImportOptions{
Delimiter: "|",
Header: false,
Columns: []string{"id", "name", "email", "active", "score"},
})
require.NoError(t, err)

var count int
err = db.QueryRow("SELECT COUNT(*) FROM csv_import_test").Scan(&count)
require.NoError(t, err)
require.Equal(t, 2, count)

var id int
var name string
var active bool
err = db.QueryRow("SELECT id, name, active FROM csv_import_test WHERE email = '[email protected]'").Scan(&id, &name, &active)
require.NoError(t, err)
require.Equal(t, 5, id)
require.Equal(t, "Charlie Brown", name)
require.Equal(t, false, active)
})
}
19 changes: 19 additions & 0 deletions seed/csv.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package seed

// CSVImportOptions contains options for importing CSV data
type CSVImportOptions struct {
// Delimiter specifies the field delimiter (default: ",")
Delimiter string
// NullValue specifies the string that represents NULL values (default: "")
NullValue string
// Header indicates whether the CSV file includes a header row (default: true)
Header bool
// Quote specifies the quote character (default: double quote)
Quote string
// Escape specifies the escape character (default: backslash)
Escape string
// Encoding specifies the file encoding (default: "UTF8")
Encoding string
// Columns specifies the target column names (default: use header row if present)
Columns []string
}
Loading