diff --git a/go.mod b/go.mod index 319a5f9..7d5b9d3 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ require ( github.com/google/uuid v1.3.0 github.com/multiprocessio/datastation/runner v0.0.0-20220629131342-6165f9d14a67 github.com/olekukonko/tablewriter v0.0.5 + github.com/pganalyze/pg_query_go/v2 v2.1.0 + github.com/stretchr/testify v1.7.1 ) require ( @@ -38,6 +40,7 @@ require ( github.com/aws/smithy-go v1.9.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/deepmap/oapi-codegen v1.8.2 // indirect github.com/denisenkom/go-mssqldb v0.12.0 // indirect github.com/flosch/pongo2 v0.0.0-20200913210552-0d938eb266f3 // indirect @@ -81,6 +84,7 @@ require ( github.com/pierrec/lz4/v4 v4.1.14 // indirect github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.12.2 // indirect github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.34.0 // indirect diff --git a/go.sum b/go.sum index f1f3a13..7051ddf 100644 --- a/go.sum +++ b/go.sum @@ -479,6 +479,8 @@ github.com/paulmach/orb v0.5.0 h1:sNhJV5ML+mv1F077ljOck/9inorF4ahDO8iNNpHbKHY= github.com/paulmach/orb v0.5.0/go.mod h1:FWRlTgl88VI1RBx/MkrwWDRhQ96ctqMCh8boXhmqB/A= github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKfDlhJCDw2gY= github.com/pborman/getopt v0.0.0-20180729010549-6fdd0a2c7117/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= +github.com/pganalyze/pg_query_go/v2 v2.1.0 h1:donwPZ4G/X+kMs7j5eYtKjdziqyOLVp3pkUrzb9lDl8= +github.com/pganalyze/pg_query_go/v2 v2.1.0/go.mod h1:XAxmVqz1tEGqizcQ3YSdN90vCOHBWjJi8URL1er5+cA= github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2dXMnm1mY= github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= diff --git a/main.go b/main.go index f30cc67..228da17 100644 --- a/main.go +++ b/main.go @@ -19,9 +19,9 @@ import ( "strings" "time" - "github.com/chzyer/readline" "github.com/multiprocessio/datastation/runner" + "github.com/chzyer/readline" "github.com/google/uuid" "github.com/olekukonko/tablewriter" ) @@ -336,6 +336,8 @@ type args struct { isInteractive bool convertNumbers bool noSQLiteWriter bool + noFieldsGuess bool + noPrefilter bool } func getArgs() (*args, error) { @@ -344,6 +346,8 @@ func getArgs() (*args, error) { args.noSQLiteWriter = strings.ToLower(os.Getenv("DSQ_NO_SQLITE_WRITER")) == "true" args.convertNumbers = strings.ToLower(os.Getenv("DSQ_CONVERT_NUMBERS")) == "true" args.cacheSettings.Enabled = strings.ToLower(os.Getenv("DSQ_CACHE")) == "true" + args.noFieldsGuess = strings.ToLower(os.Getenv("DSQ_NO_FIELDS_GUESS")) == "true" + args.noPrefilter = strings.ToLower(os.Getenv("DSQ_NO_PREFILTER")) == "true" osArgs := os.Args[1:] for i := 0; i < len(osArgs); i++ { @@ -420,11 +424,21 @@ func getArgs() (*args, error) { continue } - if arg == "--no-sqlite-writer" { + if arg == "-nsw" || arg == "--no-sqlite-writer" { args.noSQLiteWriter = true continue } + if arg == "-np" || arg == "--no-prefilter" { + args.noPrefilter = true + continue + } + + if arg == "-nfg" || arg == "--no-fields-guess" { + args.noFieldsGuess = true + continue + } + args.nonFlagArgs = append(args.nonFlagArgs, arg) } @@ -475,7 +489,7 @@ func _main() error { return nil } - lastNonFlagArg := "" + query := "" files := args.nonFlagArgs // Grab from stdin into local file @@ -500,8 +514,8 @@ func _main() error { // If -f|--file not present, query is the last argument if args.sqlFile == "" { if len(files) > 1 { - lastNonFlagArg = files[len(files)-1] - if strings.Contains(lastNonFlagArg, " ") { + query = files[len(files)-1] + if strings.Contains(query, " ") { files = files[:len(files)-1] } } @@ -512,8 +526,8 @@ func _main() error { return errors.New("Error opening sql file: " + err.Error()) } - lastNonFlagArg = string(content) - if lastNonFlagArg == "" { + query = string(content) + if query == "" { return errors.New("SQL file is empty.") } } @@ -575,7 +589,7 @@ func _main() error { connector.DatabaseConnectorInfo.Database.Database = cachedPath } - justDumpResults := lastNonFlagArg == "" && !args.isInteractive + justDumpResults := query == "" && !args.isInteractive // Check if we can use direct SQLite writer useSQLiteWriter := !args.noSQLiteWriter && !args.schema && !justDumpResults @@ -598,6 +612,7 @@ func _main() error { mtm := runner.MimeType(mt) useSQLiteWriter = useSQLiteWriter && (mtm == runner.CSVMimeType || mtm == runner.TSVMimeType || + mtm == runner.JSONLinesMimeType || mtm == runner.RegexpLinesMimeType) if !useSQLiteWriter { break @@ -614,8 +629,24 @@ func _main() error { } } + var fieldsGuess []string + var prefilter func(map[string]any) bool + if query != "" && (!args.noFieldsGuess || !args.noPrefilter) { + a, ok := parse(rewriteQuery(query, &map[string]string{"0": "t_0"})) + if ok && !args.noFieldsGuess { + fieldsGuess, ok = identifiers(a) + if !ok { + fieldsGuess = nil + } + } + + if ok && !args.noPrefilter { + prefilter = filter(a) + } + } + // When dumping schema, need to injest even if cache is on. - if !args.cacheSettings.CachePresent || !args.cacheSettings.Enabled || lastNonFlagArg == "" { + if !args.cacheSettings.CachePresent || !args.cacheSettings.Enabled || query == "" { for i, file := range files { panelId := uuid.New().String() @@ -624,7 +655,14 @@ func _main() error { var w *runner.ResultWriter if useSQLiteWriter { tableName := fmt.Sprintf("t_%d", i) - sw, err := openSQLiteResultItemWriter(connector.DatabaseConnectorInfo.Database.Database, tableName, convertNumbers) + sw, err := openSQLiteResultItemWriter( + connector.DatabaseConnectorInfo.Database.Database, + tableName, + SQLiteResultItemWriterOptions{ + convertNumbers: convertNumbers, + prefilter: prefilter, + fieldsOverride: fieldsGuess, + }) if err != nil { return err } @@ -644,7 +682,14 @@ func _main() error { } } - panel, err := importFile(project.Id, panelId, file, mimetypeOverride[file], convertNumbers, w, !useSQLiteWriter) + panel, err := importFile( + project.Id, + panelId, + file, + mimetypeOverride[file], + convertNumbers, + w, + !useSQLiteWriter) if err != nil { return err } @@ -678,7 +723,7 @@ func _main() error { return repl(project, &ec, args, files, resolveDM_getPanelToId) } - return runQuery(lastNonFlagArg, project, &ec, args, files, resolveDM_getPanelToId) + return runQuery(query, project, &ec, args, files, resolveDM_getPanelToId) } func main() { diff --git a/scripts/test.py b/scripts/test.py index 80f7e70..b889576 100755 --- a/scripts/test.py +++ b/scripts/test.py @@ -326,7 +326,7 @@ def test(name, to_run, want, fail=False, sort=False, winSkip=False, within_secon test("URL functions", to_run, want=want, sort=True) # URL functions, split_part -to_run = """./dsq testdata/basic_logs.csv 'SELECT split_part(url_host(request), ".", -1) host, count(1) count FROM {} group by host' """ +to_run = """./dsq testdata/basic_logs.csv "SELECT split_part(url_host(request), '.', -1) host, count(1) count FROM {} group by host" """ want = '[{"host":"com","count":2}]' test("URL functions", to_run, want=want, sort=True) diff --git a/sql.go b/sql.go new file mode 100644 index 0000000..91322fb --- /dev/null +++ b/sql.go @@ -0,0 +1,215 @@ +//go:build !windows + +package main + +import ( + "fmt" + "strings" + + "github.com/multiprocessio/datastation/runner" + + q "github.com/pganalyze/pg_query_go/v2" +) + +func parse(query string) (*q.SelectStmt, bool) { + ast, err := q.Parse(query) + if err != nil { + return nil, false + } + + return ast.Stmts[0].Stmt.GetSelectStmt(), true +} + +func getValidIdentifier(n *q.Node) ([]string, bool) { + if n == nil { + return nil, false + } + + // Constants are fine + if n.GetAConst() != nil { + return nil, true + } + + if fc := n.GetFuncCall(); fc != nil { + var fields []string + for _, arg := range fc.Args { + _fields, ok := getValidIdentifier(arg) + if !ok { + return nil, false + } + + fields = append(fields, _fields...) + } + + return fields, true + } + + if e := n.GetAExpr(); e != nil { + l, ok := getValidIdentifier(e.Lexpr) + if !ok { + return nil, false + } + + r, ok := getValidIdentifier(e.Rexpr) + if !ok { + return nil, false + } + + return append(l, r...), true + } + + // Otherwise must be an identifier + cr := n.GetColumnRef() + if cr == nil { + return nil, false + } + + var parts []string + for _, field := range cr.Fields { + s := field.GetString_() + if s == nil { + return nil, false + } + + parts = append(parts, s.Str) + } + + s := strings.Join(parts, ".") + return []string{s}, true +} + +func identifiers(slct *q.SelectStmt) ([]string, bool) { + var fields []string + for _, t := range slct.TargetList { + v := t.GetResTarget().GetVal() + + _fields, ok := getValidIdentifier(v) + if !ok { + return nil, false + } + + fields = append(fields, _fields...) + } + + if len(slct.FromClause) != 1 { + return nil, false + } + rv := slct.FromClause[0].GetRangeVar() + if rv == nil { + return nil, false + } + if rv.GetRelname() == "" { + return nil, false + } + + if slct.WhereClause != nil { + where, ok := getValidIdentifier(slct.WhereClause) + if !ok { + return nil, false + } + fields = append(fields, where...) + } + + return fields, true +} + +func evalNode(n *q.Node, row map[string]any) any { + if n == nil { + return nil + } + + if c := n.GetAConst(); c != nil { + if i := c.GetVal().GetInteger(); i != nil { + return i.String() + } + + if f := c.GetVal().GetFloat(); f != nil { + return f.GetStr() + } + + if s := c.GetVal().GetString_(); s != nil { + return s.Str + } + + // Unsupported const type + return nil + } + + // Filtering on function calls unsupported + if fc := n.GetFuncCall(); fc != nil { + return nil + } + + if e := n.GetAExpr(); e != nil { + if len(e.Name) != 1 { + return nil + } + + _l := evalNode(e.Lexpr, row) + if _l == nil { + return nil + } + + l, ok := _l.(string) + if !ok { + return nil + } + + _r := evalNode(e.Rexpr, row) + if _r == nil { + return nil + } + + r, ok := _r.(string) + if !ok { + return nil + } + + switch e.Name[0].GetString_().Str { + case ">": + return l > r + case "<": + return l < r + case ">=": + return l >= r + case "<=": + return l <= r + case "=": + return l == r + } + } + + // Otherwise must be an identifier + cr := n.GetColumnRef() + if cr == nil { + return nil + } + + if len(cr.Fields) != 1 { + return nil + } + + s := cr.Fields[0].GetString_() + if s == nil { + return nil + } + + v := runner.GetObjectAtPath(row, s.Str) + switch v.(type) { + case string: + return v + default: + return fmt.Sprintf("%#v", v) + } +} + +func filter(slct *q.SelectStmt) func(m map[string]any) bool { + return func(row map[string]any) bool { + if slct.WhereClause == nil { + return false + } + + x := evalNode(slct.WhereClause, row) + return x != true && x != nil + } +} diff --git a/sql_test.go b/sql_test.go new file mode 100644 index 0000000..f6a6441 --- /dev/null +++ b/sql_test.go @@ -0,0 +1,129 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_identifiers(t *testing.T) { + tests := []struct { + in string + idents []string + ok bool + }{ + { + "SELECT a FROM x", + []string{"a"}, + true, + }, + { + "SELECT a, '1' FROM x", + []string{"a"}, + true, + }, + { + "SELECT a, b FROM x", + []string{"a", "b"}, + true, + }, + { + "SELECT avg(b) FROM x", + []string{"b"}, + true, + }, + { + "SELECT avg(b+1) FROM x", + []string{"b"}, + true, + }, + { + "SELECT b+1 FROM x", + []string{"b"}, + true, + }, + { + "SELECT a FROM x WHERE b > 1", + []string{"a", "b"}, + true, + }, + { + "SELECT * FROM x", + nil, + false, + }, + { + "SELECT a FROM x, y", + nil, + false, + }, + } + + for _, test := range tests { + t.Logf("Testing: %s", test.in) + s, ok := parse(test.in) + assert.True(t, ok) + + idents, ok := identifiers(s) + assert.Equal(t, test.idents, idents) + assert.Equal(t, test.ok, ok) + } +} + +func Test_filter(t *testing.T) { + tests := []struct { + query string + inRows []map[string]any + outRows []map[string]any + }{ + { + "SELECT a FROM x", + []map[string]any{ + {"a": 1}, + {"a": 2}, + }, + []map[string]any{ + {"a": 1}, + {"a": 2}, + }, + }, + { + "SELECT a FROM x WHERE avg(a) > 12", + []map[string]any{ + {"a": 1}, + {"a": 2}, + }, + []map[string]any{ + {"a": 1}, + {"a": 2}, + }, + }, + { + "SELECT a FROM x WHERE a = 12", + []map[string]any{ + {"a": 1}, + {"a": 12}, + }, + []map[string]any{ + {"a": 12}, + }, + }, + } + + for _, test := range tests { + t.Logf("Testing: %s", test.query) + s, ok := parse(test.query) + assert.True(t, ok) + + f := filter(s) + var end []map[string]any + for _, r := range test.inRows { + canFilter := f(r) + if !canFilter { + end = append(end, r) + } + } + + assert.Equal(t, test.outRows, end) + } +} diff --git a/sql_windows.go b/sql_windows.go new file mode 100644 index 0000000..648b956 --- /dev/null +++ b/sql_windows.go @@ -0,0 +1,17 @@ +package main + +type SelectStmt struct{} + +func parse(query string) (*SelectStmt, bool) { + return nil, false +} + +func identifiers(slct *SelectStmt) ([]string, bool) { + return nil, false +} + +func filter(slct *SelectStmt) func(m map[string]any) bool { + return func(_ map[string]any) bool { + return false + } +} diff --git a/sqlite.go b/sqlite.go index 5d74445..d686962 100644 --- a/sqlite.go +++ b/sqlite.go @@ -9,18 +9,28 @@ import ( "github.com/multiprocessio/datastation/runner" ) -type SQLiteResultItemWriter struct { - db *sql.DB - fields []string - panelId string - rowBuffer runner.Vector[any] +type SQLiteResultItemWriterOptions struct { convertNumbers bool + prefilter func(map[string]any) bool + fieldsOverride []string +} + +type SQLiteResultItemWriter struct { + tableCreated bool + db *sql.DB + fields []string + panelId string + rowBuffer runner.Vector[any] + + SQLiteResultItemWriterOptions } -func openSQLiteResultItemWriter(f string, panelId string, convertNumbers bool) (runner.ResultItemWriter, error) { +func openSQLiteResultItemWriter(f string, panelId string, opts SQLiteResultItemWriterOptions) (runner.ResultItemWriter, error) { var sw SQLiteResultItemWriter sw.panelId = panelId - sw.convertNumbers = convertNumbers + sw.SQLiteResultItemWriterOptions = opts + + sw.fields = opts.fieldsOverride sw.rowBuffer = runner.Vector[any]{} @@ -127,15 +137,26 @@ func (sw *SQLiteResultItemWriter) WriteRow(r any, written int) error { return fmt.Errorf("Row must be a map, got: %#v", r) } - if len(sw.fields) == 0 { - for key := range m { - sw.fields = append(sw.fields, key) + if sw.prefilter != nil { + canSkip := sw.prefilter(m) + if canSkip { + return nil + } + } + + if !sw.tableCreated { + if len(sw.fields) == 0 { + for key := range m { + sw.fields = append(sw.fields, key) + } } err := sw.createTable() if err != nil { return err } + + sw.tableCreated = true } for _, field := range sw.fields { @@ -147,6 +168,9 @@ func (sw *SQLiteResultItemWriter) WriteRow(r any, written int) error { return err } v = string(bs) + // TODO: don't keep this + case map[string]any: + v = nil } sw.rowBuffer.Append(v) }