From e5437bbf8594a635829574113bd89876889f766c Mon Sep 17 00:00:00 2001 From: Zuo Wang Date: Tue, 6 May 2025 23:06:38 +0800 Subject: [PATCH] add seed/csv --- postgresql/seed.go | 122 ++++++++++++++++++++++++++++++++++++++++ postgresql/seed_test.go | 117 ++++++++++++++++++++++++++++++++++++++ seed/csv.go | 19 +++++++ 3 files changed, 258 insertions(+) create mode 100644 postgresql/seed.go create mode 100644 postgresql/seed_test.go create mode 100644 seed/csv.go diff --git a/postgresql/seed.go b/postgresql/seed.go new file mode 100644 index 0000000..ec212de --- /dev/null +++ b/postgresql/seed.go @@ -0,0 +1,122 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "io" + "strings" + + "github.com/swiftcarrot/dbx/seed" +) + +// ImportFromCSV imports data from a CSV file into a PostgreSQL table +// It uses PostgreSQL's COPY command for efficient bulk data loading +func (pg *PostgreSQL) ImportFromCSV(db *sql.DB, tableName, schemaName string, reader io.Reader, options *seed.CSVImportOptions) error { + // Set default options if not provided + if options == nil { + options = &CSVImportOptions{ + Delimiter: ",", + NullValue: "", + Header: true, + Quote: "\"", + Escape: "\\", + Encoding: "UTF8", + } + } + + // Apply defaults for unset options + if options.Delimiter == "" { + options.Delimiter = "," + } + if options.Quote == "" { + options.Quote = "\"" + } + if options.Escape == "" { + options.Escape = "\\" + } + if options.Encoding == "" { + options.Encoding = "UTF8" + } + + // Qualify the table name with schema if provided + qualifiedTable := tableName + if schemaName != "" && schemaName != "public" { + qualifiedTable = fmt.Sprintf("%s.%s", quoteIdentifier(schemaName), quoteIdentifier(tableName)) + } else { + qualifiedTable = quoteIdentifier(tableName) + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer func() { + if err != nil { + tx.Rollback() + } + }() + + // Build the COPY command + copyCmd := fmt.Sprintf("COPY %s", qualifiedTable) + + // Add columns if specified + if len(options.Columns) > 0 { + columnNames := make([]string, len(options.Columns)) + for i, col := range options.Columns { + columnNames[i] = quoteIdentifier(col) + } + copyCmd += fmt.Sprintf(" (%s)", strings.Join(columnNames, ", ")) + } + + copyCmd += " FROM STDIN WITH (" + copyCmd += fmt.Sprintf("FORMAT CSV, DELIMITER '%s'", options.Delimiter) + + if options.Header { + copyCmd += ", HEADER" + } + + if options.Quote != "" { + copyCmd += fmt.Sprintf(", QUOTE '%s'", options.Quote) + } + + if options.Escape != "" { + copyCmd += fmt.Sprintf(", ESCAPE '%s'", options.Escape) + } + + if options.NullValue != "" { + copyCmd += fmt.Sprintf(", NULL '%s'", options.NullValue) + } + + if options.Encoding != "" { + copyCmd += fmt.Sprintf(", ENCODING '%s'", options.Encoding) + } + + copyCmd += ")" + + // Get the PostgreSQL db/sql connection + stmt, err := tx.Prepare(copyCmd) + if err != nil { + return fmt.Errorf("failed to prepare COPY statement: %w", err) + } + defer stmt.Close() + + // Get a reference to the underlying PostgreSQL connection + _, err = stmt.Exec() + if err != nil { + return fmt.Errorf("failed to execute COPY command: %w", err) + } + + // Copy data + _, err = io.Copy(tx.(*sql.Tx), reader) + if err != nil { + return fmt.Errorf("failed to copy data: %w", err) + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} diff --git a/postgresql/seed_test.go b/postgresql/seed_test.go new file mode 100644 index 0000000..ef9967e --- /dev/null +++ b/postgresql/seed_test.go @@ -0,0 +1,117 @@ +package postgresql + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/swiftcarrot/dbx/internal/testutil" + "github.com/swiftcarrot/dbx/seed" +) + +func TestImportFromCSV(t *testing.T) { + db, err := testutil.GetPGTestConn() + require.NoError(t, err) + + _, err = db.Exec(` + CREATE TABLE csv_import_test ( + id integer, + name varchar(50), + email varchar(100), + active boolean, + score numeric(5,2) + ) + `) + require.NoError(t, err) + + t.Cleanup(func() { + _, err := db.Exec(`DROP TABLE IF EXISTS csv_import_test`) + require.NoError(t, err) + }) + + t.Run("BasicImport", func(t *testing.T) { + csvData := `id,name,email,active,score +1,John Doe,john@example.com,true,85.50 +2,Jane Smith,jane@example.com,false,92.75 +3,Bob Johnson,bob@example.com,true,78.25` + + reader := strings.NewReader(csvData) + pg := New() + + err = pg.ImportFromCSV(db, "csv_import_test", "public", reader, &seed.CSVImportOptions{ + Delimiter: ",", + Header: true, + Columns: []string{"id", "name", "email", "active", "score"}, + }) + require.NoError(t, err) + + rows, err := db.Query("SELECT * FROM csv_import_test ORDER BY id") + require.NoError(t, err) + defer rows.Close() + + var results []struct { + ID int + Name string + Email string + Active bool + Score float64 + } + + for rows.Next() { + var r struct { + ID int + Name string + Email string + Active bool + Score float64 + } + err := rows.Scan(&r.ID, &r.Name, &r.Email, &r.Active, &r.Score) + require.NoError(t, err) + results = append(results, r) + } + + require.Equal(t, []struct { + ID int + Name string + Email string + Active bool + Score float64 + }{ + {ID: 1, Name: "John Doe", Email: "john@example.com", Active: true, Score: 85.50}, + {ID: 2, Name: "Jane Smith", Email: "jane@example.com", Active: false, Score: 92.75}, + {ID: 3, Name: "Bob Johnson", Email: "bob@example.com", Active: true, Score: 78.25}, + }, results) + + _, err = db.Exec("DELETE FROM csv_import_test") + require.NoError(t, err) + }) + + t.Run("CustomDelimiterNoHeaders", func(t *testing.T) { + csvData := `4|Alice Wilson|alice@example.com|true|91.40 +5|Charlie Brown|charlie@example.com|false|68.30` + + reader := strings.NewReader(csvData) + pg := New() + + err = pg.ImportFromCSV(db, "csv_import_test", "public", reader, &seed.CSVImportOptions{ + Delimiter: "|", + Header: false, + Columns: []string{"id", "name", "email", "active", "score"}, + }) + require.NoError(t, err) + + var count int + err = db.QueryRow("SELECT COUNT(*) FROM csv_import_test").Scan(&count) + require.NoError(t, err) + require.Equal(t, 2, count) + + var id int + var name string + var active bool + err = db.QueryRow("SELECT id, name, active FROM csv_import_test WHERE email = 'charlie@example.com'").Scan(&id, &name, &active) + require.NoError(t, err) + require.Equal(t, 5, id) + require.Equal(t, "Charlie Brown", name) + require.Equal(t, false, active) + }) +} diff --git a/seed/csv.go b/seed/csv.go new file mode 100644 index 0000000..002f287 --- /dev/null +++ b/seed/csv.go @@ -0,0 +1,19 @@ +package seed + +// CSVImportOptions contains options for importing CSV data +type CSVImportOptions struct { + // Delimiter specifies the field delimiter (default: ",") + Delimiter string + // NullValue specifies the string that represents NULL values (default: "") + NullValue string + // Header indicates whether the CSV file includes a header row (default: true) + Header bool + // Quote specifies the quote character (default: double quote) + Quote string + // Escape specifies the escape character (default: backslash) + Escape string + // Encoding specifies the file encoding (default: "UTF8") + Encoding string + // Columns specifies the target column names (default: use header row if present) + Columns []string +}