From e949b19877e946b804d7799b716c94cf17b06bc7 Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Tue, 22 Apr 2025 20:31:15 +0300 Subject: [PATCH 01/18] Initial YDB support: SELECT logic, some convert tests, and YDB engine integration - Added reserved keywords and parsing logic - CREATE TABLE support and tests - SELECT initial support and tests - Initial YDB engine integration and examples --- docker-compose.yml | 13 + examples/authors/sqlc.yaml | 10 + examples/authors/ydb/db.go | 31 + examples/authors/ydb/db_test.go | 102 ++ examples/authors/ydb/models.go | 15 + examples/authors/ydb/query.sql | 19 + examples/authors/ydb/query.sql.go | 133 ++ examples/authors/ydb/schema.sql | 6 + internal/codegen/golang/go_type.go | 2 + internal/codegen/golang/ydb_type.go | 154 ++ internal/compiler/engine.go | 5 + internal/config/config.go | 1 + internal/engine/ydb/catalog.go | 19 + .../ydb/catalog_tests/create_table_test.go | 166 ++ .../engine/ydb/catalog_tests/select_test.go | 386 ++++ internal/engine/ydb/convert.go | 1616 +++++++++++++++++ internal/engine/ydb/parse.go | 93 + internal/engine/ydb/reserved.go | 301 +++ internal/engine/ydb/stdlib.go | 12 + internal/engine/ydb/utils.go | 143 ++ internal/sql/rewrite/parameters.go | 2 + internal/sqltest/local/ydb.go | 117 ++ 22 files changed, 3346 insertions(+) create mode 100644 examples/authors/ydb/db.go create mode 100644 examples/authors/ydb/db_test.go create mode 100644 examples/authors/ydb/models.go create mode 100644 examples/authors/ydb/query.sql create mode 100644 examples/authors/ydb/query.sql.go create mode 100644 examples/authors/ydb/schema.sql create mode 100644 internal/codegen/golang/ydb_type.go create mode 100644 internal/engine/ydb/catalog.go create mode 100644 internal/engine/ydb/catalog_tests/create_table_test.go create mode 100644 internal/engine/ydb/catalog_tests/select_test.go create mode 100755 internal/engine/ydb/convert.go create mode 100755 internal/engine/ydb/parse.go create mode 100644 internal/engine/ydb/reserved.go create mode 100644 internal/engine/ydb/stdlib.go create mode 100755 internal/engine/ydb/utils.go create mode 100644 internal/sqltest/local/ydb.go diff --git a/docker-compose.yml b/docker-compose.yml index f318d1ed93..e7c66b42ae 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,3 +19,16 @@ services: POSTGRES_DB: postgres POSTGRES_PASSWORD: mysecretpassword POSTGRES_USER: postgres + + ydb: + image: ydbplatform/local-ydb:latest + ports: + - "2135:2135" + - "2136:2136" + - "8765:8765" + restart: always + environment: + - YDB_USE_IN_MEMORY_PDISKS=true + - GRPC_TLS_PORT=2135 + - GRPC_PORT=2136 + - MON_PORT=8765 diff --git a/examples/authors/sqlc.yaml b/examples/authors/sqlc.yaml index 57f2319ea1..8d6bc3db28 100644 --- a/examples/authors/sqlc.yaml +++ b/examples/authors/sqlc.yaml @@ -43,6 +43,16 @@ sql: go: package: authors out: sqlite +- name: ydb + schema: ydb/schema.sql + queries: ydb/query.sql + engine: ydb + gen: + go: + package: authors + out: ydb + + rules: - name: postgresql-query-too-costly message: "Too costly" diff --git a/examples/authors/ydb/db.go b/examples/authors/ydb/db.go new file mode 100644 index 0000000000..2bb1bfc27d --- /dev/null +++ b/examples/authors/ydb/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 + +package authors + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/examples/authors/ydb/db_test.go b/examples/authors/ydb/db_test.go new file mode 100644 index 0000000000..181ee64ed1 --- /dev/null +++ b/examples/authors/ydb/db_test.go @@ -0,0 +1,102 @@ +package authors + +import ( + "context" + "testing" + + "github.com/sqlc-dev/sqlc/internal/sqltest/local" + _ "github.com/ydb-platform/ydb-go-sdk/v3" +) + +func TestAuthors(t *testing.T) { + ctx := context.Background() + + test := local.YDB(t, []string{"schema.sql"}) + defer test.DB.Close() + + q := New(test.DB) + + t.Run("ListAuthors", func(t *testing.T) { + authors, err := q.ListAuthors(ctx) + if err != nil { + t.Fatal(err) + } + if len(authors) == 0 { + t.Fatal("expected at least one author, got none") + } + t.Log("Authors:") + for _, a := range authors { + bio := "NULL" + if a.Bio.Valid { + bio = a.Bio.String + } + t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) + } + }) + + t.Run("GetAuthor", func(t *testing.T) { + singleAuthor, err := q.GetAuthor(ctx, 10) + if err != nil { + t.Fatal(err) + } + bio := "NULL" + if singleAuthor.Bio.Valid { + bio = singleAuthor.Bio.String + } + t.Logf("- ID: %d | Name: %s | Bio: %s", singleAuthor.ID, singleAuthor.Name, bio) + }) + + t.Run("GetAuthorByName", func(t *testing.T) { + authors, err := q.GetAuthorsByName(ctx, "Александр Пушкин") + if err != nil { + t.Fatal(err) + } + if len(authors) == 0 { + t.Fatal("expected at least one author with this name, got none") + } + t.Log("Authors with this name:") + for _, a := range authors { + bio := "NULL" + if a.Bio.Valid { + bio = a.Bio.String + } + t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) + } + }) + + t.Run("ListAuthorsWithIdModulo", func(t *testing.T) { + authors, err := q.ListAuthorsWithIdModulo(ctx) + if err != nil { + t.Fatal(err) + } + if len(authors) == 0 { + t.Fatal("expected at least one author with even ID, got none") + } + t.Log("Authors with even IDs:") + for _, a := range authors { + bio := "NULL" + if a.Bio.Valid { + bio = a.Bio.String + } + t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) + } + }) + + t.Run("ListAuthorsWithNullBio", func(t *testing.T) { + authors, err := q.ListAuthorsWithNullBio(ctx) + if err != nil { + t.Fatal(err) + } + if len(authors) == 0 { + t.Fatal("expected at least one author with NULL bio, got none") + } + t.Log("Authors with NULL bio:") + for _, a := range authors { + bio := "NULL" + if a.Bio.Valid { + bio = a.Bio.String + } + t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) + } + }) +} diff --git a/examples/authors/ydb/models.go b/examples/authors/ydb/models.go new file mode 100644 index 0000000000..e899b195b0 --- /dev/null +++ b/examples/authors/ydb/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 + +package authors + +import ( + "database/sql" +) + +type Author struct { + ID uint64 + Name string + Bio sql.NullString +} diff --git a/examples/authors/ydb/query.sql b/examples/authors/ydb/query.sql new file mode 100644 index 0000000000..219d680ba1 --- /dev/null +++ b/examples/authors/ydb/query.sql @@ -0,0 +1,19 @@ +-- name: ListAuthors :many +SELECT * FROM authors; + +-- name: GetAuthor :one +SELECT * FROM authors +WHERE id = $p0; + +-- name: ListAuthorsWithIdModulo :many +SELECT * FROM authors +WHERE id % 2 = 0; + +-- name: GetAuthorsByName :many +SELECT * FROM authors +WHERE name = $p0; + +-- name: ListAuthorsWithNullBio :many +SELECT * FROM authors +WHERE bio IS NULL; + diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go new file mode 100644 index 0000000000..53ed896128 --- /dev/null +++ b/examples/authors/ydb/query.sql.go @@ -0,0 +1,133 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: query.sql + +package authors + +import ( + "context" +) + +const getAuthor = `-- name: GetAuthor :one +SELECT id, name, bio FROM authors +WHERE id = $p0 +` + +func (q *Queries) GetAuthor(ctx context.Context, p0 uint64) (Author, error) { + row := q.db.QueryRowContext(ctx, getAuthor, p0) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const getAuthorsByName = `-- name: GetAuthorsByName :many +SELECT id, name, bio FROM authors +WHERE name = $p0 +` + +func (q *Queries) GetAuthorsByName(ctx context.Context, p0 string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, getAuthorsByName, p0) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio FROM authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsWithIdModulo = `-- name: ListAuthorsWithIdModulo :many +SELECT id, name, bio FROM authors +WHERE id % 2 = 0 +` + +func (q *Queries) ListAuthorsWithIdModulo(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsWithIdModulo) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsWithNullBio = `-- name: ListAuthorsWithNullBio :many +SELECT id, name, bio FROM authors +WHERE bio IS NULL +` + +func (q *Queries) ListAuthorsWithNullBio(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsWithNullBio) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/examples/authors/ydb/schema.sql b/examples/authors/ydb/schema.sql new file mode 100644 index 0000000000..ee9329e809 --- /dev/null +++ b/examples/authors/ydb/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE authors ( + id Uint64, + name Utf8 NOT NULL, + bio Utf8, + PRIMARY KEY (id) +); diff --git a/internal/codegen/golang/go_type.go b/internal/codegen/golang/go_type.go index c4aac84dd6..11eb8931df 100644 --- a/internal/codegen/golang/go_type.go +++ b/internal/codegen/golang/go_type.go @@ -89,6 +89,8 @@ func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin return postgresType(req, options, col) case "sqlite": return sqliteType(req, options, col) + case "ydb": + return YDBType(req, options, col) default: return "interface{}" } diff --git a/internal/codegen/golang/ydb_type.go b/internal/codegen/golang/ydb_type.go new file mode 100644 index 0000000000..8a5b1711b3 --- /dev/null +++ b/internal/codegen/golang/ydb_type.go @@ -0,0 +1,154 @@ +package golang + +import ( + "log" + "strings" + + "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts" + "github.com/sqlc-dev/sqlc/internal/codegen/sdk" + "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/plugin" +) + +func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string { + columnType := strings.ToLower(sdk.DataType(col.Type)) + notNull := col.NotNull || col.IsArray + emitPointersForNull := options.EmitPointersForNullTypes + + // https://ydb.tech/docs/ru/yql/reference/types/ + switch columnType { + // decimal types + case "bool": + if notNull { + return "bool" + } + if emitPointersForNull { + return "*bool" + } + return "sql.NullBool" + + case "int8": + if notNull { + return "int8" + } + if emitPointersForNull { + return "*int8" + } + // The database/sql package does not have a sql.NullInt8 type, so we + // use the smallest type they have which is NullInt16 + return "sql.NullInt16" + case "int16": + if notNull { + return "int16" + } + if emitPointersForNull { + return "*int16" + } + return "sql.NullInt16" + case "int32": + if notNull { + return "int32" + } + if emitPointersForNull { + return "*int32" + } + return "sql.NullInt32" + case "int64": + if notNull { + return "int64" + } + if emitPointersForNull { + return "*int64" + } + return "sql.NullInt64" + + case "uint8": + if emitPointersForNull { + return "*uint8" + } + return "uint8" + case "uint16": + if emitPointersForNull { + return "*uint16" + } + return "uint16" + case "uint32": + if emitPointersForNull { + return "*uint32" + } + return "uint32" + case "uint64": + if emitPointersForNull { + return "*uint64" + } + return "uint64" + + case "float": + if notNull { + return "float32" + } + if emitPointersForNull { + return "*float32" + } + // The database/sql package does not have a sql.NullFloat32 type, so we + // use the smallest type they have which is NullFloat64 + return "sql.NullFloat64" + case "double": + if notNull { + return "float64" + } + if emitPointersForNull { + return "*float64" + } + return "sql.NullFloat64" + + // string types + case "string", "utf8", "text": + if notNull { + return "string" + } + if emitPointersForNull { + return "*string" + } + return "sql.NullString" + + // serial types + case "smallserial", "serial2": + if notNull { + return "int16" + } + if emitPointersForNull { + return "*int16" + } + return "sql.NullInt16" + + case "serial", "serial4": + if notNull { + return "int32" + } + if emitPointersForNull { + return "*int32" + } + return "sql.NullInt32" + + case "bigserial", "serial8": + if notNull { + return "int64" + } + if emitPointersForNull { + return "*int64" + } + return "sql.NullInt64" + + case "null": + return "sql.Null" + + default: + if debug.Active { + log.Printf("unknown SQLite type: %s\n", columnType) + } + + return "interface{}" + } + +} diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index f742bfd999..245552b07f 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -11,6 +11,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/engine/postgresql" pganalyze "github.com/sqlc-dev/sqlc/internal/engine/postgresql/analyzer" "github.com/sqlc-dev/sqlc/internal/engine/sqlite" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" "github.com/sqlc-dev/sqlc/internal/opts" "github.com/sqlc-dev/sqlc/internal/sql/catalog" ) @@ -41,6 +42,10 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err c.parser = sqlite.NewParser() c.catalog = sqlite.NewCatalog() c.selector = newSQLiteSelector() + case config.EngineYDB: + c.parser = ydb.NewParser() + c.catalog = ydb.NewCatalog() + c.selector = newDefaultSelector() case config.EngineMySQL: c.parser = dolphin.NewParser() c.catalog = dolphin.NewCatalog() diff --git a/internal/config/config.go b/internal/config/config.go index 0ff805fccd..f7df94e5f8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -54,6 +54,7 @@ const ( EngineMySQL Engine = "mysql" EnginePostgreSQL Engine = "postgresql" EngineSQLite Engine = "sqlite" + EngineYDB Engine = "ydb" ) type Config struct { diff --git a/internal/engine/ydb/catalog.go b/internal/engine/ydb/catalog.go new file mode 100644 index 0000000000..f191d936f3 --- /dev/null +++ b/internal/engine/ydb/catalog.go @@ -0,0 +1,19 @@ +package ydb + +import "github.com/sqlc-dev/sqlc/internal/sql/catalog" + + +func NewCatalog() *catalog.Catalog { + def := "main" + return &catalog.Catalog{ + DefaultSchema: def, + Schemas: []*catalog.Schema{ + defaultSchema(def), + }, + Extensions: map[string]struct{}{}, + } +} + +func NewTestCatalog() *catalog.Catalog { + return catalog.New("main") +} diff --git a/internal/engine/ydb/catalog_tests/create_table_test.go b/internal/engine/ydb/catalog_tests/create_table_test.go new file mode 100644 index 0000000000..e98288d75a --- /dev/null +++ b/internal/engine/ydb/catalog_tests/create_table_test.go @@ -0,0 +1,166 @@ +package ydb_test + +import ( + "strconv" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func TestCreateTable(t *testing.T) { + tests := []struct { + stmt string + s *catalog.Schema + }{ + { + stmt: `CREATE TABLE users ( + id Uint64, + age Int32, + score Float, + PRIMARY KEY (id) + )`, + s: &catalog.Schema{ + Name: "main", + Tables: []*catalog.Table{ + { + Rel: &ast.TableName{Name: "users"}, + Columns: []*catalog.Column{ + { + Name: "id", + Type: ast.TypeName{Name: "Uint64"}, + IsNotNull: true, + }, + { + Name: "age", + Type: ast.TypeName{Name: "Int32"}, + }, + { + Name: "score", + Type: ast.TypeName{Name: "Float"}, + }, + }, + }, + }, + }, + }, + { + stmt: `CREATE TABLE posts ( + id Uint64, + title Utf8 NOT NULL, + content String, + metadata Json, + PRIMARY KEY (id) + )`, + s: &catalog.Schema{ + Name: "main", + Tables: []*catalog.Table{ + { + Rel: &ast.TableName{Name: "posts"}, + Columns: []*catalog.Column{ + { + Name: "id", + Type: ast.TypeName{Name: "Uint64"}, + IsNotNull: true, + }, + { + Name: "title", + Type: ast.TypeName{Name: "Utf8"}, + IsNotNull: true, + }, + { + Name: "content", + Type: ast.TypeName{Name: "String"}, + }, + { + Name: "metadata", + Type: ast.TypeName{Name: "Json"}, + }, + }, + }, + }, + }, + }, + { + stmt: `CREATE TABLE orders ( + id Uuid, + amount Decimal(22,9), + created_at Uint64, + PRIMARY KEY (id) + )`, + s: &catalog.Schema{ + Name: "main", + Tables: []*catalog.Table{ + { + Rel: &ast.TableName{Name: "orders"}, + Columns: []*catalog.Column{ + { + Name: "id", + Type: ast.TypeName{Name: "Uuid"}, + IsNotNull: true, + }, + { + Name: "amount", + Type: ast.TypeName{ + Name: "Decimal", + Names: &ast.List{ + Items: []ast.Node{ + &ast.Integer{Ival: 22}, + &ast.Integer{Ival: 9}, + }, + }, + }, + }, + { + Name: "created_at", + Type: ast.TypeName{Name: "Uint64"}, + }, + }, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for i, tc := range tests { + test := tc + t.Run(strconv.Itoa(i), func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(test.stmt)) + if err != nil { + t.Log(test.stmt) + t.Fatal(err) + } + + c := ydb.NewTestCatalog() + if err := c.Build(stmts); err != nil { + t.Log(test.stmt) + t.Fatal(err) + } + + e := ydb.NewTestCatalog() + if test.s != nil { + var replaced bool + for i := range e.Schemas { + if e.Schemas[i].Name == test.s.Name { + e.Schemas[i] = test.s + replaced = true + break + } + } + if !replaced { + e.Schemas = append(e.Schemas, test.s) + } + } + + if diff := cmp.Diff(e, c, cmpopts.EquateEmpty(), cmpopts.IgnoreUnexported(catalog.Column{})); diff != "" { + t.Log(test.stmt) + t.Errorf("catalog mismatch:\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/catalog_tests/select_test.go b/internal/engine/ydb/catalog_tests/select_test.go new file mode 100644 index 0000000000..95ae49163d --- /dev/null +++ b/internal/engine/ydb/catalog_tests/select_test.go @@ -0,0 +1,386 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func strPtr(s string) *string { + return &s +} + +func TestSelect(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + // Basic Types Select + { + stmt: `SELECT 52`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.Integer{Ival: 52}, + }, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `SELECT 'hello'`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.String{Str: "hello"}, + }, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `SELECT 'it\'s string with quote in it'`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.String{Str: `it\'s string with quote in it`}, + }, + }, + }, + }, + }, + }, + }, + }, + { + stmt: "SELECT 3.14", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.Float{Str: "3.14"}, + }, + }, + }, + }, + }, + }, + }, + }, + { + stmt: "SELECT NULL", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.Null{}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: "SELECT true", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.Boolean{Boolval: true}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: "SELECT 2+3*4", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "+"}, + }, + }, + Lexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 2}, + }, + Rexpr: &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "*"}, + }, + }, + Lexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 3}, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 4}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + + // Select with From Clause tests + { + stmt: `SELECT * FROM users`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.A_Star{}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + }, + }, + }, + }, + { + stmt: "SELECT id AS identifier FROM users", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Name: strPtr("identifier"), + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + }, + }, + }, + }, + { + stmt: "SELECT a.b.c FROM table", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "a"}, + &ast.String{Str: "b"}, + &ast.String{Str: "c"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("table"), + }, + }, + }, + }, + }, + }, + }, + { + stmt: "SELECT id.age, 3.14, 'abc', NULL, false FROM users", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + &ast.String{Str: "age"}, + }, + }, + }, + }, + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.Float{Str: "3.14"}, + }, + }, + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.String{Str: "abc"}, + }, + }, + &ast.ResTarget{ + Val: &ast.Null{}, + }, + &ast.ResTarget{ + Val: &ast.Boolean{Boolval: false}, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + }, + }, + }, + }, + { + stmt: `SELECT id, name FROM users WHERE age > 30`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "name"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + WhereClause: &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: ">"}, + }, + }, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "age"}, + }, + }, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 30}, + }, + }, + }, + }, + }, + }, + } + + p := ydb.NewParser() + + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Запрос %q не распарсен", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + // cmpopts.IgnoreFields(ast.SelectStmt{}, "Location"), + cmpopts.IgnoreFields(ast.A_Const{}, "Location"), + cmpopts.IgnoreFields(ast.ResTarget{}, "Location"), + cmpopts.IgnoreFields(ast.ColumnRef{}, "Location"), + cmpopts.IgnoreFields(ast.A_Expr{}, "Location"), + cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), + ) + if diff != "" { + t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go new file mode 100755 index 0000000000..341de1a2e9 --- /dev/null +++ b/internal/engine/ydb/convert.go @@ -0,0 +1,1616 @@ +package ydb + +import ( + "log" + "strconv" + "strings" + + "github.com/antlr4-go/antlr/v4" + "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + parser "github.com/ydb-platform/yql-parsers/go" +) + +type cc struct { + paramCount int +} + +type node interface { + GetParser() antlr.Parser +} + +func todo(funcname string, n node) *ast.TODO { + if debug.Active { + log.Printf("sqlite.%s: Unknown node type %T\n", funcname, n) + } + return &ast.TODO{} +} + +func identifier(id string) string { + if len(id) >= 2 && id[0] == '"' && id[len(id)-1] == '"' { + unquoted, _ := strconv.Unquote(id) + return unquoted + } + return strings.ToLower(id) +} + +func NewIdentifier(t string) *ast.String { + return &ast.String{Str: identifier(t)} +} + +func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) ast.Node { + tableRef := parseTableName(n.Simple_table_ref().Simple_table_ref_core()) + + stmt := &ast.AlterTableStmt{ + Table: tableRef, + Cmds: &ast.List{}, + } + for _, action := range n.AllAlter_table_action() { + if add := action.Alter_table_add_column(); add != nil { + } + } + return stmt +} + +func (c *cc) convertSelectStmtContext(n *parser.Select_stmtContext) ast.Node { + skp := n.Select_kind_parenthesis(0) + if skp == nil { + return nil + } + partial := skp.Select_kind_partial() + if partial == nil { + return nil + } + sk := partial.Select_kind() + if sk == nil { + return nil + } + selectStmt := &ast.SelectStmt{} + + switch { + case sk.Process_core() != nil: + cnode := c.convert(sk.Process_core()) + stmt, ok := cnode.(*ast.SelectStmt) + if !ok { + return nil + } + selectStmt = stmt + case sk.Select_core() != nil: + cnode := c.convert(sk.Select_core()) + stmt, ok := cnode.(*ast.SelectStmt) + if !ok { + return nil + } + selectStmt = stmt + case sk.Reduce_core() != nil: + cnode := c.convert(sk.Reduce_core()) + stmt, ok := cnode.(*ast.SelectStmt) + if !ok { + return nil + } + selectStmt = stmt + } + + // todo: cover process and reduce core, + // todo: cover LIMIT and OFFSET + + return selectStmt +} + +func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { + stmt := &ast.SelectStmt{} + if n.Opt_set_quantifier() != nil { + oq := n.Opt_set_quantifier() + if oq.DISTINCT() != nil { + // todo: add distinct support + stmt.DistinctClause = &ast.List{} + } + } + resultCols := n.AllResult_column() + if len(resultCols) > 0 { + var items []ast.Node + for _, rc := range resultCols { + resCol, ok := rc.(*parser.Result_columnContext) + if !ok { + continue + } + convNode := c.convertResultColumn(resCol) + if convNode != nil { + items = append(items, convNode) + } + } + stmt.TargetList = &ast.List{ + Items: items, + } + } + jsList := n.AllJoin_source() + if len(n.AllFROM()) > 0 && len(jsList) > 0 { + var fromItems []ast.Node + for _, js := range jsList { + jsCon, ok := js.(*parser.Join_sourceContext) + if !ok { + continue + } + + joinNode := c.convertJoinSource(jsCon) + if joinNode != nil { + fromItems = append(fromItems, joinNode) + } + } + stmt.FromClause = &ast.List{ + Items: fromItems, + } + } + if n.WHERE() != nil { + whereCtx := n.Expr(0) + if whereCtx != nil { + stmt.WhereClause = c.convert(whereCtx) + } + } + return stmt +} + +func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { + exprCtx := n.Expr() + if exprCtx == nil { + // todo + } + target := &ast.ResTarget{ + Location: n.GetStart().GetStart(), + } + var val ast.Node + iexpr := n.Expr() + switch { + case n.ASTERISK() != nil: + val = c.convertWildCardField(n) + case iexpr != nil: + val = c.convert(iexpr) + } + + if val == nil { + return nil + } + switch { + case n.AS() != nil && n.An_id_or_type() != nil: + name := parseAnIdOrType(n.An_id_or_type()) + target.Name = &name + case n.An_id_as_compat() != nil: + // todo: parse as_compat + } + target.Val = val + return target +} + +func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { + fsList := n.AllFlatten_source() + if len(fsList) == 0 { + return nil + } + joinOps := n.AllJoin_op() + joinConstraints := n.AllJoin_constraint() + + // todo: add ANY support + + leftNode := c.convertFlattenSource(fsList[0]) + if leftNode == nil { + return nil + } + for i, jopCtx := range joinOps { + if i+1 >= len(fsList) { + break + } + rightNode := c.convertFlattenSource(fsList[i+1]) + if rightNode == nil { + return leftNode + } + jexpr := &ast.JoinExpr{ + Larg: leftNode, + Rarg: rightNode, + } + if jopCtx.NATURAL() != nil { + jexpr.IsNatural = true + } + // todo: cover semi/only/exclusion/ + switch { + case jopCtx.LEFT() != nil: + jexpr.Jointype = ast.JoinTypeLeft + case jopCtx.RIGHT() != nil: + jexpr.Jointype = ast.JoinTypeRight + case jopCtx.FULL() != nil: + jexpr.Jointype = ast.JoinTypeFull + case jopCtx.INNER() != nil: + jexpr.Jointype = ast.JoinTypeInner + case jopCtx.COMMA() != nil: + jexpr.Jointype = ast.JoinTypeInner + default: + jexpr.Jointype = ast.JoinTypeInner + } + if i < len(joinConstraints) { + if jc := joinConstraints[i]; jc != nil { + switch { + case jc.ON() != nil: + if exprCtx := jc.Expr(); exprCtx != nil { + jexpr.Quals = c.convert(exprCtx) + } + case jc.USING() != nil: + if pureListCtx := jc.Pure_column_or_named_list(); pureListCtx != nil { + var using ast.List + pureItems := pureListCtx.AllPure_column_or_named() + for _, pureCtx := range pureItems { + if anID := pureCtx.An_id(); anID != nil { + using.Items = append(using.Items, NewIdentifier(parseAnId(anID))) + } else if bp := pureCtx.Bind_parameter(); bp != nil { + bindPar := c.convert(bp) + using.Items = append(using.Items, bindPar) + } + } + jexpr.UsingClause = &using + } + } + } + } + leftNode = jexpr + } + return leftNode +} + +func (c *cc) convertFlattenSource(n parser.IFlatten_sourceContext) ast.Node { + if n == nil { + return nil + } + nss := n.Named_single_source() + if nss == nil { + return nil + } + namedSingleSource, ok := nss.(*parser.Named_single_sourceContext) + if !ok { + return nil + } + return c.convertNamedSingleSource(namedSingleSource) +} + +func (c *cc) convertNamedSingleSource(n *parser.Named_single_sourceContext) ast.Node { + ss := n.Single_source() + if ss == nil { + return nil + } + SingleSource, ok := ss.(*parser.Single_sourceContext) + if !ok { + return nil + } + base := c.convertSingleSource(SingleSource) + + if n.AS() != nil && n.An_id() != nil { + aliasText := parseAnId(n.An_id()) + switch source := base.(type) { + case *ast.RangeVar: + source.Alias = &ast.Alias{Aliasname: &aliasText} + case *ast.RangeSubselect: + source.Alias = &ast.Alias{Aliasname: &aliasText} + } + } else if n.An_id_as_compat() != nil { + // todo: parse as_compat + } + return base +} + +func (c *cc) convertSingleSource(n *parser.Single_sourceContext) ast.Node { + if n.Table_ref() != nil { + tableName := n.Table_ref().GetText() // !! debug !! + return &ast.RangeVar{ + Relname: &tableName, + Location: n.GetStart().GetStart(), + } + } + + if n.Select_stmt() != nil { + subquery := c.convert(n.Select_stmt()) + return &ast.RangeSubselect{ + Subquery: subquery, + } + + } + // todo: Values stmt + + return nil +} + +func (c *cc) convertBindParameter(n *parser.Bind_parameterContext) ast.Node { + // !!debug later!! + if n.DOLLAR() != nil { + if n.TRUE() != nil { + return &ast.Boolean{ + Boolval: true, + } + } + if n.FALSE() != nil { + return &ast.Boolean{ + Boolval: false, + } + } + + if an := n.An_id_or_type(); an != nil { + idText := parseAnIdOrType(an) + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "@"}}}, + Rexpr: &ast.String{Str: idText}, + Location: n.GetStart().GetStart(), + } + } + c.paramCount++ + return &ast.ParamRef{ + Number: c.paramCount, + Location: n.GetStart().GetStart(), + Dollar: true, + } + } + return &ast.TODO{} +} + +func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef { + prefixCtx := n.Opt_id_prefix() + prefix := c.convertOptIdPrefix(prefixCtx) + + items := []ast.Node{} + if prefix != "" { + items = append(items, NewIdentifier(prefix)) + } + + items = append(items, &ast.A_Star{}) + return &ast.ColumnRef{ + Fields: &ast.List{Items: items}, + Location: n.GetStart().GetStart(), + } +} + +func (c *cc) convertOptIdPrefix(ctx parser.IOpt_id_prefixContext) string { + if ctx == nil { + return "" + } + if ctx.An_id() != nil { + return ctx.An_id().GetText() + } + return "" +} + +func (c *cc) convertCreate_table_stmtContext(n *parser.Create_table_stmtContext) ast.Node { + stmt := &ast.CreateTableStmt{ + Name: parseTableName(n.Simple_table_ref().Simple_table_ref_core()), + IfNotExists: n.EXISTS() != nil, + } + for _, idef := range n.AllCreate_table_entry() { + if def, ok := idef.(*parser.Create_table_entryContext); ok { + switch { + case def.Column_schema() != nil: + if colCtx, ok := def.Column_schema().(*parser.Column_schemaContext); ok { + colDef := c.convertColumnSchema(colCtx) + if colDef != nil { + stmt.Cols = append(stmt.Cols, colDef) + } + } + case def.Table_constraint() != nil: + if conCtx, ok := def.Table_constraint().(*parser.Table_constraintContext); ok { + switch { + case conCtx.PRIMARY() != nil && conCtx.KEY() != nil: + for _, cname := range conCtx.AllAn_id() { + for _, col := range stmt.Cols { + if col.Colname == parseAnId(cname) { + col.IsNotNull = true + } + } + } + case conCtx.PARTITION() != nil && conCtx.BY() != nil: + _ = conCtx + // todo: partition by constraint + case conCtx.ORDER() != nil && conCtx.BY() != nil: + _ = conCtx + // todo: order by constraint + } + } + + case def.Table_index() != nil: + if indCtx, ok := def.Table_index().(*parser.Table_indexContext); ok { + _ = indCtx + // todo + } + case def.Family_entry() != nil: + if famCtx, ok := def.Family_entry().(*parser.Family_entryContext); ok { + _ = famCtx + // todo + } + case def.Changefeed() != nil: // таблица ориентированная + if cgfCtx, ok := def.Changefeed().(*parser.ChangefeedContext); ok { + _ = cgfCtx + // todo + } + } + } + } + return stmt +} + +func (c *cc) convertColumnSchema(n *parser.Column_schemaContext) *ast.ColumnDef { + + col := &ast.ColumnDef{} + + if anId := n.An_id_schema(); anId != nil { + col.Colname = identifier(parseAnIdSchema(anId)) + } + if tnb := n.Type_name_or_bind(); tnb != nil { + col.TypeName = c.convertTypeNameOrBind(tnb) + } + if colCons := n.Opt_column_constraints(); colCons != nil { + col.IsNotNull = colCons.NOT() != nil && colCons.NULL() != nil + //todo: cover exprs if needed + } + // todo: family + + return col +} + +func (c *cc) convertTypeNameOrBind(n parser.IType_name_or_bindContext) *ast.TypeName { + if t := n.Type_name(); t != nil { + return c.convertTypeName(t) + } else if b := n.Bind_parameter(); b != nil { + return &ast.TypeName{Name: "BIND:" + identifier(parseAnIdOrType(b.An_id_or_type()))} + } + return nil +} + +func (c *cc) convertTypeName(n parser.IType_nameContext) *ast.TypeName { + if n == nil { + return nil + } + + // Handle composite types + if composite := n.Type_name_composite(); composite != nil { + if node := c.convertTypeNameComposite(composite); node != nil { + if typeName, ok := node.(*ast.TypeName); ok { + return typeName + } + } + } + + // Handle decimal type (e.g., DECIMAL(10,2)) + if decimal := n.Type_name_decimal(); decimal != nil { + if integerOrBinds := decimal.AllInteger_or_bind(); len(integerOrBinds) >= 2 { + return &ast.TypeName{ + Name: "Decimal", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{ + c.convertIntegerOrBind(integerOrBinds[0]), + c.convertIntegerOrBind(integerOrBinds[1]), + }, + }, + } + } + } + + // Handle simple types + if simple := n.Type_name_simple(); simple != nil { + return &ast.TypeName{ + Name: simple.GetText(), + TypeOid: 0, + } + } + + return nil +} + +func (c *cc) convertIntegerOrBind(n parser.IInteger_or_bindContext) ast.Node { + if n == nil { + return nil + } + + if integer := n.Integer(); integer != nil { + val, err := parseIntegerValue(integer.GetText()) + if err != nil { + return &ast.TODO{} + } + return &ast.Integer{Ival: val} + } + + if bind := n.Bind_parameter(); bind != nil { + return c.convertBindParameter(bind.(*parser.Bind_parameterContext)) + } + + return nil +} + +func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast.Node { + if n == nil { + return nil + } + + if opt := n.Type_name_optional(); opt != nil { + if typeName := opt.Type_name_or_bind(); typeName != nil { + return &ast.TypeName{ + Name: "Optional", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + }, + } + } + } + + if tuple := n.Type_name_tuple(); tuple != nil { + if typeNames := tuple.AllType_name_or_bind(); len(typeNames) > 0 { + var items []ast.Node + for _, tn := range typeNames { + items = append(items, c.convertTypeNameOrBind(tn)) + } + return &ast.TypeName{ + Name: "Tuple", + TypeOid: 0, + Names: &ast.List{Items: items}, + } + } + } + + if struct_ := n.Type_name_struct(); struct_ != nil { + if structArgs := struct_.AllStruct_arg(); len(structArgs) > 0 { + var items []ast.Node + for _, _ = range structArgs { + // TODO: Handle struct field names and types + items = append(items, &ast.TODO{}) + } + return &ast.TypeName{ + Name: "Struct", + TypeOid: 0, + Names: &ast.List{Items: items}, + } + } + } + + if variant := n.Type_name_variant(); variant != nil { + if variantArgs := variant.AllVariant_arg(); len(variantArgs) > 0 { + var items []ast.Node + for _, _ = range variantArgs { + // TODO: Handle variant arguments + items = append(items, &ast.TODO{}) + } + return &ast.TypeName{ + Name: "Variant", + TypeOid: 0, + Names: &ast.List{Items: items}, + } + } + } + + if list := n.Type_name_list(); list != nil { + if typeName := list.Type_name_or_bind(); typeName != nil { + return &ast.TypeName{ + Name: "List", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + }, + } + } + } + + if stream := n.Type_name_stream(); stream != nil { + if typeName := stream.Type_name_or_bind(); typeName != nil { + return &ast.TypeName{ + Name: "Stream", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + }, + } + } + } + + if flow := n.Type_name_flow(); flow != nil { + if typeName := flow.Type_name_or_bind(); typeName != nil { + return &ast.TypeName{ + Name: "Flow", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + }, + } + } + } + + if dict := n.Type_name_dict(); dict != nil { + if typeNames := dict.AllType_name_or_bind(); len(typeNames) >= 2 { + return &ast.TypeName{ + Name: "Dict", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{ + c.convertTypeNameOrBind(typeNames[0]), + c.convertTypeNameOrBind(typeNames[1]), + }, + }, + } + } + } + + if set := n.Type_name_set(); set != nil { + if typeName := set.Type_name_or_bind(); typeName != nil { + return &ast.TypeName{ + Name: "Set", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + }, + } + } + } + + if enum := n.Type_name_enum(); enum != nil { + if typeTags := enum.AllType_name_tag(); len(typeTags) > 0 { + var items []ast.Node + for _, _ = range typeTags { // todo: Handle enum tags + items = append(items, &ast.TODO{}) + } + return &ast.TypeName{ + Name: "Enum", + TypeOid: 0, + Names: &ast.List{Items: items}, + } + } + } + + if resource := n.Type_name_resource(); resource != nil { + if typeTag := resource.Type_name_tag(); typeTag != nil { + // TODO: Handle resource tag + return &ast.TypeName{ + Name: "Resource", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{&ast.TODO{}}, + }, + } + } + } + + if tagged := n.Type_name_tagged(); tagged != nil { + if typeName := tagged.Type_name_or_bind(); typeName != nil { + if typeTag := tagged.Type_name_tag(); typeTag != nil { + // TODO: Handle tagged type and tag + return &ast.TypeName{ + Name: "Tagged", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{ + c.convertTypeNameOrBind(typeName), + &ast.TODO{}, + }, + }, + } + } + } + } + + if callable := n.Type_name_callable(); callable != nil { + // TODO: Handle callable argument list and return type + return &ast.TypeName{ + Name: "Callable", + TypeOid: 0, + Names: &ast.List{ + Items: []ast.Node{&ast.TODO{}}, + }, + } + } + + return nil +} + +func (c *cc) convertSqlStmtCore(n parser.ISql_stmt_coreContext) ast.Node { + if n == nil { + return nil + } + + if stmt := n.Pragma_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Select_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Named_nodes_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_table_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_table_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Use_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Into_table_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Commit_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Update_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Delete_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Rollback_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Declare_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Import_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Export_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_table_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_external_table_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Do_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Define_action_or_subquery_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.If_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.For_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Values_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_user_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_user_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_group_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_group_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_role_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_object_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_object_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_object_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_external_data_source_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_external_data_source_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_external_data_source_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_replication_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_replication_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_topic_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_topic_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_topic_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Grant_permissions_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Revoke_permissions_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_table_store_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Upsert_object_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_view_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_view_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_replication_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_resource_pool_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_resource_pool_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_resource_pool_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_backup_collection_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_backup_collection_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_backup_collection_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Analyze_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Create_resource_pool_classifier_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_resource_pool_classifier_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Drop_resource_pool_classifier_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Backup_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Restore_stmt(); stmt != nil { + return c.convert(stmt) + } + if stmt := n.Alter_sequence_stmt(); stmt != nil { + return c.convert(stmt) + } + return nil +} + +func (c *cc) convertExpr(n *parser.ExprContext) ast.Node { + if n == nil { + return nil + } + + if tn := n.Type_name_composite(); tn != nil { + return c.convertTypeNameComposite(tn) + } + + orSubs := n.AllOr_subexpr() + if len(orSubs) == 0 { + return nil + } + + orSub, ok := orSubs[0].(*parser.Or_subexprContext) + if !ok { + return nil + } + + left := c.convertOrSubExpr(orSub) + for i := 1; i < len(orSubs); i++ { + orSub, ok = orSubs[i].(*parser.Or_subexprContext) + if !ok { + return nil + } + right := c.convertOrSubExpr(orSub) + left = &ast.BoolExpr{ + Boolop: ast.BoolExprTypeOr, + Args: &ast.List{Items: []ast.Node{left, right}}, + Location: n.GetStart().GetStart(), + } + } + return left +} + +func (c *cc) convertOrSubExpr(n *parser.Or_subexprContext) ast.Node { + if n == nil { + return nil + } + andSubs := n.AllAnd_subexpr() + if len(andSubs) == 0 { + return nil + } + andSub, ok := andSubs[0].(*parser.And_subexprContext) + if !ok { + return nil + } + + left := c.convertAndSubexpr(andSub) + for i := 1; i < len(andSubs); i++ { + andSub, ok = andSubs[i].(*parser.And_subexprContext) + if !ok { + return nil + } + right := c.convertAndSubexpr(andSub) + left = &ast.BoolExpr{ + Boolop: ast.BoolExprTypeAnd, + Args: &ast.List{Items: []ast.Node{left, right}}, + Location: n.GetStart().GetStart(), + } + } + return left +} + +func (c *cc) convertAndSubexpr(n *parser.And_subexprContext) ast.Node { + if n == nil { + return nil + } + + xors := n.AllXor_subexpr() + if len(xors) == 0 { + return nil + } + + xor, ok := xors[0].(*parser.Xor_subexprContext) + if !ok { + return nil + } + + left := c.convertXorSubexpr(xor) + for i := 1; i < len(xors); i++ { + xor, ok = xors[i].(*parser.Xor_subexprContext) + if !ok { + return nil + } + right := c.convertXorSubexpr(xor) + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "XOR"}}}, + Lexpr: left, + Rexpr: right, + Location: n.GetStart().GetStart(), + } + } + return left +} + +func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { + if n == nil { + return nil + } + es := n.Eq_subexpr() + if es == nil { + return nil + } + subExpr, ok := es.(*parser.Eq_subexprContext) + if !ok { + return nil + } + base := c.convertEqSubexpr(subExpr) + if cond := n.Cond_expr(); cond != nil { + condCtx, ok := cond.(*parser.Cond_exprContext) + if !ok { + return base + } + + switch { + case condCtx.IN() != nil: + if inExpr := condCtx.In_expr(); inExpr != nil { + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "IN"}}}, + Lexpr: base, + Rexpr: c.convert(inExpr), + } + } + case condCtx.BETWEEN() != nil: + if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 2 { + return &ast.BetweenExpr{ + Expr: base, + Left: c.convert(eqSubs[0]), + Right: c.convert(eqSubs[1]), + Not: condCtx.NOT() != nil, + Location: n.GetStart().GetStart(), + } + } + case condCtx.ISNULL() != nil: + return &ast.NullTest{ + Arg: base, + Nulltesttype: 1, // IS NULL + Location: n.GetStart().GetStart(), + } + case condCtx.NOTNULL() != nil: + return &ast.NullTest{ + Arg: base, + Nulltesttype: 2, // IS NOT NULL + Location: n.GetStart().GetStart(), + } + case condCtx.IS() != nil && condCtx.NULL() != nil: + return &ast.NullTest{ + Arg: base, + Nulltesttype: 1, // IS NULL + Location: n.GetStart().GetStart(), + } + case condCtx.IS() != nil && condCtx.NOT() != nil && condCtx.NULL() != nil: + return &ast.NullTest{ + Arg: base, + Nulltesttype: 2, // IS NOT NULL + Location: n.GetStart().GetStart(), + } + case condCtx.Match_op() != nil: + // debug!!! + matchOp := condCtx.Match_op().GetText() + if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 1 { + expr := &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: matchOp}}}, + Lexpr: base, + Rexpr: c.convert(eqSubs[0]), + } + if condCtx.ESCAPE() != nil && len(eqSubs) >= 2 { + // todo: Add ESCAPE support + } + return expr + } + case len(condCtx.AllEQUALS()) > 0 || len(condCtx.AllEQUALS2()) > 0 || + len(condCtx.AllNOT_EQUALS()) > 0 || len(condCtx.AllNOT_EQUALS2()) > 0: + // debug!!! + var op string + switch { + case len(condCtx.AllEQUALS()) > 0: + op = "=" + case len(condCtx.AllEQUALS2()) > 0: + op = "==" + case len(condCtx.AllNOT_EQUALS()) > 0: + op = "!=" + case len(condCtx.AllNOT_EQUALS2()) > 0: + op = "<>" + } + if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 1 { + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, + Lexpr: base, + Rexpr: c.convert(eqSubs[0]), + } + } + case len(condCtx.AllDistinct_from_op()) > 0: + // debug!!! + distinctOps := condCtx.AllDistinct_from_op() + for _, distinctOp := range distinctOps { + if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 1 { + not := distinctOp.NOT() != nil + op := "IS DISTINCT FROM" + if not { + op = "IS NOT DISTINCT FROM" + } + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, + Lexpr: base, + Rexpr: c.convert(eqSubs[0]), + } + } + } + } + } + return base +} + +func (c *cc) convertEqSubexpr(n *parser.Eq_subexprContext) ast.Node { + if n == nil { + return nil + } + neqList := n.AllNeq_subexpr() + if len(neqList) == 0 { + return nil + } + neq, ok := neqList[0].(*parser.Neq_subexprContext) + if !ok { + return nil + } + left := c.convertNeqSubexpr(neq) + ops := c.collectComparisonOps(n) + for i := 1; i < len(neqList); i++ { + neq, ok = neqList[i].(*parser.Neq_subexprContext) + if !ok { + return nil + } + right := c.convertNeqSubexpr(neq) + opText := ops[i-1].GetText() + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, + Lexpr: left, + Rexpr: right, + Location: n.GetStart().GetStart(), + } + } + return left +} + +func (c *cc) collectComparisonOps(n parser.IEq_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + for _, child := range n.GetChildren() { + if tn, ok := child.(antlr.TerminalNode); ok { + switch tn.GetText() { + case "<", "<=", ">", ">=": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { + if n == nil { + return nil + } + bitList := n.AllBit_subexpr() + if len(bitList) == 0 { + return nil + } + + bl, ok := bitList[0].(*parser.Bit_subexprContext) + if !ok { + return nil + } + left := c.convertBitSubexpr(bl) + ops := c.collectBitwiseOps(n) + for i := 1; i < len(bitList); i++ { + bl, ok = bitList[i].(*parser.Bit_subexprContext) + if !ok { + return nil + } + right := c.convertBitSubexpr(bl) + opText := ops[i-1].GetText() + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, + Lexpr: left, + Rexpr: right, + Location: n.GetStart().GetStart(), + } + } + + if n.Double_question() != nil { + nextCtx := n.Neq_subexpr() + if nextCtx != nil { + neq, ok2 := nextCtx.(*parser.Neq_subexprContext) + if !ok2 { + return nil + } + right := c.convertNeqSubexpr(neq) + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "??"}}}, + Lexpr: left, + Rexpr: right, + Location: n.GetStart().GetStart(), + } + } + } else { + // !! debug !! + qCount := len(n.AllQUESTION()) + if qCount > 0 { + questionOp := "?" + if qCount > 1 { + questionOp = strings.Repeat("?", qCount) + } + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: questionOp}}}, + Lexpr: left, + Location: n.GetStart().GetStart(), + } + } + } + + return left +} + +func (c *cc) collectBitwiseOps(ctx parser.INeq_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + children := ctx.GetChildren() + for _, child := range children { + if tn, ok := child.(antlr.TerminalNode); ok { + txt := tn.GetText() + switch txt { + case "<<", ">>", "<<|", ">>|", "&", "|", "^": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) convertBitSubexpr(n *parser.Bit_subexprContext) ast.Node { + addList := n.AllAdd_subexpr() + left := c.convertAddSubexpr(addList[0].(*parser.Add_subexprContext)) + + ops := c.collectBitOps(n) + for i := 1; i < len(addList); i++ { + right := c.convertAddSubexpr(addList[i].(*parser.Add_subexprContext)) + opText := ops[i-1].GetText() + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, + Lexpr: left, + Rexpr: right, + Location: n.GetStart().GetStart(), + } + } + return left +} + +func (c *cc) collectBitOps(ctx parser.IBit_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + children := ctx.GetChildren() + for _, child := range children { + if tn, ok := child.(antlr.TerminalNode); ok { + txt := tn.GetText() + switch txt { + case "+", "-": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) convertAddSubexpr(n *parser.Add_subexprContext) ast.Node { + mulList := n.AllMul_subexpr() + left := c.convertMulSubexpr(mulList[0].(*parser.Mul_subexprContext)) + + ops := c.collectAddOps(n) + for i := 1; i < len(mulList); i++ { + right := c.convertMulSubexpr(mulList[i].(*parser.Mul_subexprContext)) + opText := ops[i-1].GetText() + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, + Lexpr: left, + Rexpr: right, + Location: n.GetStart().GetStart(), + } + } + return left +} + +func (c *cc) collectAddOps(ctx parser.IAdd_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + for _, child := range ctx.GetChildren() { + if tn, ok := child.(antlr.TerminalNode); ok { + switch tn.GetText() { + case "*", "/", "%": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) convertMulSubexpr(n *parser.Mul_subexprContext) ast.Node { + conList := n.AllCon_subexpr() + left := c.convertConSubexpr(conList[0].(*parser.Con_subexprContext)) + + for i := 1; i < len(conList); i++ { + right := c.convertConSubexpr(conList[i].(*parser.Con_subexprContext)) + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "||"}}}, + Lexpr: left, + Rexpr: right, + Location: n.GetStart().GetStart(), + } + } + return left +} + +func (c *cc) convertConSubexpr(n *parser.Con_subexprContext) ast.Node { + if opCtx := n.Unary_op(); opCtx != nil { + op := opCtx.GetText() + operand := c.convertUnarySubexpr(n.Unary_subexpr().(*parser.Unary_subexprContext)) + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, + Rexpr: operand, + Location: n.GetStart().GetStart(), + } + } + return c.convertUnarySubexpr(n.Unary_subexpr().(*parser.Unary_subexprContext)) +} + +func (c *cc) convertUnarySubexpr(n *parser.Unary_subexprContext) ast.Node { + if casual := n.Unary_casual_subexpr(); casual != nil { + return c.convertUnaryCasualSubexpr(casual.(*parser.Unary_casual_subexprContext)) + } + if jsonExpr := n.Json_api_expr(); jsonExpr != nil { + return c.convertJsonApiExpr(jsonExpr.(*parser.Json_api_exprContext)) + } + return nil +} + +func (c *cc) convertJsonApiExpr(n *parser.Json_api_exprContext) ast.Node { + return &ast.TODO{} // todo +} + +func (c *cc) convertUnaryCasualSubexpr(n *parser.Unary_casual_subexprContext) ast.Node { + var baseExpr ast.Node + + if idExpr := n.Id_expr(); idExpr != nil { + baseExpr = c.convertIdExpr(idExpr.(*parser.Id_exprContext)) + } else if atomExpr := n.Atom_expr(); atomExpr != nil { + baseExpr = c.convertAtomExpr(atomExpr.(*parser.Atom_exprContext)) + } + + suffixCtx := n.Unary_subexpr_suffix() + if suffixCtx != nil { + ctx, ok := suffixCtx.(*parser.Unary_subexpr_suffixContext) + if !ok { + return baseExpr + } + baseExpr = c.convertUnarySubexprSuffix(baseExpr, ctx) + } + + return baseExpr +} + +func (c *cc) convertUnarySubexprSuffix(base ast.Node, n *parser.Unary_subexpr_suffixContext) ast.Node { + if n == nil { + return base + } + colRef, ok := base.(*ast.ColumnRef) + if !ok { + return base // todo: cover case when unary subexpr with atomic expr + } + + for i := 0; i < n.GetChildCount(); i++ { + child := n.GetChild(i) + switch v := child.(type) { + case parser.IKey_exprContext: + node := c.convert(v.(*parser.Key_exprContext)) + if node != nil { + colRef.Fields.Items = append(colRef.Fields.Items, node) + } + + case parser.IInvoke_exprContext: + node := c.convert(v.(*parser.Invoke_exprContext)) + if node != nil { + colRef.Fields.Items = append(colRef.Fields.Items, node) + } + case antlr.TerminalNode: + if v.GetText() == "." { + if i+1 < n.GetChildCount() { + next := n.GetChild(i + 1) + switch w := next.(type) { + case parser.IBind_parameterContext: + // !!! debug !!! + node := c.convert(next.(*parser.Bind_parameterContext)) + colRef.Fields.Items = append(colRef.Fields.Items, node) + case antlr.TerminalNode: + // !!! debug !!! + val, err := parseIntegerValue(w.GetText()) + if err != nil { + if debug.Active { + log.Printf("Failed to parse integer value '%s': %v", w.GetText(), err) + } + return &ast.TODO{} + } + node := &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: n.GetStart().GetStart()} + colRef.Fields.Items = append(colRef.Fields.Items, node) + case parser.IAn_id_or_typeContext: + idText := parseAnIdOrType(w) + colRef.Fields.Items = append(colRef.Fields.Items, &ast.String{Str: idText}) + default: + colRef.Fields.Items = append(colRef.Fields.Items, &ast.TODO{}) + } + i++ + } + } + } + } + + if n.COLLATE() != nil && n.An_id() != nil { + // todo: Handle COLLATE + } + return colRef +} + +func (c *cc) convertIdExpr(n *parser.Id_exprContext) ast.Node { + if id := n.Identifier(); id != nil { + return &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + NewIdentifier(id.GetText()), + }, + }, + } + } + return &ast.TODO{} +} + +func (c *cc) convertAtomExpr(n *parser.Atom_exprContext) ast.Node { + switch { + case n.An_id_or_type() != nil: + return NewIdentifier(parseAnIdOrType(n.An_id_or_type())) + case n.Literal_value() != nil: + return c.convertLiteralValue(n.Literal_value().(*parser.Literal_valueContext)) + case n.Bind_parameter() != nil: + return c.convertBindParameter(n.Bind_parameter().(*parser.Bind_parameterContext)) + default: + return &ast.TODO{} + } +} + +func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { + switch { + case n.Integer() != nil: + text := n.Integer().GetText() + val, err := parseIntegerValue(text) + if err != nil { + if debug.Active { + log.Printf("Failed to parse integer value '%s': %v", text, err) + } + return &ast.TODO{} + } + return &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: n.GetStart().GetStart()} + + case n.Real_() != nil: + text := n.Real_().GetText() + return &ast.A_Const{Val: &ast.Float{Str: text}, Location: n.GetStart().GetStart()} + + case n.STRING_VALUE() != nil: // !!! debug !!! (problem with quoted strings) + val := n.STRING_VALUE().GetText() + if len(val) >= 2 { + val = val[1 : len(val)-1] + } + return &ast.A_Const{Val: &ast.String{Str: val}, Location: n.GetStart().GetStart()} + + case n.Bool_value() != nil: + var i bool + if n.Bool_value().TRUE() != nil { + i = true + } + return &ast.Boolean{Boolval: i} + + case n.NULL() != nil: + return &ast.Null{} + + case n.CURRENT_TIME() != nil: + if debug.Active { + log.Printf("TODO: Implement CURRENT_TIME") + } + return &ast.TODO{} + + case n.CURRENT_DATE() != nil: + if debug.Active { + log.Printf("TODO: Implement CURRENT_DATE") + } + return &ast.TODO{} + + case n.CURRENT_TIMESTAMP() != nil: + if debug.Active { + log.Printf("TODO: Implement CURRENT_TIMESTAMP") + } + return &ast.TODO{} + + case n.BLOB() != nil: + blobText := n.BLOB().GetText() + return &ast.A_Const{Val: &ast.String{Str: blobText}, Location: n.GetStart().GetStart()} + + case n.EMPTY_ACTION() != nil: + if debug.Active { + log.Printf("TODO: Implement EMPTY_ACTION") + } + return &ast.TODO{} + + default: + if debug.Active { + log.Printf("Unknown literal value type: %T", n) + } + return &ast.TODO{} + } +} + +func (c *cc) convertSqlStmt(n *parser.Sql_stmtContext) ast.Node { + if n == nil { + return nil + } + // todo: handle explain + if core := n.Sql_stmt_core(); core != nil { + return c.convert(core) + } + + return nil +} + +func (c *cc) convert(node node) ast.Node { + switch n := node.(type) { + case *parser.Sql_stmtContext: + return c.convertSqlStmt(n) + + case *parser.Sql_stmt_coreContext: + return c.convertSqlStmtCore(n) + + case *parser.Create_table_stmtContext: + return c.convertCreate_table_stmtContext(n) + + case *parser.Select_stmtContext: + return c.convertSelectStmtContext(n) + + case *parser.Select_coreContext: + return c.convertSelectCoreContext(n) + + case *parser.Result_columnContext: + return c.convertResultColumn(n) + + case *parser.Join_sourceContext: + return c.convertJoinSource(n) + + case *parser.Flatten_sourceContext: + return c.convertFlattenSource(n) + + case *parser.Named_single_sourceContext: + return c.convertNamedSingleSource(n) + + case *parser.Single_sourceContext: + return c.convertSingleSource(n) + + case *parser.Bind_parameterContext: + return c.convertBindParameter(n) + + case *parser.ExprContext: + return c.convertExpr(n) + + case *parser.Or_subexprContext: + return c.convertOrSubExpr(n) + + case *parser.And_subexprContext: + return c.convertAndSubexpr(n) + + case *parser.Xor_subexprContext: + return c.convertXorSubexpr(n) + + case *parser.Eq_subexprContext: + return c.convertEqSubexpr(n) + + case *parser.Neq_subexprContext: + return c.convertNeqSubexpr(n) + + case *parser.Bit_subexprContext: + return c.convertBitSubexpr(n) + + case *parser.Add_subexprContext: + return c.convertAddSubexpr(n) + + case *parser.Mul_subexprContext: + return c.convertMulSubexpr(n) + + case *parser.Con_subexprContext: + return c.convertConSubexpr(n) + + case *parser.Unary_subexprContext: + return c.convertUnarySubexpr(n) + + case *parser.Unary_casual_subexprContext: + return c.convertUnaryCasualSubexpr(n) + + case *parser.Id_exprContext: + return c.convertIdExpr(n) + + case *parser.Atom_exprContext: + return c.convertAtomExpr(n) + + case *parser.Literal_valueContext: + return c.convertLiteralValue(n) + + case *parser.Json_api_exprContext: + return c.convertJsonApiExpr(n) + + case *parser.Type_name_compositeContext: + return c.convertTypeNameComposite(n) + + case *parser.Type_nameContext: + return c.convertTypeName(n) + + case *parser.Integer_or_bindContext: + return c.convertIntegerOrBind(n) + + case *parser.Type_name_or_bindContext: + return c.convertTypeNameOrBind(n) + + default: + return todo("convert(case=default)", n) + } +} diff --git a/internal/engine/ydb/parse.go b/internal/engine/ydb/parse.go new file mode 100755 index 0000000000..797710988c --- /dev/null +++ b/internal/engine/ydb/parse.go @@ -0,0 +1,93 @@ +package ydb + +import ( + "errors" + "fmt" + "io" + + "github.com/antlr4-go/antlr/v4" + "github.com/sqlc-dev/sqlc/internal/source" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + parser "github.com/ydb-platform/yql-parsers/go" +) + +type errorListener struct { + *antlr.DefaultErrorListener + + err string +} + +func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) { + el.err = msg +} + +// func (el *errorListener) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs antlr.ATNConfigSet) { +// } +// +// func (el *errorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs antlr.ATNConfigSet) { +// } +// +// func (el *errorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs antlr.ATNConfigSet) { +// } + +func NewParser() *Parser { + return &Parser{} +} + +type Parser struct { +} + +func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { + blob, err := io.ReadAll(r) + if err != nil { + return nil, err + } + input := antlr.NewInputStream(string(blob)) + lexer := parser.NewYQLLexer(input) + stream := antlr.NewCommonTokenStream(lexer, 0) + pp := parser.NewYQLParser(stream) + el := &errorListener{} + pp.AddErrorListener(el) + // pp.BuildParseTrees = true + tree := pp.Sql_query() + if el.err != "" { + return nil, errors.New(el.err) + } + pctx, ok := tree.(*parser.Sql_queryContext) + if !ok { + return nil, fmt.Errorf("expected ParserContext; got %T\n ", tree) + } + var stmts []ast.Statement + stmtListCtx := pctx.Sql_stmt_list() + if stmtListCtx != nil { + loc := 0 + for _, stmt := range stmtListCtx.AllSql_stmt() { + converter := &cc{} + out := converter.convert(stmt) + if _, ok := out.(*ast.TODO); ok { + loc = stmt.GetStop().GetStop() + 2 + continue + } + if out != nil { + len := (stmt.GetStop().GetStop() + 1) - loc + stmts = append(stmts, ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: out, + StmtLocation: loc, + StmtLen: len, + }, + }) + loc = stmt.GetStop().GetStop() + 2 + } + } + } + return stmts, nil +} + +func (p *Parser) CommentSyntax() source.CommentSyntax { + return source.CommentSyntax{ + Dash: true, + Hash: false, + SlashStar: true, + } +} diff --git a/internal/engine/ydb/reserved.go b/internal/engine/ydb/reserved.go new file mode 100644 index 0000000000..8db504c0b9 --- /dev/null +++ b/internal/engine/ydb/reserved.go @@ -0,0 +1,301 @@ +package ydb + +import "strings" + + +func (p *Parser) IsReservedKeyword(s string) bool { + switch strings.ToLower(s) { + case "abort": + case "action": + case "add": + case "after": + case "all": + case "alter": + case "analyze": + case "and": + case "ansi": + case "any": + case "array": + case "as": + case "asc": + case "assume": + case "asymmetric": + case "async": + case "at": + case "attach": + case "attributes": + case "autoincrement": + case "automap": + case "backup": + case "batch": + case "collection": + case "before": + case "begin": + case "bernoulli": + case "between": + case "bitcast": + case "by": + case "callable": + case "cascade": + case "case": + case "cast": + case "changefeed": + case "check": + case "classifier": + case "collate": + case "column": + case "columns": + case "commit": + case "compact": + case "conditional": + case "conflict": + case "connect": + case "constraint": + case "consumer": + case "cover": + case "create": + case "cross": + case "cube": + case "current": + case "current_date": + case "current_time": + case "current_timestamp": + case "data": + case "database": + case "decimal": + case "declare": + case "default": + case "deferrable": + case "deferred": + case "define": + case "delete": + case "desc": + case "describe": + case "detach": + case "dict": + case "directory": + case "disable": + case "discard": + case "distinct": + case "do": + case "drop": + case "each": + case "else": + case "empty": + case "empty_action": + case "encrypted": + case "end": + case "enum": + case "erase": + case "error": + case "escape": + case "evaluate": + case "except": + case "exclude": + case "exclusion": + case "exclusive": + case "exists": + case "explain": + case "export": + case "external": + case "fail": + case "false": + case "family": + case "filter": + case "first": + case "flatten": + case "flow": + case "following": + case "for": + case "foreign": + case "from": + case "full": + case "function": + case "glob": + case "global": + case "grant": + case "group": + case "grouping": + case "groups": + case "hash": + case "having": + case "hop": + case "if": + case "ignore": + case "ilike": + case "immediate": + case "import": + case "in": + case "increment": + case "incremental": + case "index": + case "indexed": + case "inherits": + case "initial": + case "initially": + case "inner": + case "insert": + case "instead": + case "intersect": + case "into": + case "is": + case "isnull": + case "join": + case "json_exists": + case "json_query": + case "json_value": + case "key": + case "last": + case "left": + case "legacy": + case "like": + case "limit": + case "list": + case "local": + case "login": + case "manage": + case "match": + case "matches": + case "match_recognize": + case "measures": + case "microseconds": + case "milliseconds": + case "modify": + case "nanoseconds": + case "natural": + case "next": + case "no": + case "nologin": + case "not": + case "notnull": + case "null": + case "nulls": + case "object": + case "of": + case "offset": + case "omit": + case "on": + case "one": + case "only": + case "option": + case "optional": + case "or": + case "order": + case "others": + case "outer": + case "over": + case "owner": + case "parallel": + case "partition": + case "passing": + case "password": + case "past": + case "pattern": + case "per": + case "permute": + case "plan": + case "pool": + case "pragma": + case "preceding": + case "presort": + case "primary": + case "privileges": + case "process": + case "query": + case "queue": + case "raise": + case "range": + case "reduce": + case "references": + case "regexp": + case "reindex": + case "release": + case "remove": + case "rename": + case "repeatable": + case "replace": + case "replication": + case "reset": + case "resource": + case "respect": + case "restart": + case "restore": + case "restrict": + case "result": + case "return": + case "returning": + case "revert": + case "revoke": + case "right": + case "rlike": + case "rollback": + case "rollup": + case "row": + case "rows": + case "sample": + case "savepoint": + case "schema": + case "seconds": + case "seek": + case "select": + case "semi": + case "set": + case "sets": + case "show": + case "tskip": + case "sequence": + case "source": + case "start": + case "stream": + case "struct": + case "subquery": + case "subset": + case "symbols": + case "symmetric": + case "sync": + case "system": + case "table": + case "tables": + case "tablesample": + case "tablestore": + case "tagged": + case "temp": + case "temporary": + case "then": + case "ties": + case "to": + case "topic": + case "transaction": + case "transfer": + case "trigger": + case "true": + case "tuple": + case "type": + case "unbounded": + case "unconditional": + case "union": + case "unique": + case "unknown": + case "unmatched": + case "update": + case "upsert": + case "use": + case "user": + case "using": + case "vacuum": + case "values": + case "variant": + case "view": + case "virtual": + case "when": + case "where": + case "window": + case "with": + case "without": + case "wrapper": + case "xor": + default: + return false + } + return true +} diff --git a/internal/engine/ydb/stdlib.go b/internal/engine/ydb/stdlib.go new file mode 100644 index 0000000000..fd78d7de38 --- /dev/null +++ b/internal/engine/ydb/stdlib.go @@ -0,0 +1,12 @@ +package ydb + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func defaultSchema(name string) *catalog.Schema { + s := &catalog.Schema{Name: name} + s.Funcs = []*catalog.Function{} + + return s +} diff --git a/internal/engine/ydb/utils.go b/internal/engine/ydb/utils.go new file mode 100755 index 0000000000..0fe41d356f --- /dev/null +++ b/internal/engine/ydb/utils.go @@ -0,0 +1,143 @@ +package ydb + +import ( + "strconv" + "strings" + + "github.com/antlr4-go/antlr/v4" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + parser "github.com/ydb-platform/yql-parsers/go" +) + +type objectRefProvider interface { + antlr.ParserRuleContext + Object_ref() parser.IObject_refContext +} + +func parseTableName(ctx objectRefProvider) *ast.TableName { + return parseObjectRef(ctx.Object_ref()) +} + +func parseObjectRef(r parser.IObject_refContext) *ast.TableName { + if r == nil { + return nil + } + ref := r.(*parser.Object_refContext) + + parts := []string{} + + if cl := ref.Cluster_expr(); cl != nil { + parts = append(parts, parseClusterExpr(cl)) + } + + if idOrAt := ref.Id_or_at(); idOrAt != nil { + parts = append(parts, parseIdOrAt(idOrAt)) + } + + objectName := strings.Join(parts, ".") + + return &ast.TableName{ + Schema: "", + Name: identifier(objectName), + } +} + +func parseClusterExpr(ctx parser.ICluster_exprContext) string { + if ctx == nil { + return "" + } + return identifier(ctx.GetText()) +} + +func parseIdOrAt(ctx parser.IId_or_atContext) string { + if ctx == nil { + return "" + } + idOrAt := ctx.(*parser.Id_or_atContext) + + if ao := idOrAt.An_id_or_type(); ao != nil { + return identifier(parseAnIdOrType(ao)) + } + return "" +} + +func parseAnIdOrType(ctx parser.IAn_id_or_typeContext) string { + if ctx == nil { + return "" + } + anId := ctx.(*parser.An_id_or_typeContext) + + if anId.Id_or_type() != nil { + return identifier(parseIdOrType(anId.Id_or_type())) + } + + if anId.STRING_VALUE() != nil { + return identifier(anId.STRING_VALUE().GetText()) + } + + return "" +} + +func parseIdOrType(ctx parser.IId_or_typeContext) string { + if ctx == nil { + return "" + } + Id := ctx.(*parser.Id_or_typeContext) + if Id.Id() != nil { + return identifier(parseIdTable(Id.Id())) + } + + return "" +} + +func parseAnId(ctx parser.IAn_idContext) string { + if id := ctx.Id(); id != nil { + return id.GetText() + } else if str := ctx.STRING_VALUE(); str != nil { + return str.GetText() + } + return "" +} + +func parseAnIdSchema(ctx parser.IAn_id_schemaContext) string { + if ctx == nil { + return "" + } + if id := ctx.Id_schema(); id != nil { + return id.GetText() + } else if str := ctx.STRING_VALUE(); str != nil { + return str.GetText() + } + return "" +} + +func parseIdTable(ctx parser.IIdContext) string { + if ctx == nil { + return "" + } + return ctx.GetText() +} + +func parseIntegerValue(text string) (int64, error) { + text = strings.ToLower(text) + base := 10 + + switch { + case strings.HasPrefix(text, "0x"): + base = 16 + text = strings.TrimPrefix(text, "0x") + + case strings.HasPrefix(text, "0o"): + base = 8 + text = strings.TrimPrefix(text, "0o") + + case strings.HasPrefix(text, "0b"): + base = 2 + text = strings.TrimPrefix(text, "0b") + } + + // debug!!! + text = strings.TrimRight(text, "pulstibn") + + return strconv.ParseInt(text, base, 64) +} diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index d1ea1a22cc..9146d17e08 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -172,6 +172,8 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, replace = "?" } else if engine == config.EngineSQLite { replace = fmt.Sprintf("?%d", argn) + } else if engine == config.EngineYDB { + replace = fmt.Sprintf("$%s", paramName) } else { replace = fmt.Sprintf("$%d", argn) } diff --git a/internal/sqltest/local/ydb.go b/internal/sqltest/local/ydb.go new file mode 100644 index 0000000000..79be58241f --- /dev/null +++ b/internal/sqltest/local/ydb.go @@ -0,0 +1,117 @@ +package local + +import ( + "context" + "database/sql" + "fmt" + "hash/fnv" + "math/rand" + "net" + "os" + "testing" + "time" + + migrate "github.com/sqlc-dev/sqlc/internal/migrations" + "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" + "github.com/ydb-platform/ydb-go-sdk/v3" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func YDB(t *testing.T, migrations []string) TestYDB { + return link_YDB(t, migrations, true) +} + +func ReadOnlyYDB(t *testing.T, migrations []string) TestYDB { + return link_YDB(t, migrations, false) +} + +type TestYDB struct { + DB *sql.DB + Prefix string +} + +func link_YDB(t *testing.T, migrations []string, rw bool) TestYDB { + t.Helper() + + // 1) Контекст с таймаутом + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + dbuiri := os.Getenv("YDB_SERVER_URI") + if dbuiri == "" { + t.Skip("YDB_SERVER_URI is empty") + } + host, _, err := net.SplitHostPort(dbuiri) + if err != nil { + t.Fatalf("invalid YDB_SERVER_URI: %q", dbuiri) + } + + baseDB := os.Getenv("YDB_DATABASE") + if baseDB == "" { + baseDB = "/local" + } + + // собираем миграции + var seed []string + files, err := sqlpath.Glob(migrations) + if err != nil { + t.Fatal(err) + } + h := fnv.New64() + for _, f := range files { + blob, err := os.ReadFile(f) + if err != nil { + t.Fatal(err) + } + h.Write(blob) + seed = append(seed, migrate.RemoveRollbackStatements(string(blob))) + } + + var name string + if rw { + // name = fmt.Sprintf("sqlc_test_%s", id()) + name = fmt.Sprintf("sqlc_test_%s", "test_new") + } else { + name = fmt.Sprintf("sqlc_test_%x", h.Sum(nil)) + } + prefix := fmt.Sprintf("%s/%s", baseDB, name) + + // 2) Открываем драйвер к корню "/" + rootDSN := fmt.Sprintf("grpc://%s?database=%s", dbuiri, baseDB) + t.Logf("→ Opening root driver: %s", rootDSN) + driver, err := ydb.Open(ctx, rootDSN, + ydb.WithInsecure(), + ydb.WithDiscoveryInterval(time.Hour), + ydb.WithNodeAddressMutator(func(_ string) string { + return host + }), + ) + if err != nil { + t.Fatalf("failed to open root YDB connection: %s", err) + } + + connector, err := ydb.Connector( + driver, + ydb.WithTablePathPrefix(prefix), + ydb.WithAutoDeclare(), + ) + if err != nil { + t.Fatalf("failed to create connector: %s", err) + } + + db := sql.OpenDB(connector) + + t.Log("→ Applying migrations to prefix: ", prefix) + + schemeCtx := ydb.WithQueryMode(ctx, ydb.SchemeQueryMode) + for _, stmt := range seed { + _, err := db.ExecContext(schemeCtx, stmt) + if err != nil { + t.Fatalf("failed to apply migration: %s\nSQL: %s", err, stmt) + } + } + return TestYDB{DB: db, Prefix: prefix} +} From 4e46daafafafbcae0c3e264ff0f6bbc8f31c6bbd Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Sun, 27 Apr 2025 02:00:26 +0300 Subject: [PATCH 02/18] Added almost full INSERT support and basic DELETE (without ON) support --- examples/authors/sqlc.yaml | 82 ++++----- examples/authors/ydb/db_test.go | 97 +++++++++-- examples/authors/ydb/models.go | 6 +- examples/authors/ydb/query.sql | 10 ++ examples/authors/ydb/query.sql.go | 42 +++++ internal/codegen/golang/ydb_type.go | 40 +++-- .../engine/ydb/catalog_tests/insert_test.go | 135 +++++++++++++++ internal/engine/ydb/convert.go | 160 ++++++++++++++++++ internal/sql/ast/delete_stmt.go | 6 + internal/sql/ast/insert_stmt.go | 15 +- internal/sql/ast/on_conflict_action_type.go | 15 ++ internal/sqltest/local/ydb.go | 2 +- 12 files changed, 532 insertions(+), 78 deletions(-) create mode 100644 internal/engine/ydb/catalog_tests/insert_test.go create mode 100644 internal/sql/ast/on_conflict_action_type.go diff --git a/examples/authors/sqlc.yaml b/examples/authors/sqlc.yaml index 8d6bc3db28..30d904875e 100644 --- a/examples/authors/sqlc.yaml +++ b/examples/authors/sqlc.yaml @@ -2,47 +2,47 @@ version: '2' cloud: project: "01HAQMMECEYQYKFJN8MP16QC41" sql: -- name: postgresql - schema: postgresql/schema.sql - queries: postgresql/query.sql - engine: postgresql - database: - uri: "${VET_TEST_EXAMPLES_POSTGRES_AUTHORS}" - analyzer: - database: false - rules: - - sqlc/db-prepare - - postgresql-query-too-costly - gen: - go: - package: authors - sql_package: pgx/v5 - out: postgresql -- name: mysql - schema: mysql/schema.sql - queries: mysql/query.sql - engine: mysql - database: - uri: "${VET_TEST_EXAMPLES_MYSQL_AUTHORS}" - rules: - - sqlc/db-prepare - # - mysql-query-too-costly - gen: - go: - package: authors - out: mysql -- name: sqlite - schema: sqlite/schema.sql - queries: sqlite/query.sql - engine: sqlite - database: - uri: file:authors?mode=memory&cache=shared - rules: - - sqlc/db-prepare - gen: - go: - package: authors - out: sqlite +# - name: postgresql +# schema: postgresql/schema.sql +# queries: postgresql/query.sql +# engine: postgresql +# database: +# uri: "${VET_TEST_EXAMPLES_POSTGRES_AUTHORS}" +# analyzer: +# database: false +# rules: +# - sqlc/db-prepare +# - postgresql-query-too-costly +# gen: +# go: +# package: authors +# sql_package: pgx/v5 +# out: postgresql +# - name: mysql +# schema: mysql/schema.sql +# queries: mysql/query.sql +# engine: mysql +# database: +# uri: "${VET_TEST_EXAMPLES_MYSQL_AUTHORS}" +# rules: +# - sqlc/db-prepare +# # - mysql-query-too-costly +# gen: +# go: +# package: authors +# out: mysql +# - name: sqlite +# schema: sqlite/schema.sql +# queries: sqlite/query.sql +# engine: sqlite +# database: +# uri: file:authors?mode=memory&cache=shared +# rules: +# - sqlc/db-prepare +# gen: +# go: +# package: authors +# out: sqlite - name: ydb schema: ydb/schema.sql queries: ydb/query.sql diff --git a/examples/authors/ydb/db_test.go b/examples/authors/ydb/db_test.go index 181ee64ed1..44e330a2f6 100644 --- a/examples/authors/ydb/db_test.go +++ b/examples/authors/ydb/db_test.go @@ -8,6 +8,10 @@ import ( _ "github.com/ydb-platform/ydb-go-sdk/v3" ) +func ptr(s string) *string { + return &s +} + func TestAuthors(t *testing.T) { ctx := context.Background() @@ -16,6 +20,53 @@ func TestAuthors(t *testing.T) { q := New(test.DB) + t.Run("InsertAuthors", func(t *testing.T) { + authorsToInsert := []CreateOrUpdateAuthorParams{ + {P0: 1, P1: "Лев Толстой", P2: ptr("Русский писатель, автор \"Война и мир\"")}, + {P0: 2, P1: "Александр Пушкин", P2: ptr("Автор \"Евгения Онегина\"")}, + {P0: 3, P1: "Александр Пушкин", P2: ptr("Русский поэт, драматург и прозаик")}, + {P0: 4, P1: "Фёдор Достоевский", P2: ptr("Автор \"Преступление и наказание\"")}, + {P0: 5, P1: "Николай Гоголь", P2: ptr("Автор \"Мёртвые души\"")}, + {P0: 6, P1: "Антон Чехов", P2: nil}, + {P0: 7, P1: "Иван Тургенев", P2: ptr("Автор \"Отцы и дети\"")}, + {P0: 8, P1: "Михаил Лермонтов", P2: nil}, + {P0: 9, P1: "Даниил Хармс", P2: ptr("Абсурдист, писатель и поэт")}, + {P0: 10, P1: "Максим Горький", P2: ptr("Автор \"На дне\"")}, + {P0: 11, P1: "Владимир Маяковский", P2: nil}, + {P0: 12, P1: "Сергей Есенин", P2: ptr("Русский лирик")}, + {P0: 13, P1: "Борис Пастернак", P2: ptr("Автор \"Доктор Живаго\"")}, + } + + for _, author := range authorsToInsert { + if _, err := q.CreateOrUpdateAuthor(ctx, author); err != nil { + t.Fatalf("failed to insert author %q: %v", author.P1, err) + } + } + }) + + t.Run("CreateOrUpdateAuthorReturningBio", func(t *testing.T) { + newBio := "Обновленная биография автора" + arg := CreateOrUpdateAuthorRetunringBioParams{ + P0: 3, + P1: "Тестовый Автор", + P2: &newBio, + } + + returnedBio, err := q.CreateOrUpdateAuthorRetunringBio(ctx, arg) + if err != nil { + t.Fatalf("failed to create or update author: %v", err) + } + + if returnedBio == nil { + t.Fatal("expected non-nil bio, got nil") + } + if *returnedBio != newBio { + t.Fatalf("expected bio %q, got %q", newBio, *returnedBio) + } + + t.Logf("Author created or updated successfully with bio: %s", *returnedBio) + }) + t.Run("ListAuthors", func(t *testing.T) { authors, err := q.ListAuthors(ctx) if err != nil { @@ -26,9 +77,9 @@ func TestAuthors(t *testing.T) { } t.Log("Authors:") for _, a := range authors { - bio := "NULL" - if a.Bio.Valid { - bio = a.Bio.String + bio := "Null" + if a.Bio != nil { + bio = *a.Bio } t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) } @@ -39,9 +90,9 @@ func TestAuthors(t *testing.T) { if err != nil { t.Fatal(err) } - bio := "NULL" - if singleAuthor.Bio.Valid { - bio = singleAuthor.Bio.String + bio := "Null" + if singleAuthor.Bio != nil { + bio = *singleAuthor.Bio } t.Logf("- ID: %d | Name: %s | Bio: %s", singleAuthor.ID, singleAuthor.Name, bio) }) @@ -56,9 +107,9 @@ func TestAuthors(t *testing.T) { } t.Log("Authors with this name:") for _, a := range authors { - bio := "NULL" - if a.Bio.Valid { - bio = a.Bio.String + bio := "Null" + if a.Bio != nil { + bio = *a.Bio } t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) } @@ -74,9 +125,9 @@ func TestAuthors(t *testing.T) { } t.Log("Authors with even IDs:") for _, a := range authors { - bio := "NULL" - if a.Bio.Valid { - bio = a.Bio.String + bio := "Null" + if a.Bio != nil { + bio = *a.Bio } t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) } @@ -92,11 +143,27 @@ func TestAuthors(t *testing.T) { } t.Log("Authors with NULL bio:") for _, a := range authors { - bio := "NULL" - if a.Bio.Valid { - bio = a.Bio.String + bio := "Null" + if a.Bio != nil { + bio = *a.Bio } t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) } }) + + t.Run("Delete All Authors", func(t *testing.T) { + var i uint64 + for i = 1; i <= 13; i++ { + if err := q.DeleteAuthor(ctx, i); err != nil { + t.Fatalf("failed to delete authors: %v", err) + } + } + authors, err := q.ListAuthors(ctx) + if err != nil { + t.Fatal(err) + } + if len(authors) != 0 { + t.Fatalf("expected no authors, got %d", len(authors)) + } + }) } diff --git a/examples/authors/ydb/models.go b/examples/authors/ydb/models.go index e899b195b0..337ea597f4 100644 --- a/examples/authors/ydb/models.go +++ b/examples/authors/ydb/models.go @@ -4,12 +4,8 @@ package authors -import ( - "database/sql" -) - type Author struct { ID uint64 Name string - Bio sql.NullString + Bio *string } diff --git a/examples/authors/ydb/query.sql b/examples/authors/ydb/query.sql index 219d680ba1..20b5eb8d5b 100644 --- a/examples/authors/ydb/query.sql +++ b/examples/authors/ydb/query.sql @@ -17,3 +17,13 @@ WHERE name = $p0; SELECT * FROM authors WHERE bio IS NULL; +-- name: CreateOrUpdateAuthor :execresult +UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2); + +-- name: CreateOrUpdateAuthorRetunringBio :one +UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) RETURNING bio; + +-- name: DeleteAuthor :exec +DELETE FROM authors +WHERE id = $p0; + diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index 53ed896128..289cbb7741 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -7,8 +7,50 @@ package authors import ( "context" + "database/sql" ) +const createOrUpdateAuthor = `-- name: CreateOrUpdateAuthor :execresult +UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) +` + +type CreateOrUpdateAuthorParams struct { + P0 uint64 + P1 string + P2 *string +} + +func (q *Queries) CreateOrUpdateAuthor(ctx context.Context, arg CreateOrUpdateAuthorParams) (sql.Result, error) { + return q.db.ExecContext(ctx, createOrUpdateAuthor, arg.P0, arg.P1, arg.P2) +} + +const createOrUpdateAuthorRetunringBio = `-- name: CreateOrUpdateAuthorRetunringBio :one +UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) RETURNING bio +` + +type CreateOrUpdateAuthorRetunringBioParams struct { + P0 uint64 + P1 string + P2 *string +} + +func (q *Queries) CreateOrUpdateAuthorRetunringBio(ctx context.Context, arg CreateOrUpdateAuthorRetunringBioParams) (*string, error) { + row := q.db.QueryRowContext(ctx, createOrUpdateAuthorRetunringBio, arg.P0, arg.P1, arg.P2) + var bio *string + err := row.Scan(&bio) + return bio, err +} + +const deleteAuthor = `-- name: DeleteAuthor :exec +DELETE FROM authors +WHERE id = $p0 +` + +func (q *Queries) DeleteAuthor(ctx context.Context, p0 uint64) error { + _, err := q.db.ExecContext(ctx, deleteAuthor, p0) + return err +} + const getAuthor = `-- name: GetAuthor :one SELECT id, name, bio FROM authors WHERE id = $p0 diff --git a/internal/codegen/golang/ydb_type.go b/internal/codegen/golang/ydb_type.go index 8a5b1711b3..aba149ab03 100644 --- a/internal/codegen/golang/ydb_type.go +++ b/internal/codegen/golang/ydb_type.go @@ -16,6 +16,7 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col emitPointersForNull := options.EmitPointersForNullTypes // https://ydb.tech/docs/ru/yql/reference/types/ + // ydb-go-sdk doesn't support sql.Null* yet switch columnType { // decimal types case "bool": @@ -25,7 +26,8 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*bool" } - return "sql.NullBool" + // return "sql.NullBool" + return "*bool" case "int8": if notNull { @@ -34,9 +36,10 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*int8" } - // The database/sql package does not have a sql.NullInt8 type, so we - // use the smallest type they have which is NullInt16 - return "sql.NullInt16" + // // The database/sql package does not have a sql.NullInt8 type, so we + // // use the smallest type they have which is NullInt16 + // return "sql.NullInt16" + return "*int8" case "int16": if notNull { return "int16" @@ -44,7 +47,8 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*int16" } - return "sql.NullInt16" + // return "sql.NullInt16" + return "*int16" case "int32": if notNull { return "int32" @@ -52,7 +56,8 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*int32" } - return "sql.NullInt32" + // return "sql.NullInt32" + return "*int32" case "int64": if notNull { return "int64" @@ -60,7 +65,8 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*int64" } - return "sql.NullInt64" + // return "sql.NullInt64" + return "*int64" case "uint8": if emitPointersForNull { @@ -92,7 +98,8 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col } // The database/sql package does not have a sql.NullFloat32 type, so we // use the smallest type they have which is NullFloat64 - return "sql.NullFloat64" + // return "sql.NullFloat64" + return "*float32" case "double": if notNull { return "float64" @@ -100,7 +107,8 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*float64" } - return "sql.NullFloat64" + // return "sql.NullFloat64" + return "*float64" // string types case "string", "utf8", "text": @@ -110,7 +118,7 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*string" } - return "sql.NullString" + return "*string" // serial types case "smallserial", "serial2": @@ -120,7 +128,8 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*int16" } - return "sql.NullInt16" + // return "sql.NullInt16" + return "*int16" case "serial", "serial4": if notNull { @@ -129,7 +138,8 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*int32" } - return "sql.NullInt32" + // return "sql.NullInt32" + return "*int32" case "bigserial", "serial8": if notNull { @@ -138,10 +148,12 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col if emitPointersForNull { return "*int64" } - return "sql.NullInt64" + // return "sql.NullInt64" + return "*int64" case "null": - return "sql.Null" + // return "sql.Null" + return "interface{}" default: if debug.Active { diff --git a/internal/engine/ydb/catalog_tests/insert_test.go b/internal/engine/ydb/catalog_tests/insert_test.go new file mode 100644 index 0000000000..0164a6302f --- /dev/null +++ b/internal/engine/ydb/catalog_tests/insert_test.go @@ -0,0 +1,135 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestInsert(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + { + stmt: "INSERT INTO users (id, name) VALUES (1, 'Alice') RETURNING *", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.InsertStmt{ + Relation: &ast.RangeVar{Relname: strPtr("users")}, + Cols: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{Name: strPtr("id")}, + &ast.ResTarget{Name: strPtr("name")}, + }, + }, + SelectStmt: &ast.SelectStmt{ + ValuesLists: &ast.List{ + Items: []ast.Node{ + &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.Integer{Ival: 1}}, + &ast.A_Const{Val: &ast.String{Str: "Alice"}}, + }, + }, + }, + }, + }, + OnConflictClause: &ast.OnConflictClause{}, + ReturningList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.A_Star{}}}, + }, + }, + }, + }, + }, + }, + }, + }, + { + stmt: "INSERT OR IGNORE INTO users (id) VALUES (3) RETURNING id, name", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.InsertStmt{ + Relation: &ast.RangeVar{Relname: strPtr("users")}, + Cols: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{Name: strPtr("id")}, + }, + }, + SelectStmt: &ast.SelectStmt{ + ValuesLists: &ast.List{ + Items: []ast.Node{ + &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.Integer{Ival: 3}}, + }, + }, + }, + }, + }, + OnConflictClause: &ast.OnConflictClause{ + Action: ast.OnConflictAction_INSERT_OR_IGNORE, + }, + ReturningList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}}, + }, + &ast.ResTarget{ + Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "name"}}}}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: "UPSERT INTO users (id) VALUES (4) RETURNING id", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.InsertStmt{ + Relation: &ast.RangeVar{Relname: strPtr("users")}, + Cols: &ast.List{Items: []ast.Node{&ast.ResTarget{Name: strPtr("id")}}}, + SelectStmt: &ast.SelectStmt{ValuesLists: &ast.List{Items: []ast.Node{&ast.List{Items: []ast.Node{&ast.A_Const{Val: &ast.Integer{Ival: 4}}}}}}}, + OnConflictClause: &ast.OnConflictClause{Action: ast.OnConflictAction_UPSERT}, + ReturningList: &ast.List{Items: []ast.Node{&ast.ResTarget{Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}}}}}, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Запрос %q не распарсен", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + cmpopts.IgnoreFields(ast.A_Const{}, "Location"), + cmpopts.IgnoreFields(ast.ResTarget{}, "Location"), + cmpopts.IgnoreFields(ast.ColumnRef{}, "Location"), + cmpopts.IgnoreFields(ast.A_Expr{}, "Location"), + cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), + ) + if diff != "" { + t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index 341de1a2e9..40e0341411 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -38,6 +38,154 @@ func NewIdentifier(t string) *ast.String { return &ast.String{Str: identifier(t)} } +func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { + batch := n.BATCH() != nil + + tableName := identifier(n.Simple_table_ref().Simple_table_ref_core().GetText()) + rel := &ast.RangeVar{Relname: &tableName} + + var where ast.Node + if n.WHERE() != nil && n.Expr() != nil { + where = c.convert(n.Expr()) + } + if n.ON() != nil && n.Into_values_source() != nil { + // todo: handle delete with into values source + } + + returning := &ast.List{} + if ret := n.Returning_columns_list(); ret != nil { + returning = c.convert(ret).(*ast.List) + } + + stmts := &ast.DeleteStmt{ + Relations: &ast.List{Items: []ast.Node{rel}}, + WhereClause: where, + ReturningList: returning, + Batch: batch, + } + + return stmts +} + +func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast.Node { + tableName := identifier(n.Into_simple_table_ref().Simple_table_ref().Simple_table_ref_core().GetText()) + rel := &ast.RangeVar{Relname: &tableName} + + onConflict := &ast.OnConflictClause{} + switch { + case n.INSERT() != nil && n.OR() != nil && n.ABORT() != nil: + onConflict.Action = ast.OnConflictAction_INSERT_OR_ABORT + case n.INSERT() != nil && n.OR() != nil && n.REVERT() != nil: + onConflict.Action = ast.OnConflictAction_INSERT_OR_REVERT + case n.INSERT() != nil && n.OR() != nil && n.IGNORE() != nil: + onConflict.Action = ast.OnConflictAction_INSERT_OR_IGNORE + case n.UPSERT() != nil: + onConflict.Action = ast.OnConflictAction_UPSERT + case n.REPLACE() != nil: + onConflict.Action = ast.OnConflictAction_REPLACE + } + + var cols *ast.List + var source ast.Node + if nVal := n.Into_values_source(); nVal != nil { + if nVal.DEFAULT() != nil && nVal.VALUES() != nil { + // todo: handle default values when implemented + } + if pureCols := nVal.Pure_column_list(); pureCols != nil { + cols = &ast.List{} + for _, anID := range pureCols.AllAn_id() { + name := identifier(parseAnId(anID)) + cols.Items = append(cols.Items, &ast.ResTarget{ + Name: &name, + }) + } + } + + valSource := nVal.Values_source() + if valSource != nil { + switch { + case valSource.Values_stmt() != nil: + source = &ast.SelectStmt{ + ValuesLists: c.convert(valSource.Values_stmt()).(*ast.List), + FromClause: &ast.List{}, + TargetList: &ast.List{}, + } + + case valSource.Select_stmt() != nil: + source = c.convert(valSource.Select_stmt()) + } + } + } + + returning := &ast.List{} + if ret := n.Returning_columns_list(); ret != nil { + returning = c.convert(ret).(*ast.List) + } + + stmts := &ast.InsertStmt{ + Relation: rel, + Cols: cols, + SelectStmt: source, + OnConflictClause: onConflict, + ReturningList: returning, + } + + return stmts +} + +func (c *cc) convertValues_stmtContext(n *parser.Values_stmtContext) ast.Node { + mainList := &ast.List{} + + for _, rowCtx := range n.Values_source_row_list().AllValues_source_row() { + rowList := &ast.List{} + exprListCtx := rowCtx.Expr_list().(*parser.Expr_listContext) + + for _, exprCtx := range exprListCtx.AllExpr() { + if converted := c.convert(exprCtx); converted != nil { + rowList.Items = append(rowList.Items, converted) + } + } + + mainList.Items = append(mainList.Items, rowList) + + } + + return mainList +} + +func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_listContext) ast.Node { + list := &ast.List{Items: []ast.Node{}} + + if n.ASTERISK() != nil { + target := &ast.ResTarget{ + Indirection: &ast.List{}, + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.A_Star{}}}, + Location: n.ASTERISK().GetSymbol().GetStart(), + }, + Location: n.ASTERISK().GetSymbol().GetStart(), + } + list.Items = append(list.Items, target) + return list + } + + for _, idCtx := range n.AllAn_id() { + target := &ast.ResTarget{ + Indirection: &ast.List{}, + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{NewIdentifier(parseAnId(idCtx))}, + }, + Location: idCtx.GetStart().GetStart(), + }, + Location: idCtx.GetStart().GetStart(), + } + list.Items = append(list.Items, target) + } + + return list +} + func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) ast.Node { tableRef := parseTableName(n.Simple_table_ref().Simple_table_ref_core()) @@ -1610,6 +1758,18 @@ func (c *cc) convert(node node) ast.Node { case *parser.Type_name_or_bindContext: return c.convertTypeNameOrBind(n) + case *parser.Into_table_stmtContext: + return c.convertInto_table_stmtContext(n) + + case *parser.Values_stmtContext: + return c.convertValues_stmtContext(n) + + case *parser.Returning_columns_listContext: + return c.convertReturning_columns_listContext(n) + + case *parser.Delete_stmtContext: + return c.convertDelete_stmtContext(n) + default: return todo("convert(case=default)", n) } diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index d77f043a12..c6fbc8149f 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -7,6 +7,9 @@ type DeleteStmt struct { LimitCount Node ReturningList *List WithClause *WithClause + + // YDB specific + Batch bool } func (n *DeleteStmt) Pos() int { @@ -22,6 +25,9 @@ func (n *DeleteStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.WithClause) buf.WriteString(" ") } + if n.Batch { + buf.WriteString("BATCH ") + } buf.WriteString("DELETE FROM ") if items(n.Relations) { diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 3cdf854091..7be5a183c9 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -24,7 +24,18 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { buf.WriteString(" ") } - buf.WriteString("INSERT INTO ") + switch n.OnConflictClause.Action { + case OnConflictAction_INSERT_OR_ABORT: + buf.WriteString("INSERT OR ABORT INTO ") + case OnConflictAction_INSERT_OR_REVERT: + buf.WriteString("INSERT OR REVERT INTO ") + case OnConflictAction_INSERT_OR_IGNORE: + buf.WriteString("INSERT OR IGNORE INTO ") + case OnConflictAction_UPSERT: + buf.WriteString("UPSERT INTO ") + default: + buf.WriteString("INSERT INTO ") + } if n.Relation != nil { buf.astFormat(n.Relation) } @@ -38,7 +49,7 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.SelectStmt) } - if n.OnConflictClause != nil { + if n.OnConflictClause != nil && n.OnConflictClause.Action < 4 { buf.WriteString(" ON CONFLICT DO NOTHING ") } diff --git a/internal/sql/ast/on_conflict_action_type.go b/internal/sql/ast/on_conflict_action_type.go new file mode 100644 index 0000000000..c149fe8d04 --- /dev/null +++ b/internal/sql/ast/on_conflict_action_type.go @@ -0,0 +1,15 @@ +package ast + +const ( + OnConflictAction_ON_CONFLICT_ACTION_UNDEFINED OnConflictAction = 0 + OnConflictAction_ONCONFLICT_NONE OnConflictAction = 1 + OnConflictAction_ONCONFLICT_NOTHING OnConflictAction = 2 + OnConflictAction_ONCONFLICT_UPDATE OnConflictAction = 3 + + // YQL-specific + OnConflictAction_INSERT_OR_ABORT OnConflictAction = 4 + OnConflictAction_INSERT_OR_REVERT OnConflictAction = 5 + OnConflictAction_INSERT_OR_IGNORE OnConflictAction = 6 + OnConflictAction_UPSERT OnConflictAction = 7 + OnConflictAction_REPLACE OnConflictAction = 8 +) diff --git a/internal/sqltest/local/ydb.go b/internal/sqltest/local/ydb.go index 79be58241f..2064b063dd 100644 --- a/internal/sqltest/local/ydb.go +++ b/internal/sqltest/local/ydb.go @@ -54,7 +54,6 @@ func link_YDB(t *testing.T, migrations []string, rw bool) TestYDB { baseDB = "/local" } - // собираем миграции var seed []string files, err := sqlpath.Glob(migrations) if err != nil { @@ -97,6 +96,7 @@ func link_YDB(t *testing.T, migrations []string, rw bool) TestYDB { driver, ydb.WithTablePathPrefix(prefix), ydb.WithAutoDeclare(), + ydb.WithNumericArgs(), ) if err != nil { t.Fatalf("failed to create connector: %s", err) From 3a4f818c50827e387ae895afb424fe6bed4e8cf2 Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Mon, 28 Apr 2025 21:39:23 +0300 Subject: [PATCH 03/18] Added UPDATE and DELETE full logic. Updated examples in example/authors --- examples/authors/ydb/db_test.go | 40 +++---- examples/authors/ydb/query.sql | 11 +- examples/authors/ydb/query.sql.go | 56 ++++----- go.mod | 5 + go.sum | 112 +++++++++++++++++- internal/engine/ydb/convert.go | 191 +++++++++++++++++++++++++----- internal/sql/ast/delete_stmt.go | 32 ++++- internal/sql/ast/insert_stmt.go | 2 + internal/sql/ast/update_stmt.go | 23 ++++ internal/sql/astutils/rewrite.go | 4 + internal/sql/astutils/walk.go | 12 ++ 11 files changed, 391 insertions(+), 97 deletions(-) diff --git a/examples/authors/ydb/db_test.go b/examples/authors/ydb/db_test.go index 44e330a2f6..76b37306ef 100644 --- a/examples/authors/ydb/db_test.go +++ b/examples/authors/ydb/db_test.go @@ -46,13 +46,13 @@ func TestAuthors(t *testing.T) { t.Run("CreateOrUpdateAuthorReturningBio", func(t *testing.T) { newBio := "Обновленная биография автора" - arg := CreateOrUpdateAuthorRetunringBioParams{ + arg := CreateOrUpdateAuthorReturningBioParams{ P0: 3, P1: "Тестовый Автор", P2: &newBio, } - returnedBio, err := q.CreateOrUpdateAuthorRetunringBio(ctx, arg) + returnedBio, err := q.CreateOrUpdateAuthorReturningBio(ctx, arg) if err != nil { t.Fatalf("failed to create or update author: %v", err) } @@ -67,6 +67,24 @@ func TestAuthors(t *testing.T) { t.Logf("Author created or updated successfully with bio: %s", *returnedBio) }) + t.Run("Update Author", func(t *testing.T) { + arg := UpdateAuthorByIDParams{ + P0: "Максим Горький", + P1: ptr("Обновленная биография"), + P2: 10, + } + + singleAuthor, err := q.UpdateAuthorByID(ctx, arg) + if err != nil { + t.Fatal(err) + } + bio := "Null" + if singleAuthor.Bio != nil { + bio = *singleAuthor.Bio + } + t.Logf("- ID: %d | Name: %s | Bio: %s", singleAuthor.ID, singleAuthor.Name, bio) + }) + t.Run("ListAuthors", func(t *testing.T) { authors, err := q.ListAuthors(ctx) if err != nil { @@ -115,24 +133,6 @@ func TestAuthors(t *testing.T) { } }) - t.Run("ListAuthorsWithIdModulo", func(t *testing.T) { - authors, err := q.ListAuthorsWithIdModulo(ctx) - if err != nil { - t.Fatal(err) - } - if len(authors) == 0 { - t.Fatal("expected at least one author with even ID, got none") - } - t.Log("Authors with even IDs:") - for _, a := range authors { - bio := "Null" - if a.Bio != nil { - bio = *a.Bio - } - t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) - } - }) - t.Run("ListAuthorsWithNullBio", func(t *testing.T) { authors, err := q.ListAuthorsWithNullBio(ctx) if err != nil { diff --git a/examples/authors/ydb/query.sql b/examples/authors/ydb/query.sql index 20b5eb8d5b..67ce89a6a7 100644 --- a/examples/authors/ydb/query.sql +++ b/examples/authors/ydb/query.sql @@ -5,10 +5,6 @@ SELECT * FROM authors; SELECT * FROM authors WHERE id = $p0; --- name: ListAuthorsWithIdModulo :many -SELECT * FROM authors -WHERE id % 2 = 0; - -- name: GetAuthorsByName :many SELECT * FROM authors WHERE name = $p0; @@ -20,10 +16,11 @@ WHERE bio IS NULL; -- name: CreateOrUpdateAuthor :execresult UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2); --- name: CreateOrUpdateAuthorRetunringBio :one +-- name: CreateOrUpdateAuthorReturningBio :one UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) RETURNING bio; -- name: DeleteAuthor :exec -DELETE FROM authors -WHERE id = $p0; +DELETE FROM authors WHERE id = $p0; +-- name: UpdateAuthorByID :one +UPDATE authors SET name = $p0, bio = $p1 WHERE id = $p2 RETURNING *; diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index 289cbb7741..64126bf254 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -24,26 +24,25 @@ func (q *Queries) CreateOrUpdateAuthor(ctx context.Context, arg CreateOrUpdateAu return q.db.ExecContext(ctx, createOrUpdateAuthor, arg.P0, arg.P1, arg.P2) } -const createOrUpdateAuthorRetunringBio = `-- name: CreateOrUpdateAuthorRetunringBio :one +const createOrUpdateAuthorReturningBio = `-- name: CreateOrUpdateAuthorReturningBio :one UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) RETURNING bio ` -type CreateOrUpdateAuthorRetunringBioParams struct { +type CreateOrUpdateAuthorReturningBioParams struct { P0 uint64 P1 string P2 *string } -func (q *Queries) CreateOrUpdateAuthorRetunringBio(ctx context.Context, arg CreateOrUpdateAuthorRetunringBioParams) (*string, error) { - row := q.db.QueryRowContext(ctx, createOrUpdateAuthorRetunringBio, arg.P0, arg.P1, arg.P2) +func (q *Queries) CreateOrUpdateAuthorReturningBio(ctx context.Context, arg CreateOrUpdateAuthorReturningBioParams) (*string, error) { + row := q.db.QueryRowContext(ctx, createOrUpdateAuthorReturningBio, arg.P0, arg.P1, arg.P2) var bio *string err := row.Scan(&bio) return bio, err } const deleteAuthor = `-- name: DeleteAuthor :exec -DELETE FROM authors -WHERE id = $p0 +DELETE FROM authors WHERE id = $p0 ` func (q *Queries) DeleteAuthor(ctx context.Context, p0 uint64) error { @@ -118,34 +117,6 @@ func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { return items, nil } -const listAuthorsWithIdModulo = `-- name: ListAuthorsWithIdModulo :many -SELECT id, name, bio FROM authors -WHERE id % 2 = 0 -` - -func (q *Queries) ListAuthorsWithIdModulo(ctx context.Context) ([]Author, error) { - rows, err := q.db.QueryContext(ctx, listAuthorsWithIdModulo) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Author - for rows.Next() { - var i Author - if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const listAuthorsWithNullBio = `-- name: ListAuthorsWithNullBio :many SELECT id, name, bio FROM authors WHERE bio IS NULL @@ -173,3 +144,20 @@ func (q *Queries) ListAuthorsWithNullBio(ctx context.Context) ([]Author, error) } return items, nil } + +const updateAuthorByID = `-- name: UpdateAuthorByID :one +UPDATE authors SET name = $p0, bio = $p1 WHERE id = $p2 RETURNING id, name, bio +` + +type UpdateAuthorByIDParams struct { + P0 string + P1 *string + P2 uint64 +} + +func (q *Queries) UpdateAuthorByID(ctx context.Context, arg UpdateAuthorByIDParams) (Author, error) { + row := q.db.QueryRowContext(ctx, updateAuthorByID, arg.P0, arg.P1, arg.P2) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} diff --git a/go.mod b/go.mod index 34aaa12c5f..ebd7884d70 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,8 @@ require ( github.com/tetratelabs/wazero v1.9.0 github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 github.com/xeipuuv/gojsonschema v1.2.0 + github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0 + github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 golang.org/x/sync v0.16.0 google.golang.org/grpc v1.75.0 google.golang.org/protobuf v1.36.8 @@ -35,6 +37,7 @@ require ( cel.dev/expr v0.24.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect @@ -45,6 +48,7 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jonboulle/clockwork v0.3.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb // indirect @@ -56,6 +60,7 @@ require ( github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect diff --git a/go.sum b/go.sum index fd8d405059..910c0e9fca 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,24 @@ cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= +github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -20,8 +32,15 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= @@ -33,19 +52,43 @@ github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI6 github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/cel-go v0.26.1 h1:iPbVVEdkhTX++hpe3lzSk7D3G3QSYqLGoHOcEio+UXQ= github.com/google/cel-go v0.26.1/go.mod h1:A9O8OU9rdvrK5MQyrqfIxo1a0u4g3sF8KB6PUIaryMM= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= @@ -101,6 +144,8 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jonboulle/clockwork v0.3.0 h1:9BSCMi8C+0qdApAp4auwX0RkLGUjs956h0EkuQymUhg= +github.com/jonboulle/clockwork v0.3.0/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -143,10 +188,14 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rekby/fixenv v0.6.1 h1:jUFiSPpajT4WY2cYuc++7Y1zWrnCxnovGCIX72PZniM= +github.com/rekby/fixenv v0.6.1/go.mod h1:/b5LRc06BYJtslRtHKxsPWFT/ySpHV+rWvzTg+XWk4c= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/riza-io/grpc-go v0.2.0 h1:2HxQKFVE7VuYstcJ8zqpN84VnAoJ4dCL6YFhJewNcHQ= github.com/riza-io/grpc-go v0.2.0/go.mod h1:2bDvR9KkKC3KhtlSHfR3dAXjUMT86kg4UfWFyVGWqi8= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= @@ -175,8 +224,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 h1:mJdDDPblDfPe7z7go8Dvv1AJQDI3eQ/5xith3q2mFlo= @@ -189,6 +238,12 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 h1:LY6cI8cP4B9rrpTleZk95+08kl2gF4rixG7+V/dwL6Q= +github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77/go.mod h1:Er+FePu1dNUieD+XTMDduGpQuCPssK5Q4BjF+IIXJ3I= +github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0 h1:TwWSp3gRMcja/hRpOofncLvgxAXCmzpz5cGtmdaoITw= +github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0/go.mod h1:l5sSv153E18VvYcsmr51hok9Sjc16tEC8AXGbwrk+ho= +github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 h1:KFtJwlPdOxWjCKXX0jFJ8k1FlbqbRbUW3k/kYSZX7SA= +github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333/go.mod h1:vrPJPS8cdPSV568YcXhB4bUwhyV8bmWKqmQ5c5Xi99o= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= @@ -202,6 +257,7 @@ go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFh go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -213,6 +269,8 @@ go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0 go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= @@ -238,23 +296,39 @@ golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -265,7 +339,10 @@ golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= @@ -280,8 +357,11 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -298,13 +378,38 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 h1:FiusG7LWj+4byqhbvmB+Q93B/mOxJLN2DTozDuZm4EU= google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:kXqgZtrWaf6qS3jZOCnCH7WYfrvFjkC51bM8fz3RsCA= google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY= google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= +google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4= google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= @@ -318,11 +423,14 @@ gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24 gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM= modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index 40e0341411..8e447254b1 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -48,8 +48,37 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { if n.WHERE() != nil && n.Expr() != nil { where = c.convert(n.Expr()) } + var cols *ast.List + var source ast.Node if n.ON() != nil && n.Into_values_source() != nil { - // todo: handle delete with into values source + nVal := n.Into_values_source() + + // todo: handle default values when implemented + + if pureCols := nVal.Pure_column_list(); pureCols != nil { + cols = &ast.List{} + for _, anID := range pureCols.AllAn_id() { + name := identifier(parseAnId(anID)) + cols.Items = append(cols.Items, &ast.ResTarget{ + Name: &name, + }) + } + } + + valSource := nVal.Values_source() + if valSource != nil { + switch { + case valSource.Values_stmt() != nil: + source = &ast.SelectStmt{ + ValuesLists: c.convert(valSource.Values_stmt()).(*ast.List), + FromClause: &ast.List{}, + TargetList: &ast.List{}, + } + + case valSource.Select_stmt() != nil: + source = c.convert(valSource.Select_stmt()) + } + } } returning := &ast.List{} @@ -62,6 +91,125 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { WhereClause: where, ReturningList: returning, Batch: batch, + OnCols: cols, + OnSelectStmt: source, + } + + return stmts +} + +func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { + batch := n.BATCH() != nil + + tableName := identifier(n.Simple_table_ref().Simple_table_ref_core().GetText()) + rel := &ast.RangeVar{Relname: &tableName} + + var where ast.Node + var setList *ast.List + var cols *ast.List + var source ast.Node + + if n.SET() != nil && n.Set_clause_choice() != nil { + nSet := n.Set_clause_choice() + setList = &ast.List{Items: []ast.Node{}} + + switch { + case nSet.Set_clause_list() != nil: + for _, clause := range nSet.Set_clause_list().AllSet_clause() { + targetCtx := clause.Set_target() + columnName := identifier(targetCtx.Column_name().GetText()) + expr := c.convert(clause.Expr()) + resTarget := &ast.ResTarget{ + Name: &columnName, + Val: expr, + } + setList.Items = append(setList.Items, resTarget) + } + + case nSet.Multiple_column_assignment() != nil: + multiAssign := nSet.Multiple_column_assignment() + targetsCtx := multiAssign.Set_target_list() + valuesCtx := multiAssign.Simple_values_source() + + var colNames []string + for _, target := range targetsCtx.AllSet_target() { + targetCtx := target.(*parser.Set_targetContext) + colNames = append(colNames, targetCtx.Column_name().GetText()) + } + + var rowExpr *ast.RowExpr + if exprList := valuesCtx.Expr_list(); exprList != nil { + rowExpr = &ast.RowExpr{ + Args: &ast.List{}, + } + for _, expr := range exprList.AllExpr() { + rowExpr.Args.Items = append(rowExpr.Args.Items, c.convert(expr)) + } + } + + for i, colName := range colNames { + name := identifier(colName) + setList.Items = append(setList.Items, &ast.ResTarget{ + Name: &name, + Val: &ast.MultiAssignRef{ + Source: rowExpr, + Colno: i + 1, + Ncolumns: len(colNames), + }, + }) + } + } + + if n.WHERE() != nil && n.Expr() != nil { + where = c.convert(n.Expr()) + } + } else if n.ON() != nil && n.Into_values_source() != nil { + + // todo: handle default values when implemented + + nVal := n.Into_values_source() + + if pureCols := nVal.Pure_column_list(); pureCols != nil { + cols = &ast.List{} + for _, anID := range pureCols.AllAn_id() { + name := identifier(parseAnId(anID)) + cols.Items = append(cols.Items, &ast.ResTarget{ + Name: &name, + }) + } + } + + valSource := nVal.Values_source() + if valSource != nil { + switch { + case valSource.Values_stmt() != nil: + source = &ast.SelectStmt{ + ValuesLists: c.convert(valSource.Values_stmt()).(*ast.List), + FromClause: &ast.List{}, + TargetList: &ast.List{}, + } + + case valSource.Select_stmt() != nil: + source = c.convert(valSource.Select_stmt()) + } + } + } + + returning := &ast.List{} + if ret := n.Returning_columns_list(); ret != nil { + returning = c.convert(ret).(*ast.List) + } + + stmts := &ast.UpdateStmt{ + Relations: &ast.List{Items: []ast.Node{rel}}, + TargetList: setList, + WhereClause: where, + ReturningList: returning, + FromClause: &ast.List{}, + WithClause: nil, + Batch: batch, + OnCols: cols, + OnSelectStmt: source, } return stmts @@ -88,9 +236,8 @@ func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast var cols *ast.List var source ast.Node if nVal := n.Into_values_source(); nVal != nil { - if nVal.DEFAULT() != nil && nVal.VALUES() != nil { - // todo: handle default values when implemented - } + // todo: handle default values when implemented + if pureCols := nVal.Pure_column_list(); pureCols != nil { cols = &ast.List{} for _, anID := range pureCols.AllAn_id() { @@ -186,20 +333,6 @@ func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_li return list } -func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) ast.Node { - tableRef := parseTableName(n.Simple_table_ref().Simple_table_ref_core()) - - stmt := &ast.AlterTableStmt{ - Table: tableRef, - Cmds: &ast.List{}, - } - for _, action := range n.AllAlter_table_action() { - if add := action.Alter_table_add_column(); add != nil { - } - } - return stmt -} - func (c *cc) convertSelectStmtContext(n *parser.Select_stmtContext) ast.Node { skp := n.Select_kind_parenthesis(0) if skp == nil { @@ -299,10 +432,7 @@ func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { } func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { - exprCtx := n.Expr() - if exprCtx == nil { - // todo - } + // todo: support opt_id_prefix target := &ast.ResTarget{ Location: n.GetStart().GetStart(), } @@ -322,7 +452,7 @@ func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { case n.AS() != nil && n.An_id_or_type() != nil: name := parseAnIdOrType(n.An_id_or_type()) target.Name = &name - case n.An_id_as_compat() != nil: + case n.An_id_as_compat() != nil: //nolint // todo: parse as_compat } target.Val = val @@ -436,7 +566,7 @@ func (c *cc) convertNamedSingleSource(n *parser.Named_single_sourceContext) ast. case *ast.RangeSubselect: source.Alias = &ast.Alias{Aliasname: &aliasText} } - } else if n.An_id_as_compat() != nil { + } else if n.An_id_as_compat() != nil { //nolint // todo: parse as_compat } return base @@ -700,7 +830,7 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if struct_ := n.Type_name_struct(); struct_ != nil { if structArgs := struct_.AllStruct_arg(); len(structArgs) > 0 { var items []ast.Node - for _, _ = range structArgs { + for range structArgs { // TODO: Handle struct field names and types items = append(items, &ast.TODO{}) } @@ -715,7 +845,7 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if variant := n.Type_name_variant(); variant != nil { if variantArgs := variant.AllVariant_arg(); len(variantArgs) > 0 { var items []ast.Node - for _, _ = range variantArgs { + for range variantArgs { // TODO: Handle variant arguments items = append(items, &ast.TODO{}) } @@ -793,7 +923,7 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if enum := n.Type_name_enum(); enum != nil { if typeTags := enum.AllType_name_tag(); len(typeTags) > 0 { var items []ast.Node - for _, _ = range typeTags { // todo: Handle enum tags + for range typeTags { // todo: Handle enum tags items = append(items, &ast.TODO{}) } return &ast.TypeName{ @@ -1195,7 +1325,7 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { Lexpr: base, Rexpr: c.convert(eqSubs[0]), } - if condCtx.ESCAPE() != nil && len(eqSubs) >= 2 { + if condCtx.ESCAPE() != nil && len(eqSubs) >= 2 { //nolint // todo: Add ESCAPE support } return expr @@ -1549,7 +1679,7 @@ func (c *cc) convertUnarySubexprSuffix(base ast.Node, n *parser.Unary_subexpr_su } } - if n.COLLATE() != nil && n.An_id() != nil { + if n.COLLATE() != nil && n.An_id() != nil { //nolint // todo: Handle COLLATE } return colRef @@ -1770,6 +1900,9 @@ func (c *cc) convert(node node) ast.Node { case *parser.Delete_stmtContext: return c.convertDelete_stmtContext(n) + case *parser.Update_stmtContext: + return c.convertUpdate_stmtContext(n) + default: return todo("convert(case=default)", n) } diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index c6fbc8149f..8bc2b89e9d 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -1,15 +1,22 @@ package ast +import ( + "fmt" +) + type DeleteStmt struct { - Relations *List - UsingClause *List - WhereClause Node - LimitCount Node + Relations *List + UsingClause *List + WhereClause Node + LimitCount Node + ReturningList *List WithClause *WithClause // YDB specific - Batch bool + Batch bool + OnCols *List + OnSelectStmt Node } func (n *DeleteStmt) Pos() int { @@ -17,6 +24,7 @@ func (n *DeleteStmt) Pos() int { } func (n *DeleteStmt) Format(buf *TrackedBuffer) { + fmt.Println("DeleteStmt.Format") if n == nil { return } @@ -39,6 +47,20 @@ func (n *DeleteStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.WhereClause) } + if items(n.OnCols) || set(n.OnSelectStmt) { + buf.WriteString(" ON ") + + if items(n.OnCols) { + buf.WriteString("(") + buf.astFormat(n.OnCols) + buf.WriteString(") ") + } + + if set(n.OnSelectStmt) { + buf.astFormat(n.OnSelectStmt) + } + } + if set(n.LimitCount) { buf.WriteString(" LIMIT ") buf.astFormat(n.LimitCount) diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 7be5a183c9..954fb4665c 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -33,6 +33,8 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { buf.WriteString("INSERT OR IGNORE INTO ") case OnConflictAction_UPSERT: buf.WriteString("UPSERT INTO ") + case OnConflictAction_REPLACE: + buf.WriteString("REPLACE INTO ") default: buf.WriteString("INSERT INTO ") } diff --git a/internal/sql/ast/update_stmt.go b/internal/sql/ast/update_stmt.go index efd496ad75..ea3e5bc65f 100644 --- a/internal/sql/ast/update_stmt.go +++ b/internal/sql/ast/update_stmt.go @@ -10,6 +10,11 @@ type UpdateStmt struct { LimitCount Node ReturningList *List WithClause *WithClause + + // YDB specific + Batch bool + OnCols *List + OnSelectStmt Node } func (n *UpdateStmt) Pos() int { @@ -25,6 +30,10 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { buf.WriteString(" ") } + if n.Batch { + buf.WriteString("BATCH ") + } + buf.WriteString("UPDATE ") if items(n.Relations) { buf.astFormat(n.Relations) @@ -100,6 +109,20 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.WhereClause) } + if items(n.OnCols) || set(n.OnSelectStmt) { + buf.WriteString(" ON ") + + if items(n.OnCols) { + buf.WriteString("(") + buf.astFormat(n.OnCols) + buf.WriteString(") ") + } + + if set(n.OnSelectStmt) { + buf.astFormat(n.OnSelectStmt) + } + } + if set(n.LimitCount) { buf.WriteString(" LIMIT ") buf.astFormat(n.LimitCount) diff --git a/internal/sql/astutils/rewrite.go b/internal/sql/astutils/rewrite.go index 93c5be3cfb..f5c8028ab6 100644 --- a/internal/sql/astutils/rewrite.go +++ b/internal/sql/astutils/rewrite.go @@ -685,6 +685,8 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "Relations", nil, n.Relations) a.apply(n, "UsingClause", nil, n.UsingClause) a.apply(n, "WhereClause", nil, n.WhereClause) + a.apply(n, "Cols", nil, n.OnCols) + a.apply(n, "SelectStmt", nil, n.OnSelectStmt) a.apply(n, "ReturningList", nil, n.ReturningList) a.apply(n, "WithClause", nil, n.WithClause) @@ -1159,6 +1161,8 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "TargetList", nil, n.TargetList) a.apply(n, "WhereClause", nil, n.WhereClause) a.apply(n, "FromClause", nil, n.FromClause) + a.apply(n, "Cols", nil, n.OnCols) + a.apply(n, "SelectStmt", nil, n.OnSelectStmt) a.apply(n, "ReturningList", nil, n.ReturningList) a.apply(n, "WithClause", nil, n.WithClause) diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 0943379f03..2f28ba0243 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -1068,6 +1068,12 @@ func Walk(f Visitor, node ast.Node) { if n.WhereClause != nil { Walk(f, n.WhereClause) } + if n.OnCols != nil { + Walk(f, n.OnCols) + } + if n.OnSelectStmt != nil { + Walk(f, n.OnSelectStmt) + } if n.LimitCount != nil { Walk(f, n.LimitCount) } @@ -2041,6 +2047,12 @@ func Walk(f Visitor, node ast.Node) { if n.FromClause != nil { Walk(f, n.FromClause) } + if n.OnCols != nil { + Walk(f, n.OnCols) + } + if n.OnSelectStmt != nil { + Walk(f, n.OnSelectStmt) + } if n.LimitCount != nil { Walk(f, n.LimitCount) } From e980d194e4ddda2fcab2c598d5c50b1d7e7ddf02 Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Mon, 28 Apr 2025 22:24:11 +0300 Subject: [PATCH 04/18] Added some examples and README for demo --- Makefile | 15 +++++++++-- examples/authors/ydb/README.md | 47 ++++++++++++++++++++++++++++++++++ internal/sqltest/local/ydb.go | 4 +-- 3 files changed, 62 insertions(+), 4 deletions(-) create mode 100644 examples/authors/ydb/README.md diff --git a/Makefile b/Makefile index b8745e57dc..18d6ca91b5 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-endtoend test test-ci test-examples test-endtoend start psql mysqlsh proto +.PHONY: build build-endtoend test test-ci test-examples test-endtoend start psql mysqlsh proto sqlc-dev ydb test-examples-ydb gen-examples-ydb build: go build ./... @@ -18,13 +18,21 @@ vet: test-examples: go test --tags=examples ./... +ydb-examples: sqlc-dev ydb gen-examples-ydb test-examples-ydb + +test-examples-ydb: + YDB_SERVER_URI=localhost:2136 go test -v ./examples/authors/ydb/... -count=1 + +gen-examples-ydb: + cd examples/authors/ && SQLCDEBUG=1 ~/bin/sqlc-dev generate && cd ../.. + build-endtoend: cd ./internal/endtoend/testdata && go build ./... test-ci: test-examples build-endtoend vet sqlc-dev: - go build -o ~/bin/sqlc-dev ./cmd/sqlc/ + go build -x -v -o ~/bin/sqlc-dev ./cmd/sqlc/ sqlc-pg-gen: go build -o ~/bin/sqlc-pg-gen ./internal/tools/sqlc-pg-gen @@ -38,6 +46,9 @@ test-json-process-plugin: start: docker compose up -d +ydb: + docker compose up -d ydb + fmt: go fmt ./... diff --git a/examples/authors/ydb/README.md b/examples/authors/ydb/README.md new file mode 100644 index 0000000000..9e77fc7886 --- /dev/null +++ b/examples/authors/ydb/README.md @@ -0,0 +1,47 @@ +# Инструкция по генерации + +В файлах `schema.sql` и `query.sql` записаны, соответственно, схема базы данных и запросы, из которых вы хотите сгенерировать код к базе данных. +В `db_test.go` находятся тесты для последних сгенерированных команд. +Ниже находятся команды для генерации и запуска тестов. + +--- + +### 1. Создание бинарника sqlc + +```bash +make sqlc-dev +``` + +### 2. Запуск YDB через Docker Compose + +```bash +make ydb +``` + +### 3. Генерация кода для примеров для YDB + +```bash +make gen-examples-ydb +``` + +### 4. Запуск тестов примеров для YDB + +```bash +make test-examples-ydb +``` + +### 5. Полный цикл: сборка, генерация, тестирование (удобно одной командой) + +```bash +make ydb-examples +``` + +Эта команда выполнит: + +- Сборку `sqlc-dev` +- Запуск контейнера YDB +- Генерацию примеров +- Тестирование примеров + +--- + diff --git a/internal/sqltest/local/ydb.go b/internal/sqltest/local/ydb.go index 2064b063dd..8703b170b5 100644 --- a/internal/sqltest/local/ydb.go +++ b/internal/sqltest/local/ydb.go @@ -36,7 +36,8 @@ type TestYDB struct { func link_YDB(t *testing.T, migrations []string, rw bool) TestYDB { t.Helper() - // 1) Контекст с таймаутом + time.Sleep(1 * time.Second) // wait for YDB to start + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -78,7 +79,6 @@ func link_YDB(t *testing.T, migrations []string, rw bool) TestYDB { } prefix := fmt.Sprintf("%s/%s", baseDB, name) - // 2) Открываем драйвер к корню "/" rootDSN := fmt.Sprintf("grpc://%s?database=%s", dbuiri, baseDB) t.Logf("→ Opening root driver: %s", rootDSN) driver, err := ydb.Open(ctx, rootDSN, From 3013c844b36546a0f1436d489ff8ca7c9e27911c Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Sat, 17 May 2025 22:26:14 +0300 Subject: [PATCH 05/18] Supported some more YQL constructions: USE ... PRAGMA CREATE USER CREATE GROUP COMMIT RESTORE also added some new tests and debugged old code for mistakes --- internal/codegen/golang/ydb_type.go | 7 +- .../ydb/catalog_tests/create_group_test.go | 110 +++++ .../ydb/catalog_tests/create_user_test.go | 129 ++++++ .../engine/ydb/catalog_tests/delete_test.go | 200 +++++++++ .../engine/ydb/catalog_tests/insert_test.go | 15 +- .../engine/ydb/catalog_tests/pragma_test.go | 118 +++++ .../engine/ydb/catalog_tests/select_test.go | 15 +- .../engine/ydb/catalog_tests/update_test.go | 185 ++++++++ internal/engine/ydb/convert.go | 407 +++++++++++++++++- internal/engine/ydb/utils.go | 2 + internal/sql/ast/create_role_stmt.go | 3 + internal/sql/ast/delete_stmt.go | 5 - internal/sql/ast/param_ref.go | 7 + internal/sql/ast/pragma_stmt.go | 43 ++ internal/sql/ast/role_spec.go | 2 + internal/sql/ast/use_stmt.go | 11 + internal/sql/astutils/rewrite.go | 12 +- internal/sql/astutils/walk.go | 23 +- 18 files changed, 1266 insertions(+), 28 deletions(-) create mode 100644 internal/engine/ydb/catalog_tests/create_group_test.go create mode 100644 internal/engine/ydb/catalog_tests/create_user_test.go create mode 100644 internal/engine/ydb/catalog_tests/delete_test.go create mode 100644 internal/engine/ydb/catalog_tests/pragma_test.go create mode 100644 internal/engine/ydb/catalog_tests/update_test.go create mode 100644 internal/sql/ast/pragma_stmt.go create mode 100644 internal/sql/ast/use_stmt.go diff --git a/internal/codegen/golang/ydb_type.go b/internal/codegen/golang/ydb_type.go index aba149ab03..e9e5c46344 100644 --- a/internal/codegen/golang/ydb_type.go +++ b/internal/codegen/golang/ydb_type.go @@ -49,7 +49,7 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col } // return "sql.NullInt16" return "*int16" - case "int32": + case "int", "int32": //ydb doesn't have int type, but we need it to support untyped constants if notNull { return "int32" } @@ -155,9 +155,12 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col // return "sql.Null" return "interface{}" + case "any": + return "interface{}" + default: if debug.Active { - log.Printf("unknown SQLite type: %s\n", columnType) + log.Printf("unknown YDB type: %s\n", columnType) } return "interface{}" diff --git a/internal/engine/ydb/catalog_tests/create_group_test.go b/internal/engine/ydb/catalog_tests/create_group_test.go new file mode 100644 index 0000000000..724e912168 --- /dev/null +++ b/internal/engine/ydb/catalog_tests/create_group_test.go @@ -0,0 +1,110 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestCreateGroup(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + { + stmt: `CREATE GROUP group1`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(3), // CREATE GROUP + Role: strPtr("group1"), + Options: &ast.List{}, + }, + }, + }, + }, + { + stmt: `CREATE GROUP group1 WITH USER alice`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(3), + Role: strPtr("group1"), + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("rolemembers"), + Arg: &ast.List{ + Items: []ast.Node{ + &ast.RoleSpec{ + Roletype: ast.RoleSpecType(1), + Rolename: strPtr("alice"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `CREATE GROUP group1 WITH USER alice, bebik`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(3), + Role: strPtr("group1"), + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("rolemembers"), + Arg: &ast.List{ + Items: []ast.Node{ + &ast.RoleSpec{ + Roletype: ast.RoleSpecType(1), + Rolename: strPtr("alice"), + }, + &ast.RoleSpec{ + Roletype: ast.RoleSpecType(1), + Rolename: strPtr("bebik"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Запрос %q не распарсен", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + cmpopts.IgnoreFields(ast.DefElem{}, "Location"), + cmpopts.IgnoreFields(ast.RoleSpec{}, "Location"), + cmpopts.IgnoreFields(ast.A_Const{}, "Location"), + ) + if diff != "" { + t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/catalog_tests/create_user_test.go b/internal/engine/ydb/catalog_tests/create_user_test.go new file mode 100644 index 0000000000..be53e9dd79 --- /dev/null +++ b/internal/engine/ydb/catalog_tests/create_user_test.go @@ -0,0 +1,129 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestCreateUser(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + { + stmt: `CREATE USER alice`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(2), // CREATE USER + Role: strPtr("alice"), + Options: &ast.List{}, + }, + }, + }, + }, + { + stmt: `CREATE USER bob PASSWORD 'secret'`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(2), + Role: strPtr("bob"), + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("password"), + Arg: &ast.String{Str: "secret"}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `CREATE USER charlie LOGIN`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(2), + Role: strPtr("charlie"), + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("login"), + Arg: &ast.Boolean{Boolval: true}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `CREATE USER dave NOLOGIN`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(2), + Role: strPtr("dave"), + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("nologin"), + Arg: &ast.Boolean{Boolval: false}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `CREATE USER bjorn HASH 'abc123'`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(2), + Role: strPtr("bjorn"), + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("hash"), + Arg: &ast.String{Str: "abc123"}, + }, + }, + }, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Запрос %q не распарсен", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + cmpopts.IgnoreFields(ast.A_Const{}, "Location"), + cmpopts.IgnoreFields(ast.DefElem{}, "Location"), + ) + if diff != "" { + t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/catalog_tests/delete_test.go b/internal/engine/ydb/catalog_tests/delete_test.go new file mode 100644 index 0000000000..b75591a9ef --- /dev/null +++ b/internal/engine/ydb/catalog_tests/delete_test.go @@ -0,0 +1,200 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestDelete(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + { + stmt: "DELETE FROM users WHERE id = 1 RETURNING id", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.DeleteStmt{ + Relations: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: strPtr("users")}, + }, + }, + WhereClause: &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "="}}}, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 1}, + }, + }, + ReturningList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Indirection: &ast.List{}, + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{&ast.String{Str: "id"}}, + }, + }, + }, + }, + }, + Batch: false, + OnCols: nil, + OnSelectStmt: nil, + }, + }, + }, + }, + { + stmt: "BATCH DELETE FROM users WHERE is_deleted = true RETURNING *", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.DeleteStmt{ + Relations: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: strPtr("users")}, + }, + }, + WhereClause: &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "="}}}, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "is_deleted"}}}, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Boolean{Boolval: true}, + }, + }, + ReturningList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Indirection: &ast.List{}, + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.A_Star{}}}, + }, + }, + }, + }, + Batch: true, + OnCols: nil, + OnSelectStmt: nil, + }, + }, + }, + }, + { + stmt: "DELETE FROM users ON (id) VALUES (5) RETURNING id", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.DeleteStmt{ + Relations: &ast.List{Items: []ast.Node{&ast.RangeVar{Relname: strPtr("users")}}}, + OnCols: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{Name: strPtr("id")}, + }, + }, + OnSelectStmt: &ast.SelectStmt{ + ValuesLists: &ast.List{ + Items: []ast.Node{ + &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.Integer{Ival: 5}}, + }, + }, + }, + }, + FromClause: &ast.List{}, + TargetList: &ast.List{}, + }, + ReturningList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Indirection: &ast.List{}, + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + Batch: false, + WhereClause: nil, + }, + }, + }, + }, + { + stmt: "DELETE FROM users ON (id) SELECT 1 AS id RETURNING id", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.DeleteStmt{ + Relations: &ast.List{Items: []ast.Node{&ast.RangeVar{Relname: strPtr("users")}}}, + OnCols: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{Name: strPtr("id")}, + }, + }, + OnSelectStmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Name: strPtr("id"), + Val: &ast.A_Const{Val: &ast.Integer{Ival: 1}}, + }, + }, + }, + FromClause: &ast.List{}, + }, + ReturningList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Indirection: &ast.List{}, + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}, + }, + }, + }, + }, + Batch: false, + WhereClause: nil, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Запрос %q не распарсен", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + cmpopts.IgnoreFields(ast.A_Const{}, "Location"), + cmpopts.IgnoreFields(ast.ResTarget{}, "Location"), + cmpopts.IgnoreFields(ast.ColumnRef{}, "Location"), + cmpopts.IgnoreFields(ast.A_Expr{}, "Location"), + cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), + ) + if diff != "" { + t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/catalog_tests/insert_test.go b/internal/engine/ydb/catalog_tests/insert_test.go index 0164a6302f..40f116a3ba 100644 --- a/internal/engine/ydb/catalog_tests/insert_test.go +++ b/internal/engine/ydb/catalog_tests/insert_test.go @@ -38,11 +38,14 @@ func TestInsert(t *testing.T) { }, }, }, + TargetList: &ast.List{}, + FromClause: &ast.List{}, }, OnConflictClause: &ast.OnConflictClause{}, ReturningList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ + Indirection: &ast.List{}, Val: &ast.ColumnRef{ Fields: &ast.List{Items: []ast.Node{&ast.A_Star{}}}, }, @@ -74,6 +77,8 @@ func TestInsert(t *testing.T) { }, }, }, + TargetList: &ast.List{}, + FromClause: &ast.List{}, }, OnConflictClause: &ast.OnConflictClause{ Action: ast.OnConflictAction_INSERT_OR_IGNORE, @@ -81,10 +86,12 @@ func TestInsert(t *testing.T) { ReturningList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ - Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}}, + Indirection: &ast.List{}, + Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}}, }, &ast.ResTarget{ - Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "name"}}}}, + Indirection: &ast.List{}, + Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "name"}}}}, }, }, }, @@ -99,9 +106,9 @@ func TestInsert(t *testing.T) { Stmt: &ast.InsertStmt{ Relation: &ast.RangeVar{Relname: strPtr("users")}, Cols: &ast.List{Items: []ast.Node{&ast.ResTarget{Name: strPtr("id")}}}, - SelectStmt: &ast.SelectStmt{ValuesLists: &ast.List{Items: []ast.Node{&ast.List{Items: []ast.Node{&ast.A_Const{Val: &ast.Integer{Ival: 4}}}}}}}, + SelectStmt: &ast.SelectStmt{ValuesLists: &ast.List{Items: []ast.Node{&ast.List{Items: []ast.Node{&ast.A_Const{Val: &ast.Integer{Ival: 4}}}}}}, TargetList: &ast.List{}, FromClause: &ast.List{}}, OnConflictClause: &ast.OnConflictClause{Action: ast.OnConflictAction_UPSERT}, - ReturningList: &ast.List{Items: []ast.Node{&ast.ResTarget{Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}}}}}, + ReturningList: &ast.List{Items: []ast.Node{&ast.ResTarget{Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}}, Indirection: &ast.List{}}}}, }, }, }, diff --git a/internal/engine/ydb/catalog_tests/pragma_test.go b/internal/engine/ydb/catalog_tests/pragma_test.go new file mode 100644 index 0000000000..9db4406c53 --- /dev/null +++ b/internal/engine/ydb/catalog_tests/pragma_test.go @@ -0,0 +1,118 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestPragma(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + { + stmt: `PRAGMA AutoCommit`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.Pragma_stmt{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.String{Str: "autocommit"}}, + }, + }, + }, + }, + }, + }, + { + stmt: `PRAGMA TablePathPrefix = "home/yql"`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.Pragma_stmt{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.String{Str: "tablepathprefix"}}, + }, + }, + Equals: true, + Values: &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.String{Str: "home/yql"}}, + }, + }, + }, + }, + }, + }, + { + stmt: `PRAGMA Warning("disable", "1101")`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.Pragma_stmt{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.String{Str: "warning"}}, + }, + }, + Equals: false, + Values: &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.String{Str: "disable"}}, + &ast.A_Const{Val: &ast.String{Str: "1101"}}, + }, + }, + }, + }, + }, + }, + { + stmt: `PRAGMA yson.AutoConvert = true`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.Pragma_stmt{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.String{Str: "yson"}}, + &ast.A_Const{Val: &ast.String{Str: "autoconvert"}}, + }, + }, + Equals: true, + Values: &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.Boolean{Boolval: true}}, + }, + }, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Запрос %q не распарсен", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + cmpopts.IgnoreFields(ast.Pragma_stmt{}, "Location"), + cmpopts.IgnoreFields(ast.ColumnRef{}, "Location"), + cmpopts.IgnoreFields(ast.A_Const{}, "Location"), + ) + if diff != "" { + t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/catalog_tests/select_test.go b/internal/engine/ydb/catalog_tests/select_test.go index 95ae49163d..47794ce81f 100644 --- a/internal/engine/ydb/catalog_tests/select_test.go +++ b/internal/engine/ydb/catalog_tests/select_test.go @@ -34,6 +34,7 @@ func TestSelect(t *testing.T) { }, }, }, + FromClause: &ast.List{}, }, }, }, @@ -52,6 +53,7 @@ func TestSelect(t *testing.T) { }, }, }, + FromClause: &ast.List{}, }, }, }, @@ -70,6 +72,7 @@ func TestSelect(t *testing.T) { }, }, }, + FromClause: &ast.List{}, }, }, }, @@ -88,6 +91,7 @@ func TestSelect(t *testing.T) { }, }, }, + FromClause: &ast.List{}, }, }, }, @@ -104,6 +108,7 @@ func TestSelect(t *testing.T) { }, }, }, + FromClause: &ast.List{}, }, }, }, @@ -116,10 +121,13 @@ func TestSelect(t *testing.T) { TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ - Val: &ast.Boolean{Boolval: true}, + Val: &ast.A_Const{ + Val: &ast.Boolean{Boolval: true}, + }, }, }, }, + FromClause: &ast.List{}, }, }, }, @@ -158,6 +166,7 @@ func TestSelect(t *testing.T) { }, }, }, + FromClause: &ast.List{}, }, }, }, @@ -285,7 +294,9 @@ func TestSelect(t *testing.T) { Val: &ast.Null{}, }, &ast.ResTarget{ - Val: &ast.Boolean{Boolval: false}, + Val: &ast.A_Const{ + Val: &ast.Boolean{Boolval: false}, + }, }, }, }, diff --git a/internal/engine/ydb/catalog_tests/update_test.go b/internal/engine/ydb/catalog_tests/update_test.go new file mode 100644 index 0000000000..57c90bc2a4 --- /dev/null +++ b/internal/engine/ydb/catalog_tests/update_test.go @@ -0,0 +1,185 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestUpdate(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + { + stmt: "UPDATE users SET name = 'Bob' WHERE id = 1 RETURNING id;", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.UpdateStmt{ + Relations: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: strPtr("users")}, + }, + }, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Name: strPtr("name"), + Val: &ast.A_Const{ + Val: &ast.String{Str: "Bob"}, + }, + }, + }, + }, + WhereClause: &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "="}}}, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 1}, + }, + }, + ReturningList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Indirection: &ast.List{}, + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}, + }, + }, + }, + }, + FromClause: &ast.List{}, + WithClause: nil, + Batch: false, + OnCols: nil, + OnSelectStmt: nil, + }, + }, + }, + }, + { + stmt: "BATCH UPDATE users SET name = 'Charlie' WHERE id = 2 RETURNING *", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.UpdateStmt{ + Relations: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: strPtr("users")}, + }, + }, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Name: strPtr("name"), + Val: &ast.A_Const{Val: &ast.String{Str: "Charlie"}}, + }, + }, + }, + WhereClause: &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "="}}}, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 2}, + }, + }, + ReturningList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Indirection: &ast.List{}, + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.A_Star{}}}, + }, + }, + }, + }, + FromClause: &ast.List{}, + WithClause: nil, + Batch: true, + OnCols: nil, + OnSelectStmt: nil, + }, + }, + }, + }, + { + stmt: "UPDATE users ON (id) VALUES (5) RETURNING id", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.UpdateStmt{ + Relations: &ast.List{Items: []ast.Node{&ast.RangeVar{Relname: strPtr("users")}}}, + OnCols: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{Name: strPtr("id")}, + }, + }, + OnSelectStmt: &ast.SelectStmt{ + ValuesLists: &ast.List{ + Items: []ast.Node{ + &ast.List{ + Items: []ast.Node{ + &ast.A_Const{Val: &ast.Integer{Ival: 5}}, + }, + }, + }, + }, + FromClause: &ast.List{}, + TargetList: &ast.List{}, + }, + ReturningList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Indirection: &ast.List{}, + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{}, + WithClause: nil, + Batch: false, + TargetList: nil, + WhereClause: nil, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Запрос %q не распарсен", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + cmpopts.IgnoreFields(ast.A_Const{}, "Location"), + cmpopts.IgnoreFields(ast.ResTarget{}, "Location"), + cmpopts.IgnoreFields(ast.ColumnRef{}, "Location"), + cmpopts.IgnoreFields(ast.A_Expr{}, "Location"), + cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), + ) + if diff != "" { + t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index 8e447254b1..6fd2fe0ea3 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -21,7 +21,7 @@ type node interface { func todo(funcname string, n node) *ast.TODO { if debug.Active { - log.Printf("sqlite.%s: Unknown node type %T\n", funcname, n) + log.Printf("ydb.%s: Unknown node type %T\n", funcname, n) } return &ast.TODO{} } @@ -34,10 +34,280 @@ func identifier(id string) string { return strings.ToLower(id) } +func stripQuotes(s string) string { + if len(s) >= 2 && (s[0] == '\'' || s[0] == '"') && s[0] == s[len(s)-1] { + return s[1 : len(s)-1] + } + return s +} + func NewIdentifier(t string) *ast.String { return &ast.String{Str: identifier(t)} } +func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) ast.Node { + if n.CREATE() == nil || n.GROUP() == nil || len(n.AllRole_name()) == 0 { + return todo("Create_group_stmtContext", n) + } + groupName := c.convert(n.Role_name(0)) + + stmt := &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(3), + Options: &ast.List{}, + } + + paramFlag := true + switch v := groupName.(type) { + case *ast.A_Const: + switch val := v.Val.(type) { + case *ast.String: + paramFlag = false + stmt.Role = &val.Str + case *ast.Boolean: + stmt.BindRole = groupName + default: + return todo("convertCreate_group_stmtContext", n) + } + case *ast.ParamRef, *ast.A_Expr: + stmt.BindRole = groupName + default: + return todo("convertCreate_group_stmtContext", n) + } + + if debug.Active && paramFlag { + log.Printf("YDB does not currently support parameters in the CREATE GROUP statement") + } + + if n.WITH() != nil && n.USER() != nil && len(n.AllRole_name()) > 1 { + defname := "rolemembers" + optionList := &ast.List{} + for _, role := range n.AllRole_name()[1:] { + roleNode := c.convert(role) + roleSpec := &ast.RoleSpec{ + Roletype: ast.RoleSpecType(1), + Location: role.GetStart().GetStart(), + } + isParam := true + switch v := roleNode.(type) { + case *ast.A_Const: + switch val := v.Val.(type) { + case *ast.String: + isParam = false + roleSpec.Rolename = &val.Str + case *ast.Boolean: + roleSpec.BindRolename = roleNode + default: + return todo("convertCreate_group_stmtContext", n) + } + case *ast.ParamRef, *ast.A_Expr: + roleSpec.BindRolename = roleNode + default: + return todo("convertCreate_group_stmtContext", n) + } + + if debug.Active && isParam && !paramFlag { + log.Printf("YDB does not currently support parameters in the CREATE GROUP statement") + } + + optionList.Items = append(optionList.Items, roleSpec) + } + + stmt.Options.Items = append(stmt.Options.Items, &ast.DefElem{ + Defname: &defname, + Arg: optionList, + Location: n.GetStart().GetStart(), + }) + } + + return stmt +} + +func (c *cc) convertUse_stmtContext(n *parser.Use_stmtContext) ast.Node { + if n.USE() != nil && n.Cluster_expr() != nil { + clusterExpr := c.convert(n.Cluster_expr()) + stmt := &ast.UseStmt{ + Xpr: clusterExpr, + Location: n.GetStart().GetStart(), + } + return stmt + } + return todo("convertUse_stmtContext", n) +} + +func (c *cc) convertCluster_exprContext(n *parser.Cluster_exprContext) ast.Node { + var node ast.Node + + switch { + case n.Pure_column_or_named() != nil: + pureCtx := n.Pure_column_or_named() + if anID := pureCtx.An_id(); anID != nil { + name := parseAnId(anID) + node = &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{NewIdentifier(name)}}, + Location: anID.GetStart().GetStart(), + } + } else if bp := pureCtx.Bind_parameter(); bp != nil { + node = c.convert(bp) + } + case n.ASTERISK() != nil: + node = &ast.A_Star{} + default: + return todo("convertCluster_exprContext", n) + } + + if n.An_id() != nil && n.COLON() != nil { + name := parseAnId(n.An_id()) + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: ":"}}}, + Lexpr: &ast.String{Str: name}, + Rexpr: node, + Location: n.GetStart().GetStart(), + } + } + + return node +} + +func (c *cc) convertCreate_user_stmtContext(n *parser.Create_user_stmtContext) ast.Node { + if n.CREATE() == nil || n.USER() == nil || n.Role_name() == nil { + return todo("convertCreate_user_stmtContext", n) + } + roleNode := c.convert(n.Role_name()) + + stmt := &ast.CreateRoleStmt{ + StmtType: ast.RoleStmtType(2), + Options: &ast.List{}, + } + + paramFlag := true + switch v := roleNode.(type) { + case *ast.A_Const: + switch val := v.Val.(type) { + case *ast.String: + paramFlag = false + stmt.Role = &val.Str + case *ast.Boolean: + stmt.BindRole = roleNode + default: + return todo("convertCreate_user_stmtContext", n) + } + case *ast.ParamRef, *ast.A_Expr: + stmt.BindRole = roleNode + default: + return todo("convertCreate_user_stmtContext", n) + } + + if debug.Active && paramFlag { + log.Printf("YDB does not currently support parameters in the CREATE USER statement") + } + + if len(n.AllUser_option()) > 0 { + options := []ast.Node{} + for _, opt := range n.AllUser_option() { + if node := c.convert(opt); node != nil { + options = append(options, node) + } + } + if len(options) > 0 { + stmt.Options = &ast.List{Items: options} + } + } + return stmt +} + +func (c *cc) convertUser_optionContext(n *parser.User_optionContext) ast.Node { + switch { + case n.Authentication_option() != nil: + aOpt := n.Authentication_option() + if pOpt := aOpt.Password_option(); pOpt != nil { + if pOpt.PASSWORD() != nil { + name := "password" + pValue := pOpt.Password_value() + var password ast.Node + if pValue.STRING_VALUE() != nil { + password = &ast.String{Str: stripQuotes(pValue.STRING_VALUE().GetText())} + } else { + password = &ast.Null{} + } + return &ast.DefElem{ + Defname: &name, + Arg: password, + Location: pOpt.GetStart().GetStart(), + } + } + } else if hOpt := aOpt.Hash_option(); hOpt != nil { + if debug.Active { + log.Printf("YDB does not currently support HASH in CREATE USER statement") + } + var pass string + if hOpt.HASH() != nil && hOpt.STRING_VALUE() != nil { + pass = stripQuotes(hOpt.STRING_VALUE().GetText()) + } + name := "hash" + return &ast.DefElem{ + Defname: &name, + Arg: &ast.String{Str: pass}, + Location: hOpt.GetStart().GetStart(), + } + } + + case n.Login_option() != nil: + lOpt := n.Login_option() + var name string + if lOpt.LOGIN() != nil { + name = "login" + } else if lOpt.NOLOGIN() != nil { + name = "nologin" + } + return &ast.DefElem{ + Defname: &name, + Arg: &ast.Boolean{Boolval: lOpt.LOGIN() != nil}, + Location: lOpt.GetStart().GetStart(), + } + default: + return todo("convertUser_optionContext", n) + } + return nil +} + +func (c *cc) convertRole_nameContext(n *parser.Role_nameContext) ast.Node { + switch { + case n.An_id_or_type() != nil: + name := parseAnIdOrType(n.An_id_or_type()) + return &ast.A_Const{Val: NewIdentifier(name), Location: n.GetStart().GetStart()} + case n.Bind_parameter() != nil: + bindPar := c.convert(n.Bind_parameter()) + return bindPar + } + return todo("convertRole_nameContext", n) +} + +func (c *cc) convertCommit_stmtContext(n *parser.Commit_stmtContext) ast.Node { + if n.COMMIT() != nil { + return &ast.TransactionStmt{Kind: ast.TransactionStmtKind(3)} + } + return todo("convertCommit_stmtContext", n) +} + +func (c *cc) convertRollback_stmtContext(n *parser.Rollback_stmtContext) ast.Node { + if n.ROLLBACK() != nil { + return &ast.TransactionStmt{Kind: ast.TransactionStmtKind(4)} + } + return todo("convertRollback_stmtContext", n) +} + +func (c *cc) convertDrop_table_stmtContext(n *parser.Drop_table_stmtContext) ast.Node { + if n.DROP() != nil && (n.TABLESTORE() != nil || (n.EXTERNAL() != nil && n.TABLE() != nil) || n.TABLE() != nil) { + name := parseTableName(n.Simple_table_ref().Simple_table_ref_core()) + stmt := &ast.DropTableStmt{ + IfExists: n.IF() != nil && n.EXISTS() != nil, + Tables: []*ast.TableName{name}, + } + return stmt + } + return todo("convertDrop_Table_stmtContxt", n) +} + func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { batch := n.BATCH() != nil @@ -52,9 +322,7 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { var source ast.Node if n.ON() != nil && n.Into_values_source() != nil { nVal := n.Into_values_source() - // todo: handle default values when implemented - if pureCols := nVal.Pure_column_list(); pureCols != nil { cols = &ast.List{} for _, anID := range pureCols.AllAn_id() { @@ -98,7 +366,85 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { return stmts } +func (c *cc) convertPragma_stmtContext(n *parser.Pragma_stmtContext) ast.Node { + if n.PRAGMA() != nil && n.An_id() != nil { + prefix := "" + if p := n.Opt_id_prefix_or_type(); p != nil { + prefix = parseAnIdOrType(p.An_id_or_type()) + } + items := []ast.Node{} + if prefix != "" { + items = append(items, &ast.A_Const{Val: NewIdentifier(prefix)}) + } + + name := parseAnId(n.An_id()) + items = append(items, &ast.A_Const{Val: NewIdentifier(name)}) + + stmt := &ast.Pragma_stmt{ + Name: &ast.List{Items: items}, + Location: n.An_id().GetStart().GetStart(), + } + + if n.EQUALS() != nil { + stmt.Equals = true + if val := n.Pragma_value(0); val != nil { + stmt.Values = &ast.List{Items: []ast.Node{c.convert(val)}} + } + } else if lp := n.LPAREN(); lp != nil { + values := []ast.Node{} + for _, v := range n.AllPragma_value() { + values = append(values, c.convert(v)) + } + stmt.Values = &ast.List{Items: values} + } + + return stmt + } + return todo("convertPragma_stmtContext", n) +} + +func (c *cc) convertPragma_valueContext(n *parser.Pragma_valueContext) ast.Node { + switch { + case n.Signed_number() != nil: + if n.Signed_number().Integer() != nil { + text := n.Signed_number().GetText() + val, err := parseIntegerValue(text) + if err != nil { + if debug.Active { + log.Printf("Failed to parse integer value '%s': %v", text, err) + } + return &ast.TODO{} + } + return &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: n.GetStart().GetStart()} + } + if n.Signed_number().Real_() != nil { + text := n.Signed_number().GetText() + return &ast.A_Const{Val: &ast.Float{Str: text}, Location: n.GetStart().GetStart()} + } + case n.STRING_VALUE() != nil: + val := n.STRING_VALUE().GetText() + if len(val) >= 2 { + val = val[1 : len(val)-1] + } + return &ast.A_Const{Val: &ast.String{Str: val}, Location: n.GetStart().GetStart()} + case n.Bool_value() != nil: + var i bool + if n.Bool_value().TRUE() != nil { + i = true + } + return &ast.A_Const{Val: &ast.Boolean{Boolval: i}, Location: n.GetStart().GetStart()} + case n.Bind_parameter() != nil: + bindPar := c.convert(n.Bind_parameter()) + return bindPar + } + + return todo("convertPragma_valueContext", n) +} + func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { + if n.UPDATE() == nil { + return nil + } batch := n.BATCH() != nil tableName := identifier(n.Simple_table_ref().Simple_table_ref_core().GetText()) @@ -205,8 +551,8 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { TargetList: setList, WhereClause: where, ReturningList: returning, - FromClause: &ast.List{}, - WithClause: nil, + FromClause: &ast.List{}, + WithClause: nil, Batch: batch, OnCols: cols, OnSelectStmt: source, @@ -379,7 +725,10 @@ func (c *cc) convertSelectStmtContext(n *parser.Select_stmtContext) ast.Node { } func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { - stmt := &ast.SelectStmt{} + stmt := &ast.SelectStmt{ + TargetList: &ast.List{}, + FromClause: &ast.List{}, + } if n.Opt_set_quantifier() != nil { oq := n.Opt_set_quantifier() if oq.DISTINCT() != nil { @@ -460,6 +809,9 @@ func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { } func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { + if n == nil { + return nil + } fsList := n.AllFlatten_source() if len(fsList) == 0 { return nil @@ -597,14 +949,10 @@ func (c *cc) convertBindParameter(n *parser.Bind_parameterContext) ast.Node { // !!debug later!! if n.DOLLAR() != nil { if n.TRUE() != nil { - return &ast.Boolean{ - Boolval: true, - } + return &ast.A_Const{Val: &ast.Boolean{Boolval: true}, Location: n.GetStart().GetStart()} } if n.FALSE() != nil { - return &ast.Boolean{ - Boolval: false, - } + return &ast.A_Const{Val: &ast.Boolean{Boolval: false}, Location: n.GetStart().GetStart()} } if an := n.An_id_or_type(); an != nil { @@ -1740,7 +2088,7 @@ func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { if n.Bool_value().TRUE() != nil { i = true } - return &ast.Boolean{Boolval: i} + return &ast.A_Const{Val: &ast.Boolean{Boolval: i}, Location: n.GetStart().GetStart()} case n.NULL() != nil: return &ast.Null{} @@ -1903,6 +2251,39 @@ func (c *cc) convert(node node) ast.Node { case *parser.Update_stmtContext: return c.convertUpdate_stmtContext(n) + case *parser.Drop_table_stmtContext: + return c.convertDrop_table_stmtContext(n) + + case *parser.Commit_stmtContext: + return c.convertCommit_stmtContext(n) + + case *parser.Rollback_stmtContext: + return c.convertRollback_stmtContext(n) + + case *parser.Pragma_valueContext: + return c.convertPragma_valueContext(n) + + case *parser.Pragma_stmtContext: + return c.convertPragma_stmtContext(n) + + case *parser.Use_stmtContext: + return c.convertUse_stmtContext(n) + + case *parser.Cluster_exprContext: + return c.convertCluster_exprContext(n) + + case *parser.Create_user_stmtContext: + return c.convertCreate_user_stmtContext(n) + + case *parser.Role_nameContext: + return c.convertRole_nameContext(n) + + case *parser.User_optionContext: + return c.convertUser_optionContext(n) + + case *parser.Create_group_stmtContext: + return c.convertCreate_group_stmtContext(n) + default: return todo("convert(case=default)", n) } diff --git a/internal/engine/ydb/utils.go b/internal/engine/ydb/utils.go index 0fe41d356f..5201e6c9dd 100755 --- a/internal/engine/ydb/utils.go +++ b/internal/engine/ydb/utils.go @@ -14,6 +14,8 @@ type objectRefProvider interface { Object_ref() parser.IObject_refContext } + + func parseTableName(ctx objectRefProvider) *ast.TableName { return parseObjectRef(ctx.Object_ref()) } diff --git a/internal/sql/ast/create_role_stmt.go b/internal/sql/ast/create_role_stmt.go index 144540863e..44565fb64c 100644 --- a/internal/sql/ast/create_role_stmt.go +++ b/internal/sql/ast/create_role_stmt.go @@ -4,6 +4,9 @@ type CreateRoleStmt struct { StmtType RoleStmtType Role *string Options *List + + // YDB specific + BindRole Node } func (n *CreateRoleStmt) Pos() int { diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index 8bc2b89e9d..715f976ae6 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -1,9 +1,5 @@ package ast -import ( - "fmt" -) - type DeleteStmt struct { Relations *List UsingClause *List @@ -24,7 +20,6 @@ func (n *DeleteStmt) Pos() int { } func (n *DeleteStmt) Format(buf *TrackedBuffer) { - fmt.Println("DeleteStmt.Format") if n == nil { return } diff --git a/internal/sql/ast/param_ref.go b/internal/sql/ast/param_ref.go index 8bd724993d..6ffe8cc5f0 100644 --- a/internal/sql/ast/param_ref.go +++ b/internal/sql/ast/param_ref.go @@ -6,6 +6,9 @@ type ParamRef struct { Number int Location int Dollar bool + + // YDB specific + Plike bool } func (n *ParamRef) Pos() int { @@ -16,5 +19,9 @@ func (n *ParamRef) Format(buf *TrackedBuffer) { if n == nil { return } + if n.Plike { + fmt.Fprintf(buf, "$p%d", n.Number) + return + } fmt.Fprintf(buf, "$%d", n.Number) } diff --git a/internal/sql/ast/pragma_stmt.go b/internal/sql/ast/pragma_stmt.go new file mode 100644 index 0000000000..46dc40dbf5 --- /dev/null +++ b/internal/sql/ast/pragma_stmt.go @@ -0,0 +1,43 @@ +package ast + +// YDB specific +type Pragma_stmt struct { + Name Node + Cols *List + Equals bool + Values *List + Location int +} + +func (n *Pragma_stmt) Pos() int { + return n.Location +} + +func (n *Pragma_stmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + + buf.WriteString("PRAGMA ") + if n.Name != nil { + buf.astFormat(n.Name) + } + if n.Cols != nil { + buf.astFormat(n.Cols) + } + + if n.Equals { + buf.WriteString(" = ") + } + + if n.Values != nil { + if n.Equals { + buf.astFormat(n.Values) + } else { + buf.WriteString("(") + buf.astFormat(n.Values) + buf.WriteString(")") + } + } + +} diff --git a/internal/sql/ast/role_spec.go b/internal/sql/ast/role_spec.go index fba4cecf7d..5b7c871c54 100644 --- a/internal/sql/ast/role_spec.go +++ b/internal/sql/ast/role_spec.go @@ -4,6 +4,8 @@ type RoleSpec struct { Roletype RoleSpecType Rolename *string Location int + + BindRolename Node } func (n *RoleSpec) Pos() int { diff --git a/internal/sql/ast/use_stmt.go b/internal/sql/ast/use_stmt.go new file mode 100644 index 0000000000..dee393c321 --- /dev/null +++ b/internal/sql/ast/use_stmt.go @@ -0,0 +1,11 @@ +package ast + +// YDB specific +type UseStmt struct { + Xpr Node + Location int +} + +func (n *UseStmt) Pos() int { + return n.Location +} diff --git a/internal/sql/astutils/rewrite.go b/internal/sql/astutils/rewrite.go index f5c8028ab6..8e8eefbff4 100644 --- a/internal/sql/astutils/rewrite.go +++ b/internal/sql/astutils/rewrite.go @@ -605,7 +605,9 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "Params", nil, n.Params) case *ast.CreateRoleStmt: + a.apply(n, "BindRole", nil, n.BindRole) a.apply(n, "Options", nil, n.Options) + case *ast.CreateSchemaStmt: a.apply(n, "Authrole", nil, n.Authrole) @@ -924,6 +926,11 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. case *ast.PartitionSpec: a.apply(n, "PartParams", nil, n.PartParams) + case *ast.Pragma_stmt: + a.apply(n, "Pragma", nil, n.Name) + a.apply(n, "Args", nil, n.Cols) + a.apply(n, "Options", nil, n.Values) + case *ast.PrepareStmt: a.apply(n, "Argtypes", nil, n.Argtypes) a.apply(n, "Query", nil, n.Query) @@ -1029,7 +1036,7 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "Val", nil, n.Val) case *ast.RoleSpec: - // pass + a.apply(n, "BindRolename", nil, n.BindRolename) case *ast.RowCompareExpr: a.apply(n, "Xpr", nil, n.Xpr) @@ -1166,6 +1173,9 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "ReturningList", nil, n.ReturningList) a.apply(n, "WithClause", nil, n.WithClause) + case *ast.UseStmt: + a.apply(n, "Xpr", nil, n.Xpr) + case *ast.VacuumStmt: a.apply(n, "Relation", nil, n.Relation) a.apply(n, "VaCols", nil, n.VaCols) diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 2f28ba0243..e7b78d126b 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -898,6 +898,9 @@ func Walk(f Visitor, node ast.Node) { } case *ast.CreateRoleStmt: + if n.BindRole != nil { + Walk(f, n.BindRole) + } if n.Options != nil { Walk(f, n.Options) } @@ -1516,6 +1519,17 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.PartParams) } + case *ast.Pragma_stmt: + if n.Name != nil { + Walk(f, n.Name) + } + if n.Cols != nil { + Walk(f, n.Cols) + } + if n.Values != nil { + Walk(f, n.Values) + } + case *ast.PrepareStmt: if n.Argtypes != nil { Walk(f, n.Argtypes) @@ -1758,7 +1772,9 @@ func Walk(f Visitor, node ast.Node) { } case *ast.RoleSpec: - // pass + if n.BindRolename != nil { + Walk(f, n.BindRolename) + } case *ast.RowCompareExpr: if n.Xpr != nil { @@ -2063,6 +2079,11 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.WithClause) } + case *ast.UseStmt: + if n.Xpr != nil { + Walk(f, n.Xpr) + } + case *ast.VacuumStmt: if n.Relation != nil { Walk(f, n.Relation) From e687cf6d78ab4d61d5ff0788c46869c76719b259 Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Mon, 19 May 2025 18:35:44 +0300 Subject: [PATCH 06/18] ADDED DROP/ALTER USER/GROUP --- .../ydb/catalog_tests/alter_group_test.go | 122 ++++++++++ .../ydb/catalog_tests/alter_user_test.go | 153 +++++++++++++ .../ydb/catalog_tests/drop_role_test.go | 87 +++++++ internal/engine/ydb/convert.go | 213 ++++++++++++++++-- internal/engine/ydb/utils.go | 29 +++ 5 files changed, 583 insertions(+), 21 deletions(-) create mode 100644 internal/engine/ydb/catalog_tests/alter_group_test.go create mode 100644 internal/engine/ydb/catalog_tests/alter_user_test.go create mode 100644 internal/engine/ydb/catalog_tests/drop_role_test.go diff --git a/internal/engine/ydb/catalog_tests/alter_group_test.go b/internal/engine/ydb/catalog_tests/alter_group_test.go new file mode 100644 index 0000000000..eef9f919e9 --- /dev/null +++ b/internal/engine/ydb/catalog_tests/alter_group_test.go @@ -0,0 +1,122 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestAlterGroup(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + { + stmt: `ALTER GROUP admins RENAME TO superusers`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.AlterRoleStmt{ + Role: &ast.RoleSpec{ + Rolename: strPtr("admins"), + Roletype: ast.RoleSpecType(1), + }, + Action: 1, + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("rename"), + Defaction: ast.DefElemAction(1), + Arg: &ast.String{Str: "superusers"}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `ALTER GROUP devs ADD USER alice, bob, carol`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.AlterRoleStmt{ + Role: &ast.RoleSpec{ + Rolename: strPtr("devs"), + Roletype: ast.RoleSpecType(1), + }, + Action: 1, + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("rolemembers"), + Defaction: ast.DefElemAction(3), + Arg: &ast.List{ + Items: []ast.Node{ + &ast.RoleSpec{Rolename: strPtr("alice"), Roletype: ast.RoleSpecType(1)}, + &ast.RoleSpec{Rolename: strPtr("bob"), Roletype: ast.RoleSpecType(1)}, + &ast.RoleSpec{Rolename: strPtr("carol"), Roletype: ast.RoleSpecType(1)}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `ALTER GROUP ops DROP USER dan, erin`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.AlterRoleStmt{ + Role: &ast.RoleSpec{ + Rolename: strPtr("ops"), + Roletype: ast.RoleSpecType(1), + }, + Action: 1, + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("rolemembers"), + Defaction: ast.DefElemAction(4), + Arg: &ast.List{ + Items: []ast.Node{ + &ast.RoleSpec{Rolename: strPtr("dan"), Roletype: ast.RoleSpecType(1)}, + &ast.RoleSpec{Rolename: strPtr("erin"), Roletype: ast.RoleSpecType(1)}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Запрос %q не распарсен", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + cmpopts.IgnoreFields(ast.DefElem{}, "Location"), + cmpopts.IgnoreFields(ast.RoleSpec{}, "Location"), + cmpopts.IgnoreFields(ast.A_Const{}, "Location"), + ) + if diff != "" { + t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/catalog_tests/alter_user_test.go b/internal/engine/ydb/catalog_tests/alter_user_test.go new file mode 100644 index 0000000000..dd4dc5bb93 --- /dev/null +++ b/internal/engine/ydb/catalog_tests/alter_user_test.go @@ -0,0 +1,153 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestAlterUser(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + { + stmt: `ALTER USER alice RENAME TO queen`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.AlterRoleStmt{ + Role: &ast.RoleSpec{ + Rolename: strPtr("alice"), + Roletype: ast.RoleSpecType(1), + }, + Action: 1, + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("rename"), + Arg: &ast.String{Str: "queen"}, + Defaction: ast.DefElemAction(1), + }, + }, + }, + }, + }, + }, + }, + { + stmt: `ALTER USER bob LOGIN`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.AlterRoleStmt{ + Role: &ast.RoleSpec{ + Rolename: strPtr("bob"), + Roletype: ast.RoleSpecType(1), + }, + Action: 1, + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("login"), + Arg: &ast.Boolean{Boolval: true}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `ALTER USER charlie NOLOGIN`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.AlterRoleStmt{ + Role: &ast.RoleSpec{ + Rolename: strPtr("charlie"), + Roletype: ast.RoleSpecType(1), + }, + Action: 1, + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("nologin"), + Arg: &ast.Boolean{Boolval: false}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `ALTER USER dave PASSWORD 'qwerty'`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.AlterRoleStmt{ + Role: &ast.RoleSpec{ + Rolename: strPtr("dave"), + Roletype: ast.RoleSpecType(1), + }, + Action: 1, + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("password"), + Arg: &ast.String{Str: "qwerty"}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `ALTER USER elena HASH 'abc123'`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.AlterRoleStmt{ + Role: &ast.RoleSpec{ + Rolename: strPtr("elena"), + Roletype: ast.RoleSpecType(1), + }, + Action: 1, + Options: &ast.List{ + Items: []ast.Node{ + &ast.DefElem{ + Defname: strPtr("hash"), + Arg: &ast.String{Str: "abc123"}, + }, + }, + }, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Запрос %q не распарсен", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + cmpopts.IgnoreFields(ast.DefElem{}, "Location"), + cmpopts.IgnoreFields(ast.RoleSpec{}, "Location"), + cmpopts.IgnoreFields(ast.A_Const{}, "Location"), + ) + if diff != "" { + t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + } + }) + } +} diff --git a/internal/engine/ydb/catalog_tests/drop_role_test.go b/internal/engine/ydb/catalog_tests/drop_role_test.go new file mode 100644 index 0000000000..1d7c6a7658 --- /dev/null +++ b/internal/engine/ydb/catalog_tests/drop_role_test.go @@ -0,0 +1,87 @@ +package ydb_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/engine/ydb" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestDropRole(t *testing.T) { + tests := []struct { + stmt string + expected ast.Node + }{ + { + stmt: `DROP USER user1;`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.DropRoleStmt{ + MissingOk: false, + Roles: &ast.List{ + Items: []ast.Node{ + &ast.RoleSpec{Rolename: strPtr("user1"), Roletype: ast.RoleSpecType(1)}, + }, + }, + }, + }, + }, + }, + { + stmt: "DROP USER IF EXISTS admin, user2", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.DropRoleStmt{ + MissingOk: true, + Roles: &ast.List{ + Items: []ast.Node{ + &ast.RoleSpec{Rolename: strPtr("admin"), Roletype: ast.RoleSpecType(1)}, + &ast.RoleSpec{Rolename: strPtr("user2"), Roletype: ast.RoleSpecType(1)}, + }, + }, + }, + }, + }, + }, + { + stmt: "DROP GROUP team1, team2", + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.DropRoleStmt{ + MissingOk: false, + Roles: &ast.List{ + Items: []ast.Node{ + &ast.RoleSpec{Rolename: strPtr("team1"), Roletype: ast.RoleSpecType(1)}, + &ast.RoleSpec{Rolename: strPtr("team2"), Roletype: ast.RoleSpecType(1)}, + }, + }, + }, + }, + }, + }, + } + + p := ydb.NewParser() + for _, tc := range tests { + t.Run(tc.stmt, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.stmt)) + if err != nil { + t.Fatalf("Error parsing %q: %v", tc.stmt, err) + } + if len(stmts) == 0 { + t.Fatalf("Statement %q was not parsed", tc.stmt) + } + + diff := cmp.Diff(tc.expected, &stmts[0], + cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), + cmpopts.IgnoreFields(ast.RoleSpec{}, "Location"), + ) + if diff != "" { + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) + } + }) + } +} diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index 6fd2fe0ea3..4dab9189a4 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -45,6 +45,186 @@ func NewIdentifier(t string) *ast.String { return &ast.String{Str: identifier(t)} } +func (c *cc) convertDrop_role_stmtCOntext(n *parser.Drop_role_stmtContext) ast.Node { + if n.DROP() == nil || (n.USER() == nil && n.GROUP() == nil) || len(n.AllRole_name()) == 0 { + return todo("Drop_role_stmtContext", n) + } + + stmt := &ast.DropRoleStmt{ + MissingOk: n.IF() != nil && n.EXISTS() != nil, + Roles: &ast.List{}, + } + + for _, role := range n.AllRole_name() { + member, isParam, _ := c.extractRoleSpec(role, ast.RoleSpecType(1)) + if member == nil { + return todo("Drop_role_stmtContext", n) + } + + if debug.Active && isParam { + log.Printf("YDB does not currently support parameters in the DROP ROLE statement") + } + + stmt.Roles.Items = append(stmt.Roles.Items, member) + } + + return stmt +} + +func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) ast.Node { + if n.ALTER() == nil || n.GROUP() == nil || len(n.AllRole_name()) == 0 { + return todo("convertAlter_group_stmtContext", n) + } + role, paramFlag, _ := c.extractRoleSpec(n.Role_name(0), ast.RoleSpecType(1)) + if role == nil { + return todo("convertAlter_group_stmtContext", n) + } + + if debug.Active && paramFlag { + log.Printf("YDB does not currently support parameters in the ALTER GROUP statement") + } + + stmt := &ast.AlterRoleStmt{ + Role: role, + Action: 1, + Options: &ast.List{}, + } + + switch { + case n.RENAME() != nil && n.TO() != nil && len(n.AllRole_name()) > 1: + newName := c.convert(n.Role_name(1)) + action := "rename" + + defElem := &ast.DefElem{ + Defname: &action, + Defaction: ast.DefElemAction(1), + Location: n.Role_name(1).GetStart().GetStart(), + } + + bindFlag := true + switch v := newName.(type) { + case *ast.A_Const: + switch val := v.Val.(type) { + case *ast.String: + bindFlag = false + defElem.Arg = val + case *ast.Boolean: + defElem.Arg = val + default: + return todo("convertAlter_group_stmtContext", n) + } + case *ast.ParamRef, *ast.A_Expr: + defElem.Arg = newName + default: + return todo("convertAlter_group_stmtContext", n) + } + + if debug.Active && !paramFlag && bindFlag { + log.Printf("YDB does not currently support parameters in the ALTER GROUP statement") + } + + stmt.Options.Items = append(stmt.Options.Items, defElem) + + case (n.ADD() != nil || n.DROP() != nil) && len(n.AllRole_name()) > 1: + defname := "rolemembers" + optionList := &ast.List{} + for _, role := range n.AllRole_name()[1:] { + member, isParam, _ := c.extractRoleSpec(role, ast.RoleSpecType(1)) + if member == nil { + return todo("convertAlter_group_stmtContext", n) + } + + if debug.Active && isParam && !paramFlag { + log.Printf("YDB does not currently support parameters in the ALTER GROUP statement") + } + + optionList.Items = append(optionList.Items, member) + } + + var action ast.DefElemAction + if n.ADD() != nil { + action = 3 + } else { + action = 4 + } + + stmt.Options.Items = append(stmt.Options.Items, &ast.DefElem{ + Defname: &defname, + Arg: optionList, + Defaction: action, + Location: n.GetStart().GetStart(), + }) + } + + return stmt +} + +func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast.Node { + if n.ALTER() == nil || n.USER() == nil || len(n.AllRole_name()) == 0 { + return todo("Alter_user_stmtContext", n) + } + + role, paramFlag, _ := c.extractRoleSpec(n.Role_name(0), ast.RoleSpecType(1)) + if role == nil { + return todo("convertAlter_group_stmtContext", n) + } + + if debug.Active && paramFlag { + log.Printf("YDB does not currently support parameters in the ALTER USER statement") + } + + stmt := &ast.AlterRoleStmt{ + Role: role, + Action: 1, + Options: &ast.List{}, + } + + switch { + case n.RENAME() != nil && n.TO() != nil && len(n.AllRole_name()) > 1: + newName := c.convert(n.Role_name(1)) + action := "rename" + + defElem := &ast.DefElem{ + Defname: &action, + Defaction: ast.DefElemAction(1), + Location: n.Role_name(1).GetStart().GetStart(), + } + + bindFlag := true + switch v := newName.(type) { + case *ast.A_Const: + switch val := v.Val.(type) { + case *ast.String: + bindFlag = false + defElem.Arg = val + case *ast.Boolean: + defElem.Arg = val + default: + return todo("Alter_user_stmtContext", n) + } + case *ast.ParamRef, *ast.A_Expr: + defElem.Arg = newName + default: + return todo("Alter_user_stmtContext", n) + } + + if debug.Active && !paramFlag && bindFlag { + log.Printf("YDB does not currently support parameters in the ALTER USER statement") + } + + stmt.Options.Items = append(stmt.Options.Items, defElem) + + case len(n.AllUser_option()) > 0: + for _, opt := range n.AllUser_option() { + if node := c.convert(opt); node != nil { + stmt.Options.Items = append(stmt.Options.Items, node) + } + } + } + + return stmt +} + func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) ast.Node { if n.CREATE() == nil || n.GROUP() == nil || len(n.AllRole_name()) == 0 { return todo("Create_group_stmtContext", n) @@ -82,26 +262,8 @@ func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) defname := "rolemembers" optionList := &ast.List{} for _, role := range n.AllRole_name()[1:] { - roleNode := c.convert(role) - roleSpec := &ast.RoleSpec{ - Roletype: ast.RoleSpecType(1), - Location: role.GetStart().GetStart(), - } - isParam := true - switch v := roleNode.(type) { - case *ast.A_Const: - switch val := v.Val.(type) { - case *ast.String: - isParam = false - roleSpec.Rolename = &val.Str - case *ast.Boolean: - roleSpec.BindRolename = roleNode - default: - return todo("convertCreate_group_stmtContext", n) - } - case *ast.ParamRef, *ast.A_Expr: - roleSpec.BindRolename = roleNode - default: + member, isParam, _ := c.extractRoleSpec(role, ast.RoleSpecType(1)) + if member == nil { return todo("convertCreate_group_stmtContext", n) } @@ -109,7 +271,7 @@ func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) log.Printf("YDB does not currently support parameters in the CREATE GROUP statement") } - optionList.Items = append(optionList.Items, roleSpec) + optionList.Items = append(optionList.Items, member) } stmt.Options.Items = append(stmt.Options.Items, &ast.DefElem{ @@ -2284,6 +2446,15 @@ func (c *cc) convert(node node) ast.Node { case *parser.Create_group_stmtContext: return c.convertCreate_group_stmtContext(n) + case *parser.Alter_user_stmtContext: + return c.convertAlter_user_stmtContext(n) + + case *parser.Alter_group_stmtContext: + return c.convertAlter_group_stmtContext(n) + + case *parser.Drop_role_stmtContext: + return c.convertDrop_role_stmtCOntext(n) + default: return todo("convert(case=default)", n) } diff --git a/internal/engine/ydb/utils.go b/internal/engine/ydb/utils.go index 5201e6c9dd..0748de8bdf 100755 --- a/internal/engine/ydb/utils.go +++ b/internal/engine/ydb/utils.go @@ -143,3 +143,32 @@ func parseIntegerValue(text string) (int64, error) { return strconv.ParseInt(text, base, 64) } + +func (c *cc) extractRoleSpec(n parser.IRole_nameContext, roletype ast.RoleSpecType) (*ast.RoleSpec, bool, ast.Node) { + roleNode := c.convert(n) + + roleSpec := &ast.RoleSpec{ + Roletype: roletype, + Location: n.GetStart().GetStart(), + } + + isParam := true + switch v := roleNode.(type) { + case *ast.A_Const: + switch val := v.Val.(type) { + case *ast.String: + roleSpec.Rolename = &val.Str + isParam = false + case *ast.Boolean: + roleSpec.BindRolename = roleNode + default: + return nil, false, nil + } + case *ast.ParamRef, *ast.A_Expr: + roleSpec.BindRolename = roleNode + default: + return nil, false, nil + } + + return roleSpec, isParam, roleNode +} From 0293f6afc046ca9fc7e590a38f1e2edc111d41e3 Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Wed, 21 May 2025 01:22:26 +0300 Subject: [PATCH 07/18] Added Funcs support! --- examples/authors/sqlc.yaml | 82 +++--- examples/authors/ydb/db.go | 2 +- examples/authors/ydb/models.go | 2 +- examples/authors/ydb/query.sql.go | 2 +- internal/engine/ydb/convert.go | 376 ++++++++++++++++-------- internal/engine/ydb/lib/aggregate.go | 330 +++++++++++++++++++++ internal/engine/ydb/lib/basic.go | 203 +++++++++++++ internal/engine/ydb/parse.go | 11 +- internal/engine/ydb/stdlib.go | 10 +- internal/engine/ydb/utils.go | 26 +- internal/sql/ast/insert_stmt.go | 29 +- internal/sql/ast/recursive_func_call.go | 33 +++ internal/sql/astutils/rewrite.go | 9 +- internal/sql/astutils/walk.go | 20 ++ 14 files changed, 950 insertions(+), 185 deletions(-) create mode 100644 internal/engine/ydb/lib/aggregate.go create mode 100644 internal/engine/ydb/lib/basic.go create mode 100644 internal/sql/ast/recursive_func_call.go diff --git a/examples/authors/sqlc.yaml b/examples/authors/sqlc.yaml index 30d904875e..8d6bc3db28 100644 --- a/examples/authors/sqlc.yaml +++ b/examples/authors/sqlc.yaml @@ -2,47 +2,47 @@ version: '2' cloud: project: "01HAQMMECEYQYKFJN8MP16QC41" sql: -# - name: postgresql -# schema: postgresql/schema.sql -# queries: postgresql/query.sql -# engine: postgresql -# database: -# uri: "${VET_TEST_EXAMPLES_POSTGRES_AUTHORS}" -# analyzer: -# database: false -# rules: -# - sqlc/db-prepare -# - postgresql-query-too-costly -# gen: -# go: -# package: authors -# sql_package: pgx/v5 -# out: postgresql -# - name: mysql -# schema: mysql/schema.sql -# queries: mysql/query.sql -# engine: mysql -# database: -# uri: "${VET_TEST_EXAMPLES_MYSQL_AUTHORS}" -# rules: -# - sqlc/db-prepare -# # - mysql-query-too-costly -# gen: -# go: -# package: authors -# out: mysql -# - name: sqlite -# schema: sqlite/schema.sql -# queries: sqlite/query.sql -# engine: sqlite -# database: -# uri: file:authors?mode=memory&cache=shared -# rules: -# - sqlc/db-prepare -# gen: -# go: -# package: authors -# out: sqlite +- name: postgresql + schema: postgresql/schema.sql + queries: postgresql/query.sql + engine: postgresql + database: + uri: "${VET_TEST_EXAMPLES_POSTGRES_AUTHORS}" + analyzer: + database: false + rules: + - sqlc/db-prepare + - postgresql-query-too-costly + gen: + go: + package: authors + sql_package: pgx/v5 + out: postgresql +- name: mysql + schema: mysql/schema.sql + queries: mysql/query.sql + engine: mysql + database: + uri: "${VET_TEST_EXAMPLES_MYSQL_AUTHORS}" + rules: + - sqlc/db-prepare + # - mysql-query-too-costly + gen: + go: + package: authors + out: mysql +- name: sqlite + schema: sqlite/schema.sql + queries: sqlite/query.sql + engine: sqlite + database: + uri: file:authors?mode=memory&cache=shared + rules: + - sqlc/db-prepare + gen: + go: + package: authors + out: sqlite - name: ydb schema: ydb/schema.sql queries: ydb/query.sql diff --git a/examples/authors/ydb/db.go b/examples/authors/ydb/db.go index 2bb1bfc27d..e2b0a86b13 100644 --- a/examples/authors/ydb/db.go +++ b/examples/authors/ydb/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 package authors diff --git a/examples/authors/ydb/models.go b/examples/authors/ydb/models.go index 337ea597f4..8edcdc7b33 100644 --- a/examples/authors/ydb/models.go +++ b/examples/authors/ydb/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 package authors diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index 64126bf254..e244f62c54 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 // source: query.sql package authors diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index 4dab9189a4..b4d9490d0b 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -13,6 +13,15 @@ import ( type cc struct { paramCount int + content string +} + +func (c *cc) pos(token antlr.Token) int { + if token == nil { + return 0 + } + runeIdx := token.GetStart() + return byteOffsetFromRuneIndex(c.content, runeIdx) } type node interface { @@ -52,7 +61,7 @@ func (c *cc) convertDrop_role_stmtCOntext(n *parser.Drop_role_stmtContext) ast.N stmt := &ast.DropRoleStmt{ MissingOk: n.IF() != nil && n.EXISTS() != nil, - Roles: &ast.List{}, + Roles: &ast.List{}, } for _, role := range n.AllRole_name() { @@ -98,7 +107,7 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a defElem := &ast.DefElem{ Defname: &action, Defaction: ast.DefElemAction(1), - Location: n.Role_name(1).GetStart().GetStart(), + Location: c.pos(n.Role_name(1).GetStart()), } bindFlag := true @@ -152,7 +161,7 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a Defname: &defname, Arg: optionList, Defaction: action, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), }) } @@ -187,7 +196,7 @@ func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast defElem := &ast.DefElem{ Defname: &action, Defaction: ast.DefElemAction(1), - Location: n.Role_name(1).GetStart().GetStart(), + Location: c.pos(n.Role_name(1).GetStart()), } bindFlag := true @@ -277,7 +286,7 @@ func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) stmt.Options.Items = append(stmt.Options.Items, &ast.DefElem{ Defname: &defname, Arg: optionList, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), }) } @@ -289,7 +298,7 @@ func (c *cc) convertUse_stmtContext(n *parser.Use_stmtContext) ast.Node { clusterExpr := c.convert(n.Cluster_expr()) stmt := &ast.UseStmt{ Xpr: clusterExpr, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } return stmt } @@ -306,7 +315,7 @@ func (c *cc) convertCluster_exprContext(n *parser.Cluster_exprContext) ast.Node name := parseAnId(anID) node = &ast.ColumnRef{ Fields: &ast.List{Items: []ast.Node{NewIdentifier(name)}}, - Location: anID.GetStart().GetStart(), + Location: c.pos(anID.GetStart()), } } else if bp := pureCtx.Bind_parameter(); bp != nil { node = c.convert(bp) @@ -323,7 +332,7 @@ func (c *cc) convertCluster_exprContext(n *parser.Cluster_exprContext) ast.Node Name: &ast.List{Items: []ast.Node{&ast.String{Str: ":"}}}, Lexpr: &ast.String{Str: name}, Rexpr: node, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } @@ -394,7 +403,7 @@ func (c *cc) convertUser_optionContext(n *parser.User_optionContext) ast.Node { return &ast.DefElem{ Defname: &name, Arg: password, - Location: pOpt.GetStart().GetStart(), + Location: c.pos(pOpt.GetStart()), } } } else if hOpt := aOpt.Hash_option(); hOpt != nil { @@ -409,7 +418,7 @@ func (c *cc) convertUser_optionContext(n *parser.User_optionContext) ast.Node { return &ast.DefElem{ Defname: &name, Arg: &ast.String{Str: pass}, - Location: hOpt.GetStart().GetStart(), + Location: c.pos(hOpt.GetStart()), } } @@ -424,7 +433,7 @@ func (c *cc) convertUser_optionContext(n *parser.User_optionContext) ast.Node { return &ast.DefElem{ Defname: &name, Arg: &ast.Boolean{Boolval: lOpt.LOGIN() != nil}, - Location: lOpt.GetStart().GetStart(), + Location: c.pos(lOpt.GetStart()), } default: return todo("convertUser_optionContext", n) @@ -436,7 +445,7 @@ func (c *cc) convertRole_nameContext(n *parser.Role_nameContext) ast.Node { switch { case n.An_id_or_type() != nil: name := parseAnIdOrType(n.An_id_or_type()) - return &ast.A_Const{Val: NewIdentifier(name), Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: NewIdentifier(name), Location: c.pos(n.GetStart())} case n.Bind_parameter() != nil: bindPar := c.convert(n.Bind_parameter()) return bindPar @@ -490,7 +499,8 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { for _, anID := range pureCols.AllAn_id() { name := identifier(parseAnId(anID)) cols.Items = append(cols.Items, &ast.ResTarget{ - Name: &name, + Name: &name, + Location: c.pos(anID.GetStart()), }) } } @@ -544,7 +554,7 @@ func (c *cc) convertPragma_stmtContext(n *parser.Pragma_stmtContext) ast.Node { stmt := &ast.Pragma_stmt{ Name: &ast.List{Items: items}, - Location: n.An_id().GetStart().GetStart(), + Location: c.pos(n.An_id().GetStart()), } if n.EQUALS() != nil { @@ -577,24 +587,24 @@ func (c *cc) convertPragma_valueContext(n *parser.Pragma_valueContext) ast.Node } return &ast.TODO{} } - return &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: c.pos(n.GetStart())} } if n.Signed_number().Real_() != nil { text := n.Signed_number().GetText() - return &ast.A_Const{Val: &ast.Float{Str: text}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.Float{Str: text}, Location: c.pos(n.GetStart())} } case n.STRING_VALUE() != nil: val := n.STRING_VALUE().GetText() if len(val) >= 2 { val = val[1 : len(val)-1] } - return &ast.A_Const{Val: &ast.String{Str: val}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.String{Str: val}, Location: c.pos(n.GetStart())} case n.Bool_value() != nil: var i bool if n.Bool_value().TRUE() != nil { i = true } - return &ast.A_Const{Val: &ast.Boolean{Boolval: i}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.Boolean{Boolval: i}, Location: c.pos(n.GetStart())} case n.Bind_parameter() != nil: bindPar := c.convert(n.Bind_parameter()) return bindPar @@ -628,8 +638,9 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { columnName := identifier(targetCtx.Column_name().GetText()) expr := c.convert(clause.Expr()) resTarget := &ast.ResTarget{ - Name: &columnName, - Val: expr, + Name: &columnName, + Val: expr, + Location: c.pos(clause.Expr().GetStart()), } setList.Items = append(setList.Items, resTarget) } @@ -664,6 +675,7 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { Colno: i + 1, Ncolumns: len(colNames), }, + Location: c.pos(targetsCtx.Set_target(i).GetStart()), }) } } @@ -682,7 +694,8 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { for _, anID := range pureCols.AllAn_id() { name := identifier(parseAnId(anID)) cols.Items = append(cols.Items, &ast.ResTarget{ - Name: &name, + Name: &name, + Location: c.pos(anID.GetStart()), }) } } @@ -725,7 +738,10 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast.Node { tableName := identifier(n.Into_simple_table_ref().Simple_table_ref().Simple_table_ref_core().GetText()) - rel := &ast.RangeVar{Relname: &tableName} + rel := &ast.RangeVar{ + Relname: &tableName, + Location: c.pos(n.Into_simple_table_ref().GetStart()), + } onConflict := &ast.OnConflictClause{} switch { @@ -751,7 +767,8 @@ func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast for _, anID := range pureCols.AllAn_id() { name := identifier(parseAnId(anID)) cols.Items = append(cols.Items, &ast.ResTarget{ - Name: &name, + Name: &name, + Location: c.pos(anID.GetStart()), }) } } @@ -816,9 +833,9 @@ func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_li Indirection: &ast.List{}, Val: &ast.ColumnRef{ Fields: &ast.List{Items: []ast.Node{&ast.A_Star{}}}, - Location: n.ASTERISK().GetSymbol().GetStart(), + Location: c.pos(n.ASTERISK().GetSymbol()), }, - Location: n.ASTERISK().GetSymbol().GetStart(), + Location: c.pos(n.ASTERISK().GetSymbol()), } list.Items = append(list.Items, target) return list @@ -831,9 +848,9 @@ func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_li Fields: &ast.List{ Items: []ast.Node{NewIdentifier(parseAnId(idCtx))}, }, - Location: idCtx.GetStart().GetStart(), + Location: c.pos(idCtx.GetStart()), }, - Location: idCtx.GetStart().GetStart(), + Location: c.pos(idCtx.GetStart()), } list.Items = append(list.Items, target) } @@ -945,7 +962,7 @@ func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { // todo: support opt_id_prefix target := &ast.ResTarget{ - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } var val ast.Node iexpr := n.Expr() @@ -1091,7 +1108,7 @@ func (c *cc) convertSingleSource(n *parser.Single_sourceContext) ast.Node { tableName := n.Table_ref().GetText() // !! debug !! return &ast.RangeVar{ Relname: &tableName, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } @@ -1111,10 +1128,10 @@ func (c *cc) convertBindParameter(n *parser.Bind_parameterContext) ast.Node { // !!debug later!! if n.DOLLAR() != nil { if n.TRUE() != nil { - return &ast.A_Const{Val: &ast.Boolean{Boolval: true}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.Boolean{Boolval: true}, Location: c.pos(n.GetStart())} } if n.FALSE() != nil { - return &ast.A_Const{Val: &ast.Boolean{Boolval: false}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.Boolean{Boolval: false}, Location: c.pos(n.GetStart())} } if an := n.An_id_or_type(); an != nil { @@ -1122,13 +1139,13 @@ func (c *cc) convertBindParameter(n *parser.Bind_parameterContext) ast.Node { return &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: "@"}}}, Rexpr: &ast.String{Str: idText}, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } c.paramCount++ return &ast.ParamRef{ Number: c.paramCount, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), Dollar: true, } } @@ -1147,7 +1164,7 @@ func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef items = append(items, &ast.A_Star{}) return &ast.ColumnRef{ Fields: &ast.List{Items: items}, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } @@ -1250,7 +1267,6 @@ func (c *cc) convertTypeName(n parser.IType_nameContext) *ast.TypeName { return nil } - // Handle composite types if composite := n.Type_name_composite(); composite != nil { if node := c.convertTypeNameComposite(composite); node != nil { if typeName, ok := node.(*ast.TypeName); ok { @@ -1259,7 +1275,6 @@ func (c *cc) convertTypeName(n parser.IType_nameContext) *ast.TypeName { } } - // Handle decimal type (e.g., DECIMAL(10,2)) if decimal := n.Type_name_decimal(); decimal != nil { if integerOrBinds := decimal.AllInteger_or_bind(); len(integerOrBinds) >= 2 { return &ast.TypeName{ @@ -1697,7 +1712,7 @@ func (c *cc) convertExpr(n *parser.ExprContext) ast.Node { left = &ast.BoolExpr{ Boolop: ast.BoolExprTypeOr, Args: &ast.List{Items: []ast.Node{left, right}}, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } return left @@ -1726,7 +1741,7 @@ func (c *cc) convertOrSubExpr(n *parser.Or_subexprContext) ast.Node { left = &ast.BoolExpr{ Boolop: ast.BoolExprTypeAnd, Args: &ast.List{Items: []ast.Node{left, right}}, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } return left @@ -1758,7 +1773,7 @@ func (c *cc) convertAndSubexpr(n *parser.And_subexprContext) ast.Node { Name: &ast.List{Items: []ast.Node{&ast.String{Str: "XOR"}}}, Lexpr: left, Rexpr: right, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } return left @@ -1799,32 +1814,32 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { Left: c.convert(eqSubs[0]), Right: c.convert(eqSubs[1]), Not: condCtx.NOT() != nil, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } case condCtx.ISNULL() != nil: return &ast.NullTest{ Arg: base, Nulltesttype: 1, // IS NULL - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } case condCtx.NOTNULL() != nil: return &ast.NullTest{ Arg: base, Nulltesttype: 2, // IS NOT NULL - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } case condCtx.IS() != nil && condCtx.NULL() != nil: return &ast.NullTest{ Arg: base, Nulltesttype: 1, // IS NULL - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } case condCtx.IS() != nil && condCtx.NOT() != nil && condCtx.NULL() != nil: return &ast.NullTest{ Arg: base, Nulltesttype: 2, // IS NOT NULL - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } case condCtx.Match_op() != nil: // debug!!! @@ -1908,7 +1923,7 @@ func (c *cc) convertEqSubexpr(n *parser.Eq_subexprContext) ast.Node { Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, Lexpr: left, Rexpr: right, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } return left @@ -1953,7 +1968,7 @@ func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, Lexpr: left, Rexpr: right, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } @@ -1969,7 +1984,7 @@ func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { Name: &ast.List{Items: []ast.Node{&ast.String{Str: "??"}}}, Lexpr: left, Rexpr: right, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } } else { @@ -1983,7 +1998,7 @@ func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: questionOp}}}, Lexpr: left, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } } @@ -2018,7 +2033,7 @@ func (c *cc) convertBitSubexpr(n *parser.Bit_subexprContext) ast.Node { Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, Lexpr: left, Rexpr: right, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } return left @@ -2051,7 +2066,7 @@ func (c *cc) convertAddSubexpr(n *parser.Add_subexprContext) ast.Node { Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, Lexpr: left, Rexpr: right, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } return left @@ -2080,7 +2095,7 @@ func (c *cc) convertMulSubexpr(n *parser.Mul_subexprContext) ast.Node { Name: &ast.List{Items: []ast.Node{&ast.String{Str: "||"}}}, Lexpr: left, Rexpr: right, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } return left @@ -2093,7 +2108,7 @@ func (c *cc) convertConSubexpr(n *parser.Con_subexprContext) ast.Node { return &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, Rexpr: operand, - Location: n.GetStart().GetStart(), + Location: c.pos(n.GetStart()), } } return c.convertUnarySubexpr(n.Unary_subexpr().(*parser.Unary_subexprContext)) @@ -2110,89 +2125,211 @@ func (c *cc) convertUnarySubexpr(n *parser.Unary_subexprContext) ast.Node { } func (c *cc) convertJsonApiExpr(n *parser.Json_api_exprContext) ast.Node { - return &ast.TODO{} // todo + return todo("Json_api_exprContext", n) } func (c *cc) convertUnaryCasualSubexpr(n *parser.Unary_casual_subexprContext) ast.Node { - var baseExpr ast.Node + var current ast.Node + switch { + case n.Id_expr() != nil: + current = c.convertIdExpr(n.Id_expr().(*parser.Id_exprContext)) + case n.Atom_expr() != nil: + current = c.convertAtomExpr(n.Atom_expr().(*parser.Atom_exprContext)) + default: + return todo("Unary_casual_subexprContext", n) + } - if idExpr := n.Id_expr(); idExpr != nil { - baseExpr = c.convertIdExpr(idExpr.(*parser.Id_exprContext)) - } else if atomExpr := n.Atom_expr(); atomExpr != nil { - baseExpr = c.convertAtomExpr(atomExpr.(*parser.Atom_exprContext)) + if suffix := n.Unary_subexpr_suffix(); suffix != nil { + current = c.processSuffixChain(current, suffix.(*parser.Unary_subexpr_suffixContext)) } - suffixCtx := n.Unary_subexpr_suffix() - if suffixCtx != nil { - ctx, ok := suffixCtx.(*parser.Unary_subexpr_suffixContext) - if !ok { - return baseExpr + return current +} + +func (c *cc) processSuffixChain(base ast.Node, suffix *parser.Unary_subexpr_suffixContext) ast.Node { + current := base + for i := 0; i < suffix.GetChildCount(); i++ { + child := suffix.GetChild(i) + switch elem := child.(type) { + case *parser.Key_exprContext: + current = c.handleKeySuffix(current, elem) + case *parser.Invoke_exprContext: + current = c.handleInvokeSuffix(current, elem, i) + case antlr.TerminalNode: + if elem.GetText() == "." { + current = c.handleDotSuffix(current, suffix, &i) + } } - baseExpr = c.convertUnarySubexprSuffix(baseExpr, ctx) } - - return baseExpr + return current } -func (c *cc) convertUnarySubexprSuffix(base ast.Node, n *parser.Unary_subexpr_suffixContext) ast.Node { - if n == nil { - return base +func (c *cc) handleKeySuffix(base ast.Node, keyCtx *parser.Key_exprContext) ast.Node { + keyNode := c.convertKey_exprContext(keyCtx) + ind, ok := keyNode.(*ast.A_Indirection) + if !ok { + return todo("Key_exprContext", keyCtx) } - colRef, ok := base.(*ast.ColumnRef) + + if indirection, ok := base.(*ast.A_Indirection); ok { + indirection.Indirection.Items = append(indirection.Indirection.Items, ind.Indirection.Items...) + return indirection + } + + return &ast.A_Indirection{ + Arg: base, + Indirection: &ast.List{ + Items: []ast.Node{keyNode}, + }, + } +} + +func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprContext, idx int) ast.Node { + funcCall, ok := c.convertInvoke_exprContext(invokeCtx).(*ast.FuncCall) if !ok { - return base // todo: cover case when unary subexpr with atomic expr + return todo("Invoke_exprContext", invokeCtx) } - for i := 0; i < n.GetChildCount(); i++ { - child := n.GetChild(i) - switch v := child.(type) { - case parser.IKey_exprContext: - node := c.convert(v.(*parser.Key_exprContext)) - if node != nil { - colRef.Fields.Items = append(colRef.Fields.Items, node) - } + if idx == 0 { + switch baseNode := base.(type) { + case *ast.ColumnRef: + if len(baseNode.Fields.Items) > 0 { + var nameParts []string + for _, item := range baseNode.Fields.Items { + if s, ok := item.(*ast.String); ok { + nameParts = append(nameParts, s.Str) + } + } + funcName := strings.Join(nameParts, ".") - case parser.IInvoke_exprContext: - node := c.convert(v.(*parser.Invoke_exprContext)) - if node != nil { - colRef.Fields.Items = append(colRef.Fields.Items, node) - } - case antlr.TerminalNode: - if v.GetText() == "." { - if i+1 < n.GetChildCount() { - next := n.GetChild(i + 1) - switch w := next.(type) { - case parser.IBind_parameterContext: - // !!! debug !!! - node := c.convert(next.(*parser.Bind_parameterContext)) - colRef.Fields.Items = append(colRef.Fields.Items, node) - case antlr.TerminalNode: - // !!! debug !!! - val, err := parseIntegerValue(w.GetText()) - if err != nil { - if debug.Active { - log.Printf("Failed to parse integer value '%s': %v", w.GetText(), err) - } - return &ast.TODO{} - } - node := &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: n.GetStart().GetStart()} - colRef.Fields.Items = append(colRef.Fields.Items, node) - case parser.IAn_id_or_typeContext: - idText := parseAnIdOrType(w) - colRef.Fields.Items = append(colRef.Fields.Items, &ast.String{Str: idText}) - default: - colRef.Fields.Items = append(colRef.Fields.Items, &ast.TODO{}) + if funcName == "coalesce" { + return &ast.CoalesceExpr{ + Args: funcCall.Args, + Location: baseNode.Location, } - i++ } + + funcCall.Func = &ast.FuncName{Name: funcName} + funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: funcName}) + + return funcCall } + default: + return todo("Invoke_exprContext", invokeCtx) + } + } + + stmt := &ast.RecursiveFuncCall{ + Func: base, + Funcname: funcCall.Funcname, + AggStar: funcCall.AggStar, + Location: funcCall.Location, + Args: funcCall.Args, + AggDistinct: funcCall.AggDistinct, + } + stmt.Funcname.Items = append(stmt.Funcname.Items, base) + return stmt +} + +func (c *cc) handleDotSuffix(base ast.Node, suffix *parser.Unary_subexpr_suffixContext, idx *int) ast.Node { + if *idx+1 >= suffix.GetChildCount() { + return base + } + + next := suffix.GetChild(*idx + 1) + *idx++ + + var field ast.Node + switch v := next.(type) { + case *parser.Bind_parameterContext: + field = c.convertBindParameter(v) + case *parser.An_id_or_typeContext: + field = &ast.String{Str: parseAnIdOrType(v)} + case antlr.TerminalNode: + if val, err := parseIntegerValue(v.GetText()); err == nil { + field = &ast.A_Const{Val: &ast.Integer{Ival: val}} + } else { + return &ast.TODO{} } } - if n.COLLATE() != nil && n.An_id() != nil { //nolint - // todo: Handle COLLATE + if field == nil { + return base } - return colRef + + if cr, ok := base.(*ast.ColumnRef); ok { + cr.Fields.Items = append(cr.Fields.Items, field) + return cr + } + return &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{base, field}}, + } +} + +func (c *cc) convertKey_exprContext(n *parser.Key_exprContext) ast.Node { + if n.LBRACE_SQUARE() == nil || n.RBRACE_SQUARE() == nil || n.Expr() == nil { + return todo("Key_exprContext", n) + } + + stmt := &ast.A_Indirection{ + Indirection: &ast.List{}, + } + + expr := c.convert(n.Expr()) + + stmt.Indirection.Items = append(stmt.Indirection.Items, &ast.A_Indices{ + Uidx: expr, + }) + + return stmt +} + +func (c *cc) convertInvoke_exprContext(n *parser.Invoke_exprContext) ast.Node { + if n.LPAREN() == nil || n.RPAREN() == nil { + return todo("Invoke_exprContext", n) + } + + distinct := false + if n.Opt_set_quantifier() != nil { + distinct = n.Opt_set_quantifier().DISTINCT() != nil + } + + stmt := &ast.FuncCall{ + AggDistinct: distinct, + Funcname: &ast.List{}, + AggOrder: &ast.List{}, + Args: &ast.List{}, + Location: c.pos(n.GetStart()), + } + + if nList := n.Named_expr_list(); nList != nil { + for _, namedExpr := range nList.AllNamed_expr() { + name := parseAnIdOrType(namedExpr.An_id_or_type()) + expr := c.convert(namedExpr.Expr()) + + var res ast.Node + if rt, ok := expr.(*ast.ResTarget); ok { + if name != "" { + rt.Name = &name + } + res = rt + } else if name != "" { + res = &ast.ResTarget{ + Name: &name, + Val: expr, + Location: c.pos(namedExpr.Expr().GetStart()), + } + } else { + res = expr + } + + stmt.Args.Items = append(stmt.Args.Items, res) + } + } else if n.ASTERISK() != nil { + stmt.AggStar = true + } + + return stmt } func (c *cc) convertIdExpr(n *parser.Id_exprContext) ast.Node { @@ -2203,6 +2340,7 @@ func (c *cc) convertIdExpr(n *parser.Id_exprContext) ast.Node { NewIdentifier(id.GetText()), }, }, + Location: c.pos(id.GetStart()), } } return &ast.TODO{} @@ -2210,6 +2348,8 @@ func (c *cc) convertIdExpr(n *parser.Id_exprContext) ast.Node { func (c *cc) convertAtomExpr(n *parser.Atom_exprContext) ast.Node { switch { + case n.An_id_or_type() != nil && n.NAMESPACE() != nil: + return NewIdentifier(parseAnIdOrType(n.An_id_or_type()) + "::" + parseIdOrType(n.Id_or_type())) case n.An_id_or_type() != nil: return NewIdentifier(parseAnIdOrType(n.An_id_or_type())) case n.Literal_value() != nil: @@ -2232,25 +2372,25 @@ func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { } return &ast.TODO{} } - return &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: c.pos(n.GetStart())} case n.Real_() != nil: text := n.Real_().GetText() - return &ast.A_Const{Val: &ast.Float{Str: text}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.Float{Str: text}, Location: c.pos(n.GetStart())} case n.STRING_VALUE() != nil: // !!! debug !!! (problem with quoted strings) val := n.STRING_VALUE().GetText() if len(val) >= 2 { val = val[1 : len(val)-1] } - return &ast.A_Const{Val: &ast.String{Str: val}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.String{Str: val}, Location: c.pos(n.GetStart())} case n.Bool_value() != nil: var i bool if n.Bool_value().TRUE() != nil { i = true } - return &ast.A_Const{Val: &ast.Boolean{Boolval: i}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.Boolean{Boolval: i}, Location: c.pos(n.GetStart())} case n.NULL() != nil: return &ast.Null{} @@ -2275,7 +2415,7 @@ func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { case n.BLOB() != nil: blobText := n.BLOB().GetText() - return &ast.A_Const{Val: &ast.String{Str: blobText}, Location: n.GetStart().GetStart()} + return &ast.A_Const{Val: &ast.String{Str: blobText}, Location: c.pos(n.GetStart())} case n.EMPTY_ACTION() != nil: if debug.Active { diff --git a/internal/engine/ydb/lib/aggregate.go b/internal/engine/ydb/lib/aggregate.go new file mode 100644 index 0000000000..dfb3924e90 --- /dev/null +++ b/internal/engine/ydb/lib/aggregate.go @@ -0,0 +1,330 @@ +package lib + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func AggregateFunctions() []*catalog.Function { + var funcs []*catalog.Function + + // COUNT(*) + funcs = append(funcs, &catalog.Function{ + Name: "COUNT", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }) + + // COUNT(T) и COUNT(T?) + for _, typ := range types { + funcs = append(funcs, &catalog.Function{ + Name: "COUNT", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }) + funcs = append(funcs, &catalog.Function{ + Name: "COUNT", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}, Mode: ast.FuncParamVariadic}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }) + } + + // MIN и MAX + for _, typ := range types { + funcs = append(funcs, &catalog.Function{ + Name: "MIN", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: true, + }) + funcs = append(funcs, &catalog.Function{ + Name: "MAX", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: true, + }) + } + + // SUM для unsigned типов + for _, typ := range unsignedTypes { + funcs = append(funcs, &catalog.Function{ + Name: "SUM", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }) + } + + // SUM для signed типов + for _, typ := range signedTypes { + funcs = append(funcs, &catalog.Function{ + Name: "SUM", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnTypeNullable: true, + }) + } + + // SUM для float/double + for _, typ := range []string{"float", "double"} { + funcs = append(funcs, &catalog.Function{ + Name: "SUM", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: true, + }) + } + + // AVG для целочисленных типов + for _, typ := range append(unsignedTypes, signedTypes...) { + funcs = append(funcs, &catalog.Function{ + Name: "AVG", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }) + } + + // AVG для float/double + for _, typ := range []string{"float", "double"} { + funcs = append(funcs, &catalog.Function{ + Name: "AVG", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: true, + }) + } + + // COUNT_IF + funcs = append(funcs, &catalog.Function{ + Name: "COUNT_IF", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }) + + // SUM_IF для unsigned + for _, typ := range unsignedTypes { + funcs = append(funcs, &catalog.Function{ + Name: "SUM_IF", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }) + } + + // SUM_IF для signed + for _, typ := range signedTypes { + funcs = append(funcs, &catalog.Function{ + Name: "SUM_IF", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnTypeNullable: true, + }) + } + + // SUM_IF для float/double + for _, typ := range []string{"float", "double"} { + funcs = append(funcs, &catalog.Function{ + Name: "SUM_IF", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: true, + }) + } + + // AVG_IF для целочисленных + for _, typ := range append(unsignedTypes, signedTypes...) { + funcs = append(funcs, &catalog.Function{ + Name: "AVG_IF", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }) + } + + // AVG_IF для float/double + for _, typ := range []string{"float", "double"} { + funcs = append(funcs, &catalog.Function{ + Name: "AVG_IF", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: true, + }) + } + + // SOME + for _, typ := range types { + funcs = append(funcs, &catalog.Function{ + Name: "SOME", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: true, + }) + } + + // AGGREGATE_LIST и AGGREGATE_LIST_DISTINCT + for _, typ := range types { + funcs = append(funcs, &catalog.Function{ + Name: "AGGREGATE_LIST", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "List<" + typ + ">"}, + }) + funcs = append(funcs, &catalog.Function{ + Name: "AGGREGATE_LIST_DISTINCT", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "List<" + typ + ">"}, + }) + } + + // BOOL_AND, BOOL_OR, BOOL_XOR + boolAggrs := []string{"BOOL_AND", "BOOL_OR", "BOOL_XOR"} + for _, name := range boolAggrs { + funcs = append(funcs, &catalog.Function{ + Name: name, + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + ReturnTypeNullable: true, + }) + } + + // BIT_AND, BIT_OR, BIT_XOR + bitAggrs := []string{"BIT_AND", "BIT_OR", "BIT_XOR"} + for _, typ := range append(unsignedTypes, signedTypes...) { + for _, name := range bitAggrs { + funcs = append(funcs, &catalog.Function{ + Name: name, + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: true, + }) + } + } + + // STDDEV и VARIANCE + stdDevVariants := []struct { + name string + returnType string + }{ + {"STDDEV", "Double"}, + {"VARIANCE", "Double"}, + {"STDDEV_SAMPLE", "Double"}, + {"VARIANCE_SAMPLE", "Double"}, + {"STDDEV_POPULATION", "Double"}, + {"VARIANCE_POPULATION", "Double"}, + } + for _, variant := range stdDevVariants { + funcs = append(funcs, &catalog.Function{ + Name: variant.name, + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: variant.returnType}, + ReturnTypeNullable: true, + }) + } + + // CORRELATION и COVARIANCE + corrCovar := []string{"CORRELATION", "COVARIANCE", "COVARIANCE_SAMPLE", "COVARIANCE_POPULATION"} + for _, name := range corrCovar { + funcs = append(funcs, &catalog.Function{ + Name: name, + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }) + } + + // HISTOGRAM + funcs = append(funcs, &catalog.Function{ + Name: "HISTOGRAM", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "HistogramStruct"}, + ReturnTypeNullable: true, + }) + + // TOP и BOTTOM + topBottom := []string{"TOP", "BOTTOM"} + for _, name := range topBottom { + for _, typ := range types { + funcs = append(funcs, &catalog.Function{ + Name: name, + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "List<" + typ + ">"}, + }) + } + } + + // MAX_BY и MIN_BY + minMaxBy := []string{"MAX_BY", "MIN_BY"} + for _, name := range minMaxBy { + for _, typ := range types { + funcs = append(funcs, &catalog.Function{ + Name: name, + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: true, + }) + } + } + + // ... (добавьте другие агрегатные функции по аналогии) + + return funcs +} diff --git a/internal/engine/ydb/lib/basic.go b/internal/engine/ydb/lib/basic.go new file mode 100644 index 0000000000..08c0011787 --- /dev/null +++ b/internal/engine/ydb/lib/basic.go @@ -0,0 +1,203 @@ +package lib + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +var types = []string{ + "bool", + "int8", "int16", "int32", "int64", + "uint8", "uint16", "uint32", "uint64", + "float", "double", + "string", "utf8", + "any", +} + +var ( + unsignedTypes = []string{"uint8", "uint16", "uint32", "uint64"} + signedTypes = []string{"int8", "int16", "int32", "int64"} + numericTypes = append(append(unsignedTypes, signedTypes...), "float", "double") +) + +func BasicFunctions() []*catalog.Function { + var funcs []*catalog.Function + + for _, typ := range types { + // COALESCE, NVL + funcs = append(funcs, &catalog.Function{ + Name: "COALESCE", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: typ}}, + { + Type: &ast.TypeName{Name: typ}, + Mode: ast.FuncParamVariadic, + }, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: false, + }) + funcs = append(funcs, &catalog.Function{ + Name: "NVL", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: typ}}, + { + Type: &ast.TypeName{Name: typ}, + Mode: ast.FuncParamVariadic, + }, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: false, + }) + + // IF(Bool, T, T) -> T + funcs = append(funcs, &catalog.Function{ + Name: "IF", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + ReturnTypeNullable: false, + }) + + // LENGTH, LEN + funcs = append(funcs, &catalog.Function{ + Name: "LENGTH", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + ReturnTypeNullable: true, + }) + funcs = append(funcs, &catalog.Function{ + Name: "LEN", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + ReturnTypeNullable: true, + }) + + // StartsWith, EndsWith + funcs = append(funcs, &catalog.Function{ + Name: "StartsWith", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }) + funcs = append(funcs, &catalog.Function{ + Name: "EndsWith", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }) + + // ABS(T) -> T + } + + // SUBSTRING + funcs = append(funcs, &catalog.Function{ + Name: "Substring", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }) + funcs = append(funcs, &catalog.Function{ + Name: "Substring", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }) + funcs = append(funcs, &catalog.Function{ + Name: "Substring", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }) + + // FIND / RFIND + for _, name := range []string{"FIND", "RFIND"} { + for _, typ := range []string{"String", "Utf8"} { + funcs = append(funcs, &catalog.Function{ + Name: name, + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }) + funcs = append(funcs, &catalog.Function{ + Name: name, + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }) + } + } + + for _, typ := range numericTypes { + funcs = append(funcs, &catalog.Function{ + Name: "Abs", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: typ}}, + }, + ReturnType: &ast.TypeName{Name: typ}, + }) + } + + // NANVL + funcs = append(funcs, &catalog.Function{ + Name: "NANVL", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Float"}}, + {Type: &ast.TypeName{Name: "Float"}}, + }, + ReturnType: &ast.TypeName{Name: "Float"}, + }) + funcs = append(funcs, &catalog.Function{ + Name: "NANVL", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }) + + // Random* + funcs = append(funcs, &catalog.Function{ + Name: "Random", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "Double"}, + }) + funcs = append(funcs, &catalog.Function{ + Name: "RandomNumber", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }) + funcs = append(funcs, &catalog.Function{ + Name: "RandomUuid", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "Uuid"}, + }) + + // todo: add all remain functions + + return funcs +} diff --git a/internal/engine/ydb/parse.go b/internal/engine/ydb/parse.go index 797710988c..1c263924a5 100755 --- a/internal/engine/ydb/parse.go +++ b/internal/engine/ydb/parse.go @@ -42,7 +42,8 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { if err != nil { return nil, err } - input := antlr.NewInputStream(string(blob)) + content := string(blob) + input := antlr.NewInputStream(content) lexer := parser.NewYQLLexer(input) stream := antlr.NewCommonTokenStream(lexer, 0) pp := parser.NewYQLParser(stream) @@ -62,14 +63,14 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { if stmtListCtx != nil { loc := 0 for _, stmt := range stmtListCtx.AllSql_stmt() { - converter := &cc{} + converter := &cc{content: string(blob)} out := converter.convert(stmt) if _, ok := out.(*ast.TODO); ok { - loc = stmt.GetStop().GetStop() + 2 + loc = byteOffset(content, stmt.GetStop().GetStop() + 2) continue } if out != nil { - len := (stmt.GetStop().GetStop() + 1) - loc + len := byteOffset(content, stmt.GetStop().GetStop() + 1) - loc stmts = append(stmts, ast.Statement{ Raw: &ast.RawStmt{ Stmt: out, @@ -77,7 +78,7 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { StmtLen: len, }, }) - loc = stmt.GetStop().GetStop() + 2 + loc = byteOffset(content, stmt.GetStop().GetStop() + 2) } } } diff --git a/internal/engine/ydb/stdlib.go b/internal/engine/ydb/stdlib.go index fd78d7de38..21dc21242b 100644 --- a/internal/engine/ydb/stdlib.go +++ b/internal/engine/ydb/stdlib.go @@ -1,12 +1,18 @@ package ydb import ( + "github.com/sqlc-dev/sqlc/internal/engine/ydb/lib" "github.com/sqlc-dev/sqlc/internal/sql/catalog" ) func defaultSchema(name string) *catalog.Schema { - s := &catalog.Schema{Name: name} - s.Funcs = []*catalog.Function{} + s := &catalog.Schema{ + Name: name, + Funcs: make([]*catalog.Function, 0, 128), + } + + s.Funcs = append(s.Funcs, lib.BasicFunctions()...) + s.Funcs = append(s.Funcs, lib.AggregateFunctions()...) return s } diff --git a/internal/engine/ydb/utils.go b/internal/engine/ydb/utils.go index 0748de8bdf..3847ee5055 100755 --- a/internal/engine/ydb/utils.go +++ b/internal/engine/ydb/utils.go @@ -3,6 +3,7 @@ package ydb import ( "strconv" "strings" + "unicode/utf8" "github.com/antlr4-go/antlr/v4" "github.com/sqlc-dev/sqlc/internal/sql/ast" @@ -14,8 +15,6 @@ type objectRefProvider interface { Object_ref() parser.IObject_refContext } - - func parseTableName(ctx objectRefProvider) *ast.TableName { return parseObjectRef(ctx.Object_ref()) } @@ -172,3 +171,26 @@ func (c *cc) extractRoleSpec(n parser.IRole_nameContext, roletype ast.RoleSpecTy return roleSpec, isParam, roleNode } + +func byteOffset(s string, runeIndex int) int { + count := 0 + for i := range s { + if count == runeIndex { + return i + } + count++ + } + return len(s) +} + +func byteOffsetFromRuneIndex(s string, runeIndex int) int { + if runeIndex <= 0 { + return 0 + } + bytePos := 0 + for i := 0; i < runeIndex && bytePos < len(s); i++ { + _, size := utf8.DecodeRuneInString(s[bytePos:]) + bytePos += size + } + return bytePos +} diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 954fb4665c..b3e7c60809 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -23,19 +23,22 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { buf.astFormat(n.WithClause) buf.WriteString(" ") } - - switch n.OnConflictClause.Action { - case OnConflictAction_INSERT_OR_ABORT: - buf.WriteString("INSERT OR ABORT INTO ") - case OnConflictAction_INSERT_OR_REVERT: - buf.WriteString("INSERT OR REVERT INTO ") - case OnConflictAction_INSERT_OR_IGNORE: - buf.WriteString("INSERT OR IGNORE INTO ") - case OnConflictAction_UPSERT: - buf.WriteString("UPSERT INTO ") - case OnConflictAction_REPLACE: - buf.WriteString("REPLACE INTO ") - default: + if n.OnConflictClause != nil { + switch n.OnConflictClause.Action { + case OnConflictAction_INSERT_OR_ABORT: + buf.WriteString("INSERT OR ABORT INTO ") + case OnConflictAction_INSERT_OR_REVERT: + buf.WriteString("INSERT OR REVERT INTO ") + case OnConflictAction_INSERT_OR_IGNORE: + buf.WriteString("INSERT OR IGNORE INTO ") + case OnConflictAction_UPSERT: + buf.WriteString("UPSERT INTO ") + case OnConflictAction_REPLACE: + buf.WriteString("REPLACE INTO ") + default: + buf.WriteString("INSERT INTO ") + } + } else { buf.WriteString("INSERT INTO ") } if n.Relation != nil { diff --git a/internal/sql/ast/recursive_func_call.go b/internal/sql/ast/recursive_func_call.go new file mode 100644 index 0000000000..1c7c0a8125 --- /dev/null +++ b/internal/sql/ast/recursive_func_call.go @@ -0,0 +1,33 @@ +package ast + +type RecursiveFuncCall struct { + Func Node + Funcname *List + Args *List + AggOrder *List + AggFilter Node + AggWithinGroup bool + AggStar bool + AggDistinct bool + FuncVariadic bool + Over *WindowDef + Location int +} + +func (n *RecursiveFuncCall) Pos() int { + return n.Location +} + +func (n *RecursiveFuncCall) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Func) + buf.WriteString("(") + if n.AggStar { + buf.WriteString("*") + } else { + buf.astFormat(n.Args) + } + buf.WriteString(")") +} diff --git a/internal/sql/astutils/rewrite.go b/internal/sql/astutils/rewrite.go index 8e8eefbff4..bcc7c17e40 100644 --- a/internal/sql/astutils/rewrite.go +++ b/internal/sql/astutils/rewrite.go @@ -607,7 +607,6 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. case *ast.CreateRoleStmt: a.apply(n, "BindRole", nil, n.BindRole) a.apply(n, "Options", nil, n.Options) - case *ast.CreateSchemaStmt: a.apply(n, "Authrole", nil, n.Authrole) @@ -1014,6 +1013,14 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "Roles", nil, n.Roles) a.apply(n, "Newrole", nil, n.Newrole) + case *ast.RecursiveFuncCall: + a.apply(n, "Func", nil, n.Func) + a.apply(n, "Funcname", nil, n.Funcname) + a.apply(n, "Args", nil, n.Args) + a.apply(n, "AggOrder", nil, n.AggOrder) + a.apply(n, "AggFilter", nil, n.AggFilter) + a.apply(n, "Over", nil, n.Over) + case *ast.RefreshMatViewStmt: a.apply(n, "Relation", nil, n.Relation) diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index e7b78d126b..dfc313fda1 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -1734,6 +1734,26 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.Newrole) } + case *ast.RecursiveFuncCall: + if n.Func != nil { + Walk(f, n.Func) + } + if n.Funcname != nil { + Walk(f, n.Funcname) + } + if n.Args != nil { + Walk(f, n.Args) + } + if n.AggOrder != nil { + Walk(f, n.AggOrder) + } + if n.AggFilter != nil { + Walk(f, n.AggFilter) + } + if n.Over != nil { + Walk(f, n.Over) + } + case *ast.RefreshMatViewStmt: if n.Relation != nil { Walk(f, n.Relation) From fc8c9322c5de37b185d43ce1debc4749daa72dd9 Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Wed, 21 May 2025 10:14:09 +0300 Subject: [PATCH 08/18] Added examples for funcs codegen --- examples/authors/sqlc.yaml | 1 + examples/authors/ydb/models.go | 6 +-- examples/authors/ydb/query.sql | 6 +++ examples/authors/ydb/query.sql.go | 62 ++++++++++++++++++++++++++----- examples/authors/ydb/schema.sql | 8 ++-- 5 files changed, 67 insertions(+), 16 deletions(-) diff --git a/examples/authors/sqlc.yaml b/examples/authors/sqlc.yaml index 8d6bc3db28..143cb608b0 100644 --- a/examples/authors/sqlc.yaml +++ b/examples/authors/sqlc.yaml @@ -51,6 +51,7 @@ sql: go: package: authors out: ydb + emit_json_tags: true rules: diff --git a/examples/authors/ydb/models.go b/examples/authors/ydb/models.go index 8edcdc7b33..12b4f3a604 100644 --- a/examples/authors/ydb/models.go +++ b/examples/authors/ydb/models.go @@ -5,7 +5,7 @@ package authors type Author struct { - ID uint64 - Name string - Bio *string + ID uint64 `json:"id"` + Name string `json:"name"` + Bio *string `json:"bio"` } diff --git a/examples/authors/ydb/query.sql b/examples/authors/ydb/query.sql index 67ce89a6a7..bf672042c5 100644 --- a/examples/authors/ydb/query.sql +++ b/examples/authors/ydb/query.sql @@ -13,6 +13,12 @@ WHERE name = $p0; SELECT * FROM authors WHERE bio IS NULL; +-- name: Count :one +SELECT COUNT(*) FROM authors; + +-- name: COALESCE :many +SELECT id, name, COALESCE(bio, 'Null value!') FROM authors; + -- name: CreateOrUpdateAuthor :execresult UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2); diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index e244f62c54..45d86c96fd 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -10,14 +10,58 @@ import ( "database/sql" ) +const cOALESCE = `-- name: COALESCE :many +SELECT id, name, COALESCE(bio, 'Null value!') FROM authors +` + +type COALESCERow struct { + ID uint64 `json:"id"` + Name string `json:"name"` + Bio string `json:"bio"` +} + +func (q *Queries) COALESCE(ctx context.Context) ([]COALESCERow, error) { + rows, err := q.db.QueryContext(ctx, cOALESCE) + if err != nil { + return nil, err + } + defer rows.Close() + var items []COALESCERow + for rows.Next() { + var i COALESCERow + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const count = `-- name: Count :one +SELECT COUNT(*) FROM authors +` + +func (q *Queries) Count(ctx context.Context) (uint64, error) { + row := q.db.QueryRowContext(ctx, count) + var count uint64 + err := row.Scan(&count) + return count, err +} + const createOrUpdateAuthor = `-- name: CreateOrUpdateAuthor :execresult UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) ` type CreateOrUpdateAuthorParams struct { - P0 uint64 - P1 string - P2 *string + P0 uint64 `json:"p0"` + P1 string `json:"p1"` + P2 *string `json:"p2"` } func (q *Queries) CreateOrUpdateAuthor(ctx context.Context, arg CreateOrUpdateAuthorParams) (sql.Result, error) { @@ -29,9 +73,9 @@ UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) RETURNING bio ` type CreateOrUpdateAuthorReturningBioParams struct { - P0 uint64 - P1 string - P2 *string + P0 uint64 `json:"p0"` + P1 string `json:"p1"` + P2 *string `json:"p2"` } func (q *Queries) CreateOrUpdateAuthorReturningBio(ctx context.Context, arg CreateOrUpdateAuthorReturningBioParams) (*string, error) { @@ -150,9 +194,9 @@ UPDATE authors SET name = $p0, bio = $p1 WHERE id = $p2 RETURNING id, name, bio ` type UpdateAuthorByIDParams struct { - P0 string - P1 *string - P2 uint64 + P0 string `json:"p0"` + P1 *string `json:"p1"` + P2 uint64 `json:"p2"` } func (q *Queries) UpdateAuthorByID(ctx context.Context, arg UpdateAuthorByIDParams) (Author, error) { diff --git a/examples/authors/ydb/schema.sql b/examples/authors/ydb/schema.sql index ee9329e809..5207fb3b1e 100644 --- a/examples/authors/ydb/schema.sql +++ b/examples/authors/ydb/schema.sql @@ -1,6 +1,6 @@ CREATE TABLE authors ( - id Uint64, - name Utf8 NOT NULL, - bio Utf8, - PRIMARY KEY (id) + id Uint64, + name Utf8 NOT NULL, + bio Utf8, + PRIMARY KEY (id) ); From d0e1550c3fb736ca7447d287f78939a8ca9963ff Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Tue, 27 May 2025 14:19:48 +0300 Subject: [PATCH 09/18] Upgraded jwt to 4.5.2 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index ebd7884d70..b3e8647b6c 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( cel.dev/expr v0.24.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/golang-jwt/jwt/v4 v4.5.0 // indirect + github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect diff --git a/go.sum b/go.sum index 910c0e9fca..8094d57fc2 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,8 @@ github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI6 github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= -github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= From 0b1752a833ddc0131dfd2ae07c0d7889c11acf16 Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov <150552906+1NepuNep1@users.noreply.github.com> Date: Tue, 2 Sep 2025 15:53:01 +0300 Subject: [PATCH 10/18] First ydb-go-sdk generation version (See #4) (#6) This PR adds native YDB Go SDK support to sqlc for YDB database engine, moving from the standard database/sql interface to the YDB-specific SDK. The implementation includes code generation templates, configuration updates, and example adaptations. Adds YDB Go SDK as a new SQL package option (ydb-go-sdk) Implements YDB-specific code generation templates for queries, interfaces, and database connections Updates configuration schema to support YDB as an engine option --- docker-compose.yml | 2 + examples/authors/sqlc.yaml | 1 + examples/authors/ydb/db.go | 17 +- examples/authors/ydb/db_test.go | 95 +------- examples/authors/ydb/query.sql | 29 +-- examples/authors/ydb/query.sql.go | 222 ++++++------------ go.mod | 4 +- go.sum | 4 + internal/codegen/golang/driver.go | 2 + internal/codegen/golang/gen.go | 9 + internal/codegen/golang/imports.go | 30 ++- internal/codegen/golang/opts/enum.go | 10 + internal/codegen/golang/query.go | 54 +++++ .../codegen/golang/templates/template.tmpl | 6 + .../golang/templates/ydb-go-sdk/dbCode.tmpl | 24 ++ .../templates/ydb-go-sdk/interfaceCode.tmpl | 36 +++ .../templates/ydb-go-sdk/queryCode.tmpl | 145 ++++++++++++ internal/config/v_two.json | 3 +- internal/sqltest/local/ydb.go | 90 +++---- 19 files changed, 446 insertions(+), 337 deletions(-) create mode 100644 internal/codegen/golang/templates/ydb-go-sdk/dbCode.tmpl create mode 100644 internal/codegen/golang/templates/ydb-go-sdk/interfaceCode.tmpl create mode 100644 internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl diff --git a/docker-compose.yml b/docker-compose.yml index e7c66b42ae..255527a3d1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -27,8 +27,10 @@ services: - "2136:2136" - "8765:8765" restart: always + hostname: localhost environment: - YDB_USE_IN_MEMORY_PDISKS=true - GRPC_TLS_PORT=2135 - GRPC_PORT=2136 - MON_PORT=8765 + diff --git a/examples/authors/sqlc.yaml b/examples/authors/sqlc.yaml index 143cb608b0..49fe62ff76 100644 --- a/examples/authors/sqlc.yaml +++ b/examples/authors/sqlc.yaml @@ -52,6 +52,7 @@ sql: package: authors out: ydb emit_json_tags: true + sql_package: ydb-go-sdk rules: diff --git a/examples/authors/ydb/db.go b/examples/authors/ydb/db.go index e2b0a86b13..c3b16b4481 100644 --- a/examples/authors/ydb/db.go +++ b/examples/authors/ydb/db.go @@ -6,14 +6,15 @@ package authors import ( "context" - "database/sql" + + "github.com/ydb-platform/ydb-go-sdk/v3/query" ) type DBTX interface { - ExecContext(context.Context, string, ...interface{}) (sql.Result, error) - PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...interface{}) *sql.Row + Exec(ctx context.Context, sql string, opts ...query.ExecuteOption) error + Query(ctx context.Context, sql string, opts ...query.ExecuteOption) (query.Result, error) + QueryResultSet(ctx context.Context, sql string, opts ...query.ExecuteOption) (query.ClosableResultSet, error) + QueryRow(ctx context.Context, sql string, opts ...query.ExecuteOption) (query.Row, error) } func New(db DBTX) *Queries { @@ -23,9 +24,3 @@ func New(db DBTX) *Queries { type Queries struct { db DBTX } - -func (q *Queries) WithTx(tx *sql.Tx) *Queries { - return &Queries{ - db: tx, - } -} diff --git a/examples/authors/ydb/db_test.go b/examples/authors/ydb/db_test.go index 76b37306ef..6ab15913f0 100644 --- a/examples/authors/ydb/db_test.go +++ b/examples/authors/ydb/db_test.go @@ -6,6 +6,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/sqltest/local" _ "github.com/ydb-platform/ydb-go-sdk/v3" + "github.com/ydb-platform/ydb-go-sdk/v3/query" ) func ptr(s string) *string { @@ -15,10 +16,10 @@ func ptr(s string) *string { func TestAuthors(t *testing.T) { ctx := context.Background() - test := local.YDB(t, []string{"schema.sql"}) - defer test.DB.Close() + db := local.YDB(t, []string{"schema.sql"}) + defer db.Close(ctx) - q := New(test.DB) + q := New(db.Query()) t.Run("InsertAuthors", func(t *testing.T) { authorsToInsert := []CreateOrUpdateAuthorParams{ @@ -38,53 +39,12 @@ func TestAuthors(t *testing.T) { } for _, author := range authorsToInsert { - if _, err := q.CreateOrUpdateAuthor(ctx, author); err != nil { + if err := q.CreateOrUpdateAuthor(ctx, author, query.WithIdempotent()); err != nil { t.Fatalf("failed to insert author %q: %v", author.P1, err) } } }) - t.Run("CreateOrUpdateAuthorReturningBio", func(t *testing.T) { - newBio := "Обновленная биография автора" - arg := CreateOrUpdateAuthorReturningBioParams{ - P0: 3, - P1: "Тестовый Автор", - P2: &newBio, - } - - returnedBio, err := q.CreateOrUpdateAuthorReturningBio(ctx, arg) - if err != nil { - t.Fatalf("failed to create or update author: %v", err) - } - - if returnedBio == nil { - t.Fatal("expected non-nil bio, got nil") - } - if *returnedBio != newBio { - t.Fatalf("expected bio %q, got %q", newBio, *returnedBio) - } - - t.Logf("Author created or updated successfully with bio: %s", *returnedBio) - }) - - t.Run("Update Author", func(t *testing.T) { - arg := UpdateAuthorByIDParams{ - P0: "Максим Горький", - P1: ptr("Обновленная биография"), - P2: 10, - } - - singleAuthor, err := q.UpdateAuthorByID(ctx, arg) - if err != nil { - t.Fatal(err) - } - bio := "Null" - if singleAuthor.Bio != nil { - bio = *singleAuthor.Bio - } - t.Logf("- ID: %d | Name: %s | Bio: %s", singleAuthor.ID, singleAuthor.Name, bio) - }) - t.Run("ListAuthors", func(t *testing.T) { authors, err := q.ListAuthors(ctx) if err != nil { @@ -115,46 +75,10 @@ func TestAuthors(t *testing.T) { t.Logf("- ID: %d | Name: %s | Bio: %s", singleAuthor.ID, singleAuthor.Name, bio) }) - t.Run("GetAuthorByName", func(t *testing.T) { - authors, err := q.GetAuthorsByName(ctx, "Александр Пушкин") - if err != nil { - t.Fatal(err) - } - if len(authors) == 0 { - t.Fatal("expected at least one author with this name, got none") - } - t.Log("Authors with this name:") - for _, a := range authors { - bio := "Null" - if a.Bio != nil { - bio = *a.Bio - } - t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) - } - }) - - t.Run("ListAuthorsWithNullBio", func(t *testing.T) { - authors, err := q.ListAuthorsWithNullBio(ctx) - if err != nil { - t.Fatal(err) - } - if len(authors) == 0 { - t.Fatal("expected at least one author with NULL bio, got none") - } - t.Log("Authors with NULL bio:") - for _, a := range authors { - bio := "Null" - if a.Bio != nil { - bio = *a.Bio - } - t.Logf("- ID: %d | Name: %s | Bio: %s", a.ID, a.Name, bio) - } - }) - t.Run("Delete All Authors", func(t *testing.T) { var i uint64 for i = 1; i <= 13; i++ { - if err := q.DeleteAuthor(ctx, i); err != nil { + if err := q.DeleteAuthor(ctx, i, query.WithIdempotent()); err != nil { t.Fatalf("failed to delete authors: %v", err) } } @@ -166,4 +90,11 @@ func TestAuthors(t *testing.T) { t.Fatalf("expected no authors, got %d", len(authors)) } }) + + t.Run("Drop Table Authors", func(t *testing.T) { + err := q.DropTable(ctx) + if err != nil { + t.Fatal(err) + } + }) } diff --git a/examples/authors/ydb/query.sql b/examples/authors/ydb/query.sql index bf672042c5..804150615d 100644 --- a/examples/authors/ydb/query.sql +++ b/examples/authors/ydb/query.sql @@ -1,32 +1,15 @@ --- name: ListAuthors :many -SELECT * FROM authors; - -- name: GetAuthor :one SELECT * FROM authors -WHERE id = $p0; +WHERE id = $p0 LIMIT 1; --- name: GetAuthorsByName :many -SELECT * FROM authors -WHERE name = $p0; - --- name: ListAuthorsWithNullBio :many -SELECT * FROM authors -WHERE bio IS NULL; - --- name: Count :one -SELECT COUNT(*) FROM authors; - --- name: COALESCE :many -SELECT id, name, COALESCE(bio, 'Null value!') FROM authors; +-- name: ListAuthors :many +SELECT * FROM authors ORDER BY name; --- name: CreateOrUpdateAuthor :execresult +-- name: CreateOrUpdateAuthor :exec UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2); --- name: CreateOrUpdateAuthorReturningBio :one -UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) RETURNING bio; - -- name: DeleteAuthor :exec DELETE FROM authors WHERE id = $p0; --- name: UpdateAuthorByID :one -UPDATE authors SET name = $p0, bio = $p1 WHERE id = $p2 RETURNING *; +-- name: DropTable :exec +DROP TABLE IF EXISTS authors; \ No newline at end of file diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index 45d86c96fd..7459482b3a 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -7,54 +7,13 @@ package authors import ( "context" - "database/sql" -) - -const cOALESCE = `-- name: COALESCE :many -SELECT id, name, COALESCE(bio, 'Null value!') FROM authors -` - -type COALESCERow struct { - ID uint64 `json:"id"` - Name string `json:"name"` - Bio string `json:"bio"` -} -func (q *Queries) COALESCE(ctx context.Context) ([]COALESCERow, error) { - rows, err := q.db.QueryContext(ctx, cOALESCE) - if err != nil { - return nil, err - } - defer rows.Close() - var items []COALESCERow - for rows.Next() { - var i COALESCERow - if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const count = `-- name: Count :one -SELECT COUNT(*) FROM authors -` - -func (q *Queries) Count(ctx context.Context) (uint64, error) { - row := q.db.QueryRowContext(ctx, count) - var count uint64 - err := row.Scan(&count) - return count, err -} + "github.com/ydb-platform/ydb-go-sdk/v3" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/query" +) -const createOrUpdateAuthor = `-- name: CreateOrUpdateAuthor :execresult +const createOrUpdateAuthor = `-- name: CreateOrUpdateAuthor :exec UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) ` @@ -64,144 +23,97 @@ type CreateOrUpdateAuthorParams struct { P2 *string `json:"p2"` } -func (q *Queries) CreateOrUpdateAuthor(ctx context.Context, arg CreateOrUpdateAuthorParams) (sql.Result, error) { - return q.db.ExecContext(ctx, createOrUpdateAuthor, arg.P0, arg.P1, arg.P2) -} - -const createOrUpdateAuthorReturningBio = `-- name: CreateOrUpdateAuthorReturningBio :one -UPSERT INTO authors (id, name, bio) VALUES ($p0, $p1, $p2) RETURNING bio -` - -type CreateOrUpdateAuthorReturningBioParams struct { - P0 uint64 `json:"p0"` - P1 string `json:"p1"` - P2 *string `json:"p2"` -} - -func (q *Queries) CreateOrUpdateAuthorReturningBio(ctx context.Context, arg CreateOrUpdateAuthorReturningBioParams) (*string, error) { - row := q.db.QueryRowContext(ctx, createOrUpdateAuthorReturningBio, arg.P0, arg.P1, arg.P2) - var bio *string - err := row.Scan(&bio) - return bio, err +func (q *Queries) CreateOrUpdateAuthor(ctx context.Context, arg CreateOrUpdateAuthorParams, opts ...query.ExecuteOption) error { + err := q.db.Exec(ctx, createOrUpdateAuthor, + append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ + "$p0": arg.P0, + "$p1": arg.P1, + "$p2": arg.P2, + })))..., + ) + if err != nil { + return xerrors.WithStackTrace(err) + } + return nil } const deleteAuthor = `-- name: DeleteAuthor :exec DELETE FROM authors WHERE id = $p0 ` -func (q *Queries) DeleteAuthor(ctx context.Context, p0 uint64) error { - _, err := q.db.ExecContext(ctx, deleteAuthor, p0) - return err -} - -const getAuthor = `-- name: GetAuthor :one -SELECT id, name, bio FROM authors -WHERE id = $p0 -` - -func (q *Queries) GetAuthor(ctx context.Context, p0 uint64) (Author, error) { - row := q.db.QueryRowContext(ctx, getAuthor, p0) - var i Author - err := row.Scan(&i.ID, &i.Name, &i.Bio) - return i, err +func (q *Queries) DeleteAuthor(ctx context.Context, p0 uint64, opts ...query.ExecuteOption) error { + err := q.db.Exec(ctx, deleteAuthor, + append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ + "$p0": p0, + })))..., + ) + if err != nil { + return xerrors.WithStackTrace(err) + } + return nil } -const getAuthorsByName = `-- name: GetAuthorsByName :many -SELECT id, name, bio FROM authors -WHERE name = $p0 +const dropTable = `-- name: DropTable :exec +DROP TABLE IF EXISTS authors ` -func (q *Queries) GetAuthorsByName(ctx context.Context, p0 string) ([]Author, error) { - rows, err := q.db.QueryContext(ctx, getAuthorsByName, p0) +func (q *Queries) DropTable(ctx context.Context, opts ...query.ExecuteOption) error { + err := q.db.Exec(ctx, dropTable, opts...) if err != nil { - return nil, err - } - defer rows.Close() - var items []Author - for rows.Next() { - var i Author - if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err + return xerrors.WithStackTrace(err) } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil + return nil } -const listAuthors = `-- name: ListAuthors :many +const getAuthor = `-- name: GetAuthor :one SELECT id, name, bio FROM authors +WHERE id = $p0 LIMIT 1 ` -func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { - rows, err := q.db.QueryContext(ctx, listAuthors) +func (q *Queries) GetAuthor(ctx context.Context, p0 uint64, opts ...query.ExecuteOption) (Author, error) { + row, err := q.db.QueryRow(ctx, getAuthor, + append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ + "$p0": p0, + })))..., + ) + var i Author if err != nil { - return nil, err - } - defer rows.Close() - var items []Author - for rows.Next() { - var i Author - if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, err - } - items = append(items, i) + return i, xerrors.WithStackTrace(err) } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err + err = row.Scan(&i.ID, &i.Name, &i.Bio) + if err != nil { + return i, xerrors.WithStackTrace(err) } - return items, nil + return i, nil } -const listAuthorsWithNullBio = `-- name: ListAuthorsWithNullBio :many -SELECT id, name, bio FROM authors -WHERE bio IS NULL +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio FROM authors ORDER BY name ` -func (q *Queries) ListAuthorsWithNullBio(ctx context.Context) ([]Author, error) { - rows, err := q.db.QueryContext(ctx, listAuthorsWithNullBio) +func (q *Queries) ListAuthors(ctx context.Context, opts ...query.ExecuteOption) ([]Author, error) { + result, err := q.db.Query(ctx, listAuthors, opts...) if err != nil { - return nil, err + return nil, xerrors.WithStackTrace(err) } - defer rows.Close() var items []Author - for rows.Next() { - var i Author - if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, err + for set, err := range result.ResultSets(ctx) { + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + for row, err := range set.Rows(ctx) { + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + var i Author + if err := row.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, xerrors.WithStackTrace(err) + } + items = append(items, i) } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err } - if err := rows.Err(); err != nil { - return nil, err + if err := result.Close(ctx); err != nil { + return nil, xerrors.WithStackTrace(err) } return items, nil } - -const updateAuthorByID = `-- name: UpdateAuthorByID :one -UPDATE authors SET name = $p0, bio = $p1 WHERE id = $p2 RETURNING id, name, bio -` - -type UpdateAuthorByIDParams struct { - P0 string `json:"p0"` - P1 *string `json:"p1"` - P2 uint64 `json:"p2"` -} - -func (q *Queries) UpdateAuthorByID(ctx context.Context, arg UpdateAuthorByIDParams) (Author, error) { - row := q.db.QueryRowContext(ctx, updateAuthorByID, arg.P0, arg.P1, arg.P2) - var i Author - err := row.Scan(&i.ID, &i.Name, &i.Bio) - return i, err -} diff --git a/go.mod b/go.mod index b3e8647b6c..c72f29b6b1 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/tetratelabs/wazero v1.9.0 github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 github.com/xeipuuv/gojsonschema v1.2.0 - github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0 + github.com/ydb-platform/ydb-go-sdk/v3 v3.115.3 github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 golang.org/x/sync v0.16.0 google.golang.org/grpc v1.75.0 @@ -48,7 +48,7 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - github.com/jonboulle/clockwork v0.3.0 // indirect + github.com/jonboulle/clockwork v0.5.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb // indirect diff --git a/go.sum b/go.sum index 8094d57fc2..53cb8e8aec 100644 --- a/go.sum +++ b/go.sum @@ -146,6 +146,8 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jonboulle/clockwork v0.3.0 h1:9BSCMi8C+0qdApAp4auwX0RkLGUjs956h0EkuQymUhg= github.com/jonboulle/clockwork v0.3.0/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= +github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I= +github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -242,6 +244,8 @@ github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 h1:LY github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77/go.mod h1:Er+FePu1dNUieD+XTMDduGpQuCPssK5Q4BjF+IIXJ3I= github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0 h1:TwWSp3gRMcja/hRpOofncLvgxAXCmzpz5cGtmdaoITw= github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0/go.mod h1:l5sSv153E18VvYcsmr51hok9Sjc16tEC8AXGbwrk+ho= +github.com/ydb-platform/ydb-go-sdk/v3 v3.115.3 h1:SFeSK2c+PmiToyNIhr143u+YDzLhl/kboXwKLYDk0O4= +github.com/ydb-platform/ydb-go-sdk/v3 v3.115.3/go.mod h1:Pp1w2xxUoLQ3NCNAwV7pvDq0TVQOdtAqs+ZiC+i8r14= github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 h1:KFtJwlPdOxWjCKXX0jFJ8k1FlbqbRbUW3k/kYSZX7SA= github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333/go.mod h1:vrPJPS8cdPSV568YcXhB4bUwhyV8bmWKqmQ5c5Xi99o= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= diff --git a/internal/codegen/golang/driver.go b/internal/codegen/golang/driver.go index 5e3a533dcc..6e0596172f 100644 --- a/internal/codegen/golang/driver.go +++ b/internal/codegen/golang/driver.go @@ -8,6 +8,8 @@ func parseDriver(sqlPackage string) opts.SQLDriver { return opts.SQLDriverPGXV4 case opts.SQLPackagePGXV5: return opts.SQLDriverPGXV5 + case opts.SQLPackageYDBGoSDK: + return opts.SQLDriverYDBGoSDK default: return opts.SQLDriverLibPQ } diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 7df56a0a41..4b48f34bde 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -209,6 +209,15 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, return nil, errors.New(":batch* commands are only supported by pgx") } + if tctx.SQLDriver.IsYDBGoSDK() { + for _, q := range queries { + switch q.Cmd { + case metadata.CmdExecResult, metadata.CmdExecRows, metadata.CmdExecLastId: + return nil, fmt.Errorf("%s is not supported by ydb-go-sdk", q.Cmd) + } + } + } + funcMap := template.FuncMap{ "lowerTitle": sdk.LowerTitle, "comment": sdk.DoubleSlashComment, diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index ccca4f603c..17e426b8f9 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -132,6 +132,8 @@ func (i *importer) dbImports() fileImports { case opts.SQLDriverPGXV5: pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}) pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5"}) + case opts.SQLDriverYDBGoSDK: + pkg = append(pkg, ImportSpec{Path: "github.com/ydb-platform/ydb-go-sdk/v3/query"}) default: std = append(std, ImportSpec{Path: "database/sql"}) if i.Options.EmitPreparedQueries { @@ -177,7 +179,9 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool case opts.SQLDriverPGXV5: pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}] = struct{}{} default: - std["database/sql"] = struct{}{} + if !sqlpkg.IsYDBGoSDK() { + std["database/sql"] = struct{}{} + } } } } @@ -267,6 +271,11 @@ func (i *importer) interfaceImports() fileImports { }) std["context"] = struct{}{} + + sqlpkg := parseDriver(i.Options.SqlPackage) + if sqlpkg.IsYDBGoSDK() { + pkg[ImportSpec{Path: "github.com/ydb-platform/ydb-go-sdk/v3/query"}] = struct{}{} + } return sortedImports(std, pkg) } @@ -395,13 +404,28 @@ func (i *importer) queryImports(filename string) fileImports { } sqlpkg := parseDriver(i.Options.SqlPackage) - if sqlcSliceScan() && !sqlpkg.IsPGX() { + if sqlcSliceScan() && !sqlpkg.IsPGX() && !sqlpkg.IsYDBGoSDK() { std["strings"] = struct{}{} } - if sliceScan() && !sqlpkg.IsPGX() { + if sliceScan() && !sqlpkg.IsPGX() && !sqlpkg.IsYDBGoSDK() { pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } + if sqlpkg.IsYDBGoSDK() { + hasParams := false + for _, q := range gq { + if !q.Arg.isEmpty() { + hasParams = true + break + } + } + if hasParams { + pkg[ImportSpec{Path: "github.com/ydb-platform/ydb-go-sdk/v3"}] = struct{}{} + } + pkg[ImportSpec{Path: "github.com/ydb-platform/ydb-go-sdk/v3/query"}] = struct{}{} + pkg[ImportSpec{Path: "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xerrors"}] = struct{}{} + } + if i.Options.WrapErrors { std["fmt"] = struct{}{} } diff --git a/internal/codegen/golang/opts/enum.go b/internal/codegen/golang/opts/enum.go index 40457d040a..4d57a080a8 100644 --- a/internal/codegen/golang/opts/enum.go +++ b/internal/codegen/golang/opts/enum.go @@ -8,12 +8,14 @@ const ( SQLPackagePGXV4 string = "pgx/v4" SQLPackagePGXV5 string = "pgx/v5" SQLPackageStandard string = "database/sql" + SQLPackageYDBGoSDK string = "ydb-go-sdk" ) var validPackages = map[string]struct{}{ string(SQLPackagePGXV4): {}, string(SQLPackagePGXV5): {}, string(SQLPackageStandard): {}, + string(SQLPackageYDBGoSDK): {}, } func validatePackage(sqlPackage string) error { @@ -28,6 +30,7 @@ const ( SQLDriverPGXV5 = "github.com/jackc/pgx/v5" SQLDriverLibPQ = "github.com/lib/pq" SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql" + SQLDriverYDBGoSDK = "github.com/ydb-platform/ydb-go-sdk/v3" ) var validDrivers = map[string]struct{}{ @@ -35,6 +38,7 @@ var validDrivers = map[string]struct{}{ string(SQLDriverPGXV5): {}, string(SQLDriverLibPQ): {}, string(SQLDriverGoSQLDriverMySQL): {}, + string(SQLDriverYDBGoSDK): {}, } func validateDriver(sqlDriver string) error { @@ -52,12 +56,18 @@ func (d SQLDriver) IsGoSQLDriverMySQL() bool { return d == SQLDriverGoSQLDriverMySQL } +func (d SQLDriver) IsYDBGoSDK() bool { + return d == SQLDriverYDBGoSDK +} + func (d SQLDriver) Package() string { switch d { case SQLDriverPGXV4: return SQLPackagePGXV4 case SQLDriverPGXV5: return SQLPackagePGXV5 + case SQLDriverYDBGoSDK: + return SQLPackageYDBGoSDK default: return SQLPackageStandard } diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 3b4fb2fa1a..02a09c3870 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -39,6 +39,10 @@ func (v QueryValue) isEmpty() bool { return v.Typ == "" && v.Name == "" && v.Struct == nil } +func (v QueryValue) IsEmpty() bool { + return v.isEmpty() +} + type Argument struct { Name string Type string @@ -254,6 +258,56 @@ func (v QueryValue) VariableForField(f Field) string { return v.Name + "." + f.Name } +func addDollarPrefix(name string) string { + if name == "" { + return name + } + if strings.HasPrefix(name, "$") { + return name + } + return "$" + name +} + +// YDBParamMapEntries returns entries for a map[string]any literal for YDB parameters. +func (v QueryValue) YDBParamMapEntries() string { + if v.isEmpty() { + return "" + } + + var parts []string + for _, field := range v.getParameterFields() { + if field.Column != nil && field.Column.IsNamedParam { + name := field.Column.GetName() + if name != "" { + key := fmt.Sprintf("%q", addDollarPrefix(name)) + variable := v.VariableForField(field) + parts = append(parts, key+": "+escape(variable)) + } + } + } + + if len(parts) == 0 { + return "" + } + + parts = append(parts, "") + return "\n" + strings.Join(parts, ",\n") +} + +func (v QueryValue) getParameterFields() []Field { + if v.Struct == nil { + return []Field{ + { + Name: v.Name, + DBName: v.DBName, + Type: v.Typ, + Column: v.Column, + }, + } + } + return v.Struct.Fields +} + // A struct used to generate methods and fields on the Queries struct type Query struct { Cmd string diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl index afd50c01ac..f74b796349 100644 --- a/internal/codegen/golang/templates/template.tmpl +++ b/internal/codegen/golang/templates/template.tmpl @@ -25,6 +25,8 @@ import ( {{if .SQLDriver.IsPGX }} {{- template "dbCodeTemplatePgx" .}} +{{else if .SQLDriver.IsYDBGoSDK }} + {{- template "dbCodeTemplateYDB" .}} {{else}} {{- template "dbCodeTemplateStd" .}} {{end}} @@ -57,6 +59,8 @@ import ( {{define "interfaceCode"}} {{if .SQLDriver.IsPGX }} {{- template "interfaceCodePgx" .}} + {{else if .SQLDriver.IsYDBGoSDK }} + {{- template "interfaceCodeYDB" .}} {{else}} {{- template "interfaceCodeStd" .}} {{end}} @@ -188,6 +192,8 @@ import ( {{define "queryCode"}} {{if .SQLDriver.IsPGX }} {{- template "queryCodePgx" .}} +{{else if .SQLDriver.IsYDBGoSDK }} + {{- template "queryCodeYDB" .}} {{else}} {{- template "queryCodeStd" .}} {{end}} diff --git a/internal/codegen/golang/templates/ydb-go-sdk/dbCode.tmpl b/internal/codegen/golang/templates/ydb-go-sdk/dbCode.tmpl new file mode 100644 index 0000000000..f79831d2e2 --- /dev/null +++ b/internal/codegen/golang/templates/ydb-go-sdk/dbCode.tmpl @@ -0,0 +1,24 @@ +{{define "dbCodeTemplateYDB"}} +type DBTX interface { + Exec(ctx context.Context, sql string, opts ...query.ExecuteOption) error + Query(ctx context.Context, sql string, opts ...query.ExecuteOption) (query.Result, error) + QueryResultSet(ctx context.Context, sql string, opts ...query.ExecuteOption) (query.ClosableResultSet, error) + QueryRow(ctx context.Context, sql string, opts ...query.ExecuteOption) (query.Row, error) +} + +{{ if .EmitMethodsWithDBArgument}} +func New() *Queries { + return &Queries{} +{{- else -}} +func New(db DBTX) *Queries { + return &Queries{db: db} +{{- end}} +} + +type Queries struct { + {{if not .EmitMethodsWithDBArgument}} + db DBTX + {{end}} +} + +{{end}} diff --git a/internal/codegen/golang/templates/ydb-go-sdk/interfaceCode.tmpl b/internal/codegen/golang/templates/ydb-go-sdk/interfaceCode.tmpl new file mode 100644 index 0000000000..f9c06cc705 --- /dev/null +++ b/internal/codegen/golang/templates/ydb-go-sdk/interfaceCode.tmpl @@ -0,0 +1,36 @@ +{{define "interfaceCodeYDB"}} + type Querier interface { + {{- $dbtxParam := .EmitMethodsWithDBArgument -}} + {{- range .GoQueries}} + {{- if and (eq .Cmd ":one") ($dbtxParam) }} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, db DBTX, {{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ({{.Ret.DefineType}}, error) + {{- else if eq .Cmd ":one"}} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, {{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ({{.Ret.DefineType}}, error) + {{- end}} + {{- if and (eq .Cmd ":many") ($dbtxParam) }} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, db DBTX, {{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ([]{{.Ret.DefineType}}, error) + {{- else if eq .Cmd ":many"}} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, {{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ([]{{.Ret.DefineType}}, error) + {{- end}} + {{- if and (eq .Cmd ":exec") ($dbtxParam) }} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, db DBTX, {{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) error + {{- else if eq .Cmd ":exec"}} + {{range .Comments}}//{{.}} + {{end -}} + {{.MethodName}}(ctx context.Context, {{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) error + {{- end}} + {{- end}} + } + + var _ Querier = (*Queries)(nil) +{{end}} diff --git a/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl b/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl new file mode 100644 index 0000000000..ecd78b1344 --- /dev/null +++ b/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl @@ -0,0 +1,145 @@ +{{define "queryCodeYDB"}} +{{range .GoQueries}} +{{if $.OutputQuery .SourceName}} +const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} +{{escape .SQL}} +{{$.Q}} + +{{if .Arg.EmitStruct}} +type {{.Arg.Type}} struct { {{- range .Arg.UniqueFields}} + {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} + {{- end}} +} +{{end}} + +{{if .Ret.EmitStruct}} +type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} + {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} + {{- end}} +} +{{end}} + +{{if eq .Cmd ":one"}} +{{range .Comments}}//{{.}} +{{end -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ({{.Ret.DefineType}}, error) { + {{- $dbArg := "q.db" }}{{- if $.EmitMethodsWithDBArgument }}{{- $dbArg = "db" }}{{- end -}} + {{- if .Arg.IsEmpty }} + row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, opts...) + {{- else }} + row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, + append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ {{.Arg.YDBParamMapEntries}} })))..., + ) + {{- end }} + {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} + var {{.Ret.Name}} {{.Ret.Type}} + {{- end}} + if err != nil { + {{- if $.WrapErrors}} + return {{.Ret.ReturnName}}, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + {{- else }} + return {{.Ret.ReturnName}}, xerrors.WithStackTrace(err) + {{- end }} + } + err = row.Scan({{.Ret.Scan}}) + {{- if $.WrapErrors}} + if err != nil { + return {{.Ret.ReturnName}}, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + } + {{- else }} + if err != nil { + return {{.Ret.ReturnName}}, xerrors.WithStackTrace(err) + } + {{- end}} + return {{.Ret.ReturnName}}, nil +} +{{end}} + +{{if eq .Cmd ":many"}} +{{range .Comments}}//{{.}} +{{end -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ([]{{.Ret.DefineType}}, error) { + {{- $dbArg := "q.db" }}{{- if $.EmitMethodsWithDBArgument }}{{- $dbArg = "db" }}{{- end -}} + {{- if .Arg.IsEmpty }} + result, err := {{$dbArg}}.Query(ctx, {{.ConstantName}}, opts...) + {{- else }} + result, err := {{$dbArg}}.Query(ctx, {{.ConstantName}}, + append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ {{.Arg.YDBParamMapEntries}} })))..., + ) + {{- end }} + if err != nil { + {{- if $.WrapErrors}} + return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + {{- else }} + return nil, xerrors.WithStackTrace(err) + {{- end }} + } + {{- if $.EmitEmptySlices}} + items := []{{.Ret.DefineType}}{} + {{else}} + var items []{{.Ret.DefineType}} + {{end -}} + for set, err := range result.ResultSets(ctx) { + if err != nil { + {{- if $.WrapErrors}} + return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + {{- else }} + return nil, xerrors.WithStackTrace(err) + {{- end }} + } + for row, err := range set.Rows(ctx) { + if err != nil { + {{- if $.WrapErrors}} + return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + {{- else }} + return nil, xerrors.WithStackTrace(err) + {{- end }} + } + var {{.Ret.Name}} {{.Ret.Type}} + if err := row.Scan({{.Ret.Scan}}); err != nil { + {{- if $.WrapErrors}} + return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + {{- else }} + return nil, xerrors.WithStackTrace(err) + {{- end }} + } + items = append(items, {{.Ret.ReturnName}}) + } + } + if err := result.Close(ctx); err != nil { + {{- if $.WrapErrors}} + return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + {{- else }} + return nil, xerrors.WithStackTrace(err) + {{- end }} + } + return items, nil +} +{{end}} + +{{if eq .Cmd ":exec"}} +{{range .Comments}}//{{.}} +{{end -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) error { + {{- $dbArg := "q.db" }}{{- if $.EmitMethodsWithDBArgument }}{{- $dbArg = "db" }}{{- end -}} + {{- if .Arg.IsEmpty }} + err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, opts...) + {{- else }} + err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, + append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ {{.Arg.YDBParamMapEntries}} })))..., + ) + {{- end }} + if err != nil { + {{- if $.WrapErrors }} + return xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + {{- else }} + return xerrors.WithStackTrace(err) + {{- end }} + } + return nil +} +{{end}} + +{{end}} +{{end}} +{{end}} diff --git a/internal/config/v_two.json b/internal/config/v_two.json index acf914997d..fd7084d6e8 100644 --- a/internal/config/v_two.json +++ b/internal/config/v_two.json @@ -38,7 +38,8 @@ "enum": [ "postgresql", "mysql", - "sqlite" + "sqlite", + "ydb" ] }, "schema": { diff --git a/internal/sqltest/local/ydb.go b/internal/sqltest/local/ydb.go index 8703b170b5..e3e51e3716 100644 --- a/internal/sqltest/local/ydb.go +++ b/internal/sqltest/local/ydb.go @@ -2,11 +2,8 @@ package local import ( "context" - "database/sql" "fmt" - "hash/fnv" "math/rand" - "net" "os" "testing" "time" @@ -14,104 +11,77 @@ import ( migrate "github.com/sqlc-dev/sqlc/internal/migrations" "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" "github.com/ydb-platform/ydb-go-sdk/v3" + "github.com/ydb-platform/ydb-go-sdk/v3/query" ) func init() { rand.Seed(time.Now().UnixNano()) } -func YDB(t *testing.T, migrations []string) TestYDB { - return link_YDB(t, migrations, true) +func YDB(t *testing.T, migrations []string) *ydb.Driver { + return link_YDB(t, migrations, true, false) } -func ReadOnlyYDB(t *testing.T, migrations []string) TestYDB { - return link_YDB(t, migrations, false) +func YDBTLS(t *testing.T, migrations []string) *ydb.Driver { + return link_YDB(t, migrations, true, true) } -type TestYDB struct { - DB *sql.DB - Prefix string +func ReadOnlyYDB(t *testing.T, migrations []string) *ydb.Driver { + return link_YDB(t, migrations, false, false) } -func link_YDB(t *testing.T, migrations []string, rw bool) TestYDB { - t.Helper() - - time.Sleep(1 * time.Second) // wait for YDB to start +func ReadOnlyYDBTLS(t *testing.T, migrations []string) *ydb.Driver { + return link_YDB(t, migrations, false, true) +} +func link_YDB(t *testing.T, migrations []string, rw bool, tls bool) *ydb.Driver { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + t.Helper() + dbuiri := os.Getenv("YDB_SERVER_URI") if dbuiri == "" { t.Skip("YDB_SERVER_URI is empty") } - host, _, err := net.SplitHostPort(dbuiri) - if err != nil { - t.Fatalf("invalid YDB_SERVER_URI: %q", dbuiri) - } baseDB := os.Getenv("YDB_DATABASE") if baseDB == "" { baseDB = "/local" } - var seed []string - files, err := sqlpath.Glob(migrations) - if err != nil { - t.Fatal(err) - } - h := fnv.New64() - for _, f := range files { - blob, err := os.ReadFile(f) - if err != nil { - t.Fatal(err) - } - h.Write(blob) - seed = append(seed, migrate.RemoveRollbackStatements(string(blob))) - } - - var name string - if rw { - // name = fmt.Sprintf("sqlc_test_%s", id()) - name = fmt.Sprintf("sqlc_test_%s", "test_new") + var connectionString string + if tls { + connectionString = fmt.Sprintf("grpcs://%s%s", dbuiri, baseDB) } else { - name = fmt.Sprintf("sqlc_test_%x", h.Sum(nil)) + connectionString = fmt.Sprintf("grpc://%s%s", dbuiri, baseDB) } - prefix := fmt.Sprintf("%s/%s", baseDB, name) - rootDSN := fmt.Sprintf("grpc://%s?database=%s", dbuiri, baseDB) - t.Logf("→ Opening root driver: %s", rootDSN) - driver, err := ydb.Open(ctx, rootDSN, + db, err := ydb.Open(ctx, connectionString, ydb.WithInsecure(), ydb.WithDiscoveryInterval(time.Hour), - ydb.WithNodeAddressMutator(func(_ string) string { - return host - }), ) if err != nil { - t.Fatalf("failed to open root YDB connection: %s", err) + t.Fatalf("failed to open YDB connection: %s", err) } - connector, err := ydb.Connector( - driver, - ydb.WithTablePathPrefix(prefix), - ydb.WithAutoDeclare(), - ydb.WithNumericArgs(), - ) + files, err := sqlpath.Glob(migrations) if err != nil { - t.Fatalf("failed to create connector: %s", err) + t.Fatal(err) } - db := sql.OpenDB(connector) - - t.Log("→ Applying migrations to prefix: ", prefix) + for _, f := range files { + blob, err := os.ReadFile(f) + if err != nil { + t.Fatal(err) + } + stmt := migrate.RemoveRollbackStatements(string(blob)) - schemeCtx := ydb.WithQueryMode(ctx, ydb.SchemeQueryMode) - for _, stmt := range seed { - _, err := db.ExecContext(schemeCtx, stmt) + err = db.Query().Exec(ctx, stmt, query.WithTxControl(query.EmptyTxControl())) if err != nil { t.Fatalf("failed to apply migration: %s\nSQL: %s", err, stmt) } } - return TestYDB{DB: db, Prefix: prefix} + + return db } From 4eb8ff53c644d423e0527c66a61a2f10e1b0676d Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Wed, 3 Sep 2025 12:57:49 +0300 Subject: [PATCH 11/18] Bump sqlc version in examples/authors/ydb to v1.30.0 --- examples/authors/ydb/db.go | 2 +- examples/authors/ydb/models.go | 2 +- examples/authors/ydb/query.sql.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/authors/ydb/db.go b/examples/authors/ydb/db.go index c3b16b4481..9a15b333ce 100644 --- a/examples/authors/ydb/db.go +++ b/examples/authors/ydb/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.30.0 package authors diff --git a/examples/authors/ydb/models.go b/examples/authors/ydb/models.go index 12b4f3a604..2806beacfe 100644 --- a/examples/authors/ydb/models.go +++ b/examples/authors/ydb/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.30.0 package authors diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index 7459482b3a..3233b705d3 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.30.0 // source: query.sql package authors From 14b60c501a7ace05dd0089df8d6e2661a0fabe2e Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov Date: Wed, 3 Sep 2025 14:25:40 +0300 Subject: [PATCH 12/18] Rewrited comments and testdata to eng --- examples/authors/ydb/README.md | 47 ------------------- examples/authors/ydb/db_test.go | 28 +++++------ .../ydb/catalog_tests/alter_group_test.go | 6 +-- .../ydb/catalog_tests/alter_user_test.go | 10 ++-- .../ydb/catalog_tests/create_group_test.go | 6 +-- .../ydb/catalog_tests/create_user_test.go | 6 +-- .../engine/ydb/catalog_tests/delete_test.go | 6 +-- .../ydb/catalog_tests/drop_role_test.go | 4 +- .../engine/ydb/catalog_tests/insert_test.go | 6 +-- .../engine/ydb/catalog_tests/pragma_test.go | 6 +-- .../engine/ydb/catalog_tests/select_test.go | 6 +-- .../engine/ydb/catalog_tests/update_test.go | 6 +-- 12 files changed, 45 insertions(+), 92 deletions(-) delete mode 100644 examples/authors/ydb/README.md diff --git a/examples/authors/ydb/README.md b/examples/authors/ydb/README.md deleted file mode 100644 index 9e77fc7886..0000000000 --- a/examples/authors/ydb/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# Инструкция по генерации - -В файлах `schema.sql` и `query.sql` записаны, соответственно, схема базы данных и запросы, из которых вы хотите сгенерировать код к базе данных. -В `db_test.go` находятся тесты для последних сгенерированных команд. -Ниже находятся команды для генерации и запуска тестов. - ---- - -### 1. Создание бинарника sqlc - -```bash -make sqlc-dev -``` - -### 2. Запуск YDB через Docker Compose - -```bash -make ydb -``` - -### 3. Генерация кода для примеров для YDB - -```bash -make gen-examples-ydb -``` - -### 4. Запуск тестов примеров для YDB - -```bash -make test-examples-ydb -``` - -### 5. Полный цикл: сборка, генерация, тестирование (удобно одной командой) - -```bash -make ydb-examples -``` - -Эта команда выполнит: - -- Сборку `sqlc-dev` -- Запуск контейнера YDB -- Генерацию примеров -- Тестирование примеров - ---- - diff --git a/examples/authors/ydb/db_test.go b/examples/authors/ydb/db_test.go index 6ab15913f0..ab5324e76d 100644 --- a/examples/authors/ydb/db_test.go +++ b/examples/authors/ydb/db_test.go @@ -23,19 +23,19 @@ func TestAuthors(t *testing.T) { t.Run("InsertAuthors", func(t *testing.T) { authorsToInsert := []CreateOrUpdateAuthorParams{ - {P0: 1, P1: "Лев Толстой", P2: ptr("Русский писатель, автор \"Война и мир\"")}, - {P0: 2, P1: "Александр Пушкин", P2: ptr("Автор \"Евгения Онегина\"")}, - {P0: 3, P1: "Александр Пушкин", P2: ptr("Русский поэт, драматург и прозаик")}, - {P0: 4, P1: "Фёдор Достоевский", P2: ptr("Автор \"Преступление и наказание\"")}, - {P0: 5, P1: "Николай Гоголь", P2: ptr("Автор \"Мёртвые души\"")}, - {P0: 6, P1: "Антон Чехов", P2: nil}, - {P0: 7, P1: "Иван Тургенев", P2: ptr("Автор \"Отцы и дети\"")}, - {P0: 8, P1: "Михаил Лермонтов", P2: nil}, - {P0: 9, P1: "Даниил Хармс", P2: ptr("Абсурдист, писатель и поэт")}, - {P0: 10, P1: "Максим Горький", P2: ptr("Автор \"На дне\"")}, - {P0: 11, P1: "Владимир Маяковский", P2: nil}, - {P0: 12, P1: "Сергей Есенин", P2: ptr("Русский лирик")}, - {P0: 13, P1: "Борис Пастернак", P2: ptr("Автор \"Доктор Живаго\"")}, + {P0: 1, P1: "Leo Tolstoy", P2: ptr("Russian writer, author of \"War and Peace\"")}, + {P0: 2, P1: "Alexander Pushkin", P2: ptr("Author of \"Eugene Onegin\"")}, + {P0: 3, P1: "Alexander Pushkin", P2: ptr("Russian poet, playwright, and prose writer")}, + {P0: 4, P1: "Fyodor Dostoevsky", P2: ptr("Author of \"Crime and Punishment\"")}, + {P0: 5, P1: "Nikolai Gogol", P2: ptr("Author of \"Dead Souls\"")}, + {P0: 6, P1: "Anton Chekhov", P2: nil}, + {P0: 7, P1: "Ivan Turgenev", P2: ptr("Author of \"Fathers and Sons\"")}, + {P0: 8, P1: "Mikhail Lermontov", P2: nil}, + {P0: 9, P1: "Daniil Kharms", P2: ptr("Absurdist, writer and poet")}, + {P0: 10, P1: "Maxim Gorky", P2: ptr("Author of \"At the Bottom\"")}, + {P0: 11, P1: "Vladimir Mayakovsky", P2: nil}, + {P0: 12, P1: "Sergei Yesenin", P2: ptr("Russian lyric poet")}, + {P0: 13, P1: "Boris Pasternak", P2: ptr("Author of \"Doctor Zhivago\"")}, } for _, author := range authorsToInsert { @@ -79,7 +79,7 @@ func TestAuthors(t *testing.T) { var i uint64 for i = 1; i <= 13; i++ { if err := q.DeleteAuthor(ctx, i, query.WithIdempotent()); err != nil { - t.Fatalf("failed to delete authors: %v", err) + t.Fatalf("failed to delete author: %v", err) } } authors, err := q.ListAuthors(ctx) diff --git a/internal/engine/ydb/catalog_tests/alter_group_test.go b/internal/engine/ydb/catalog_tests/alter_group_test.go index eef9f919e9..297d9b326a 100644 --- a/internal/engine/ydb/catalog_tests/alter_group_test.go +++ b/internal/engine/ydb/catalog_tests/alter_group_test.go @@ -102,10 +102,10 @@ func TestAlterGroup(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Запрос %q не распарсен", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], @@ -115,7 +115,7 @@ func TestAlterGroup(t *testing.T) { cmpopts.IgnoreFields(ast.A_Const{}, "Location"), ) if diff != "" { - t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) } }) } diff --git a/internal/engine/ydb/catalog_tests/alter_user_test.go b/internal/engine/ydb/catalog_tests/alter_user_test.go index dd4dc5bb93..64a9891e8c 100644 --- a/internal/engine/ydb/catalog_tests/alter_user_test.go +++ b/internal/engine/ydb/catalog_tests/alter_user_test.go @@ -28,8 +28,8 @@ func TestAlterUser(t *testing.T) { Options: &ast.List{ Items: []ast.Node{ &ast.DefElem{ - Defname: strPtr("rename"), - Arg: &ast.String{Str: "queen"}, + Defname: strPtr("rename"), + Arg: &ast.String{Str: "queen"}, Defaction: ast.DefElemAction(1), }, }, @@ -133,10 +133,10 @@ func TestAlterUser(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Запрос %q не распарсен", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], @@ -146,7 +146,7 @@ func TestAlterUser(t *testing.T) { cmpopts.IgnoreFields(ast.A_Const{}, "Location"), ) if diff != "" { - t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) } }) } diff --git a/internal/engine/ydb/catalog_tests/create_group_test.go b/internal/engine/ydb/catalog_tests/create_group_test.go index 724e912168..bc8d8369fd 100644 --- a/internal/engine/ydb/catalog_tests/create_group_test.go +++ b/internal/engine/ydb/catalog_tests/create_group_test.go @@ -90,10 +90,10 @@ func TestCreateGroup(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Запрос %q не распарсен", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], @@ -103,7 +103,7 @@ func TestCreateGroup(t *testing.T) { cmpopts.IgnoreFields(ast.A_Const{}, "Location"), ) if diff != "" { - t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) } }) } diff --git a/internal/engine/ydb/catalog_tests/create_user_test.go b/internal/engine/ydb/catalog_tests/create_user_test.go index be53e9dd79..108d282c7c 100644 --- a/internal/engine/ydb/catalog_tests/create_user_test.go +++ b/internal/engine/ydb/catalog_tests/create_user_test.go @@ -110,10 +110,10 @@ func TestCreateUser(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Запрос %q не распарсен", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], @@ -122,7 +122,7 @@ func TestCreateUser(t *testing.T) { cmpopts.IgnoreFields(ast.DefElem{}, "Location"), ) if diff != "" { - t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) } }) } diff --git a/internal/engine/ydb/catalog_tests/delete_test.go b/internal/engine/ydb/catalog_tests/delete_test.go index b75591a9ef..ab7b709be9 100644 --- a/internal/engine/ydb/catalog_tests/delete_test.go +++ b/internal/engine/ydb/catalog_tests/delete_test.go @@ -178,10 +178,10 @@ func TestDelete(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Запрос %q не распарсен", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], @@ -193,7 +193,7 @@ func TestDelete(t *testing.T) { cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), ) if diff != "" { - t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) } }) } diff --git a/internal/engine/ydb/catalog_tests/drop_role_test.go b/internal/engine/ydb/catalog_tests/drop_role_test.go index 1d7c6a7658..224be53ed1 100644 --- a/internal/engine/ydb/catalog_tests/drop_role_test.go +++ b/internal/engine/ydb/catalog_tests/drop_role_test.go @@ -69,10 +69,10 @@ func TestDropRole(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Error parsing %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Statement %q was not parsed", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], diff --git a/internal/engine/ydb/catalog_tests/insert_test.go b/internal/engine/ydb/catalog_tests/insert_test.go index 40f116a3ba..4dea2ceccb 100644 --- a/internal/engine/ydb/catalog_tests/insert_test.go +++ b/internal/engine/ydb/catalog_tests/insert_test.go @@ -120,10 +120,10 @@ func TestInsert(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Запрос %q не распарсен", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], @@ -135,7 +135,7 @@ func TestInsert(t *testing.T) { cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), ) if diff != "" { - t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) } }) } diff --git a/internal/engine/ydb/catalog_tests/pragma_test.go b/internal/engine/ydb/catalog_tests/pragma_test.go index 9db4406c53..fef5f76f88 100644 --- a/internal/engine/ydb/catalog_tests/pragma_test.go +++ b/internal/engine/ydb/catalog_tests/pragma_test.go @@ -98,10 +98,10 @@ func TestPragma(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Запрос %q не распарсен", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], @@ -111,7 +111,7 @@ func TestPragma(t *testing.T) { cmpopts.IgnoreFields(ast.A_Const{}, "Location"), ) if diff != "" { - t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) } }) } diff --git a/internal/engine/ydb/catalog_tests/select_test.go b/internal/engine/ydb/catalog_tests/select_test.go index 47794ce81f..fa7b22677c 100644 --- a/internal/engine/ydb/catalog_tests/select_test.go +++ b/internal/engine/ydb/catalog_tests/select_test.go @@ -374,10 +374,10 @@ func TestSelect(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Запрос %q не распарсен", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], @@ -390,7 +390,7 @@ func TestSelect(t *testing.T) { cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), ) if diff != "" { - t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) } }) } diff --git a/internal/engine/ydb/catalog_tests/update_test.go b/internal/engine/ydb/catalog_tests/update_test.go index 57c90bc2a4..f4f00a92bc 100644 --- a/internal/engine/ydb/catalog_tests/update_test.go +++ b/internal/engine/ydb/catalog_tests/update_test.go @@ -163,10 +163,10 @@ func TestUpdate(t *testing.T) { t.Run(tc.stmt, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(tc.stmt)) if err != nil { - t.Fatalf("Ошибка парсинга запроса %q: %v", tc.stmt, err) + t.Fatalf("Failed to parse query %q: %v", tc.stmt, err) } if len(stmts) == 0 { - t.Fatalf("Запрос %q не распарсен", tc.stmt) + t.Fatalf("Query %q was not parsed", tc.stmt) } diff := cmp.Diff(tc.expected, &stmts[0], @@ -178,7 +178,7 @@ func TestUpdate(t *testing.T) { cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), ) if diff != "" { - t.Errorf("Несовпадение AST (-ожидалось +получено):\n%s", diff) + t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) } }) } From f2d7aeab4c99eea98c10825672c10a3124fbdfea Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov <150552906+1NepuNep1@users.noreply.github.com> Date: Mon, 8 Sep 2025 13:23:38 +0300 Subject: [PATCH 13/18] Expanding YQL Syntax support in sqlc engine (Refactored SELECT (Supported DISTINCT, HAVING, GROUP BY, ORDER BY, LIMIT, OFFSET, UNION), added ALTER TABLE, DO stmt) (#10) * Added ALTER TABLE support * Added DO stmt support * Fixed DO stmt generation * Refactored SELECT logic. Supported DISTINCT, HAVING, GROUP BY, ORDER BY, LIMIT, OFFSET, UNION * Fixed style and types * Refactored DO stmt style & added some additional converts --------- Co-authored-by: Viktor Pentyukhov --- .../engine/ydb/catalog_tests/delete_test.go | 17 +- .../engine/ydb/catalog_tests/insert_test.go | 21 +- .../engine/ydb/catalog_tests/select_test.go | 364 +++++++++++++- .../engine/ydb/catalog_tests/update_test.go | 5 + internal/engine/ydb/convert.go | 469 +++++++++++++++--- internal/engine/ydb/utils.go | 29 +- 6 files changed, 830 insertions(+), 75 deletions(-) diff --git a/internal/engine/ydb/catalog_tests/delete_test.go b/internal/engine/ydb/catalog_tests/delete_test.go index ab7b709be9..1885deb9ce 100644 --- a/internal/engine/ydb/catalog_tests/delete_test.go +++ b/internal/engine/ydb/catalog_tests/delete_test.go @@ -101,6 +101,7 @@ func TestDelete(t *testing.T) { }, }, OnSelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, ValuesLists: &ast.List{ Items: []ast.Node{ &ast.List{ @@ -110,8 +111,12 @@ func TestDelete(t *testing.T) { }, }, }, - FromClause: &ast.List{}, - TargetList: &ast.List{}, + FromClause: &ast.List{}, + TargetList: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, ReturningList: &ast.List{ Items: []ast.Node{ @@ -145,6 +150,7 @@ func TestDelete(t *testing.T) { }, }, OnSelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -153,7 +159,12 @@ func TestDelete(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, ReturningList: &ast.List{ Items: []ast.Node{ diff --git a/internal/engine/ydb/catalog_tests/insert_test.go b/internal/engine/ydb/catalog_tests/insert_test.go index 4dea2ceccb..c60d0920da 100644 --- a/internal/engine/ydb/catalog_tests/insert_test.go +++ b/internal/engine/ydb/catalog_tests/insert_test.go @@ -28,6 +28,7 @@ func TestInsert(t *testing.T) { }, }, SelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, ValuesLists: &ast.List{ Items: []ast.Node{ &ast.List{ @@ -40,6 +41,10 @@ func TestInsert(t *testing.T) { }, TargetList: &ast.List{}, FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, OnConflictClause: &ast.OnConflictClause{}, ReturningList: &ast.List{ @@ -68,6 +73,7 @@ func TestInsert(t *testing.T) { }, }, SelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, ValuesLists: &ast.List{ Items: []ast.Node{ &ast.List{ @@ -79,6 +85,10 @@ func TestInsert(t *testing.T) { }, TargetList: &ast.List{}, FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, OnConflictClause: &ast.OnConflictClause{ Action: ast.OnConflictAction_INSERT_OR_IGNORE, @@ -106,7 +116,16 @@ func TestInsert(t *testing.T) { Stmt: &ast.InsertStmt{ Relation: &ast.RangeVar{Relname: strPtr("users")}, Cols: &ast.List{Items: []ast.Node{&ast.ResTarget{Name: strPtr("id")}}}, - SelectStmt: &ast.SelectStmt{ValuesLists: &ast.List{Items: []ast.Node{&ast.List{Items: []ast.Node{&ast.A_Const{Val: &ast.Integer{Ival: 4}}}}}}, TargetList: &ast.List{}, FromClause: &ast.List{}}, + SelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + ValuesLists: &ast.List{Items: []ast.Node{&ast.List{Items: []ast.Node{&ast.A_Const{Val: &ast.Integer{Ival: 4}}}}}}, + TargetList: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + }, OnConflictClause: &ast.OnConflictClause{Action: ast.OnConflictAction_UPSERT}, ReturningList: &ast.List{Items: []ast.Node{&ast.ResTarget{Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}}, Indirection: &ast.List{}}}}, }, diff --git a/internal/engine/ydb/catalog_tests/select_test.go b/internal/engine/ydb/catalog_tests/select_test.go index fa7b22677c..f01171f12a 100644 --- a/internal/engine/ydb/catalog_tests/select_test.go +++ b/internal/engine/ydb/catalog_tests/select_test.go @@ -25,6 +25,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -34,7 +35,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -44,6 +50,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -53,7 +60,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -63,6 +75,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -72,7 +85,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -82,6 +100,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -91,7 +110,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -101,6 +125,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -108,7 +133,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -118,6 +148,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -127,7 +158,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -137,6 +173,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -166,7 +203,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -178,6 +220,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -198,6 +241,11 @@ func TestSelect(t *testing.T) { }, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -207,6 +255,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -228,6 +277,11 @@ func TestSelect(t *testing.T) { }, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -237,6 +291,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -259,6 +314,11 @@ func TestSelect(t *testing.T) { }, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -268,6 +328,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -307,6 +368,11 @@ func TestSelect(t *testing.T) { }, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -316,6 +382,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -362,6 +429,287 @@ func TestSelect(t *testing.T) { Val: &ast.Integer{Ival: 30}, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + }, + }, + }, + }, + { + stmt: `(SELECT 1) UNION ALL (SELECT 2)`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + Op: ast.Union, + All: true, + Larg: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.Integer{Ival: 1}, + }, + }, + }, + }, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + }, + Rarg: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.Integer{Ival: 2}, + }, + }, + }, + }, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + }, + }, + }, + }, + }, + { + stmt: `SELECT id FROM users ORDER BY id DESC`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + LockingClause: &ast.List{}, + SortClause: &ast.List{ + Items: []ast.Node{ + &ast.SortBy{ + Node: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + SortbyDir: ast.SortByDirDesc, + UseOp: &ast.List{}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `SELECT id FROM users LIMIT 10 OFFSET 5`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + LimitCount: &ast.A_Const{ + Val: &ast.Integer{Ival: 10}, + }, + LimitOffset: &ast.A_Const{ + Val: &ast.Integer{Ival: 5}, + }, + }, + }, + }, + }, + { + stmt: `SELECT id FROM users WHERE id > 10 GROUP BY id HAVING id > 10`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + GroupClause: &ast.List{ + Items: []ast.Node{ + &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + WhereClause: &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: ">"}, + }, + }, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 10}, + }, + }, + HavingClause: &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: ">"}, + }, + }, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 10}, + }, + }, + }, + }, + }, + }, + { + stmt: `SELECT id FROM users GROUP BY ROLLUP (id)`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + GroupClause: &ast.List{ + Items: []ast.Node{ + &ast.GroupingSet{ + Kind: 1, // T_GroupingSet: ROLLUP + Content: &ast.List{ + Items: []ast.Node{ + &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + }, + }, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -382,12 +730,12 @@ func TestSelect(t *testing.T) { diff := cmp.Diff(tc.expected, &stmts[0], cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), - // cmpopts.IgnoreFields(ast.SelectStmt{}, "Location"), cmpopts.IgnoreFields(ast.A_Const{}, "Location"), cmpopts.IgnoreFields(ast.ResTarget{}, "Location"), cmpopts.IgnoreFields(ast.ColumnRef{}, "Location"), cmpopts.IgnoreFields(ast.A_Expr{}, "Location"), cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), + cmpopts.IgnoreFields(ast.SortBy{}, "Location"), ) if diff != "" { t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) diff --git a/internal/engine/ydb/catalog_tests/update_test.go b/internal/engine/ydb/catalog_tests/update_test.go index f4f00a92bc..b7ebeb3d6a 100644 --- a/internal/engine/ydb/catalog_tests/update_test.go +++ b/internal/engine/ydb/catalog_tests/update_test.go @@ -121,6 +121,7 @@ func TestUpdate(t *testing.T) { }, }, OnSelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, ValuesLists: &ast.List{ Items: []ast.Node{ &ast.List{ @@ -132,6 +133,10 @@ func TestUpdate(t *testing.T) { }, FromClause: &ast.List{}, TargetList: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, ReturningList: &ast.List{ Items: []ast.Node{ diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index b4d9490d0b..250dc467e0 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -54,7 +54,7 @@ func NewIdentifier(t string) *ast.String { return &ast.String{Str: identifier(t)} } -func (c *cc) convertDrop_role_stmtCOntext(n *parser.Drop_role_stmtContext) ast.Node { +func (c *cc) convertDrop_role_stmtContext(n *parser.Drop_role_stmtContext) ast.Node { if n.DROP() == nil || (n.USER() == nil && n.GROUP() == nil) || len(n.AllRole_name()) == 0 { return todo("Drop_role_stmtContext", n) } @@ -467,6 +467,145 @@ func (c *cc) convertRollback_stmtContext(n *parser.Rollback_stmtContext) ast.Nod return todo("convertRollback_stmtContext", n) } +func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) ast.Node { + if n.ALTER() == nil || n.TABLE() == nil || n.Simple_table_ref() == nil || len(n.AllAlter_table_action()) == 0 { + return todo("convertAlter_table_stmtContext", n) + } + + stmt := &ast.AlterTableStmt{ + Table: parseTableName(n.Simple_table_ref().Simple_table_ref_core()), + Cmds: &ast.List{}, + } + + for _, action := range n.AllAlter_table_action() { + if action == nil { + continue + } + + switch { + case action.Alter_table_add_column() != nil: + ac := action.Alter_table_add_column() + if ac.ADD() != nil && ac.Column_schema() != nil { + columnDef := c.convertColumnSchema(ac.Column_schema().(*parser.Column_schemaContext)) + stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ + Name: &columnDef.Colname, + Subtype: ast.AT_AddColumn, + Def: columnDef, + }) + } + case action.Alter_table_drop_column() != nil: + ac := action.Alter_table_drop_column() + if ac.DROP() != nil && ac.An_id() != nil { + name := parseAnId(ac.An_id()) + stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ + Name: &name, + Subtype: ast.AT_DropColumn, + }) + } + case action.Alter_table_alter_column_drop_not_null() != nil: + ac := action.Alter_table_alter_column_drop_not_null() + if ac.DROP() != nil && ac.NOT() != nil && ac.NULL() != nil && ac.An_id() != nil { + name := parseAnId(ac.An_id()) + stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ + Name: &name, + Subtype: ast.AT_DropNotNull, + }) + } + case action.Alter_table_rename_to() != nil: + ac := action.Alter_table_rename_to() + if ac.RENAME() != nil && ac.TO() != nil && ac.An_id_table() != nil { + // TODO: Returning here may be incorrect if there are multiple specs + newName := parseAnIdTable(ac.An_id_table()) + return &ast.RenameTableStmt{ + Table: stmt.Table, + NewName: &newName, + } + } + case action.Alter_table_add_index() != nil, + action.Alter_table_drop_index() != nil, + action.Alter_table_add_column_family() != nil, + action.Alter_table_alter_column_family() != nil, + action.Alter_table_set_table_setting_uncompat() != nil, + action.Alter_table_set_table_setting_compat() != nil, + action.Alter_table_reset_table_setting() != nil, + action.Alter_table_add_changefeed() != nil, + action.Alter_table_alter_changefeed() != nil, + action.Alter_table_drop_changefeed() != nil, + action.Alter_table_rename_index_to() != nil, + action.Alter_table_alter_index() != nil: + // All these actions do not change column schema relevant to sqlc; no-op. + // Intentionally ignored. + } + } + + return stmt +} + +func (c *cc) convertDo_stmtContext(n *parser.Do_stmtContext) ast.Node { + if n.DO() == nil || (n.Call_action() == nil && n.Inline_action() == nil) { + return todo("convertDo_stmtContext", n) + } + + switch { + case n.Call_action() != nil: + return c.convert(n.Call_action()) + + case n.Inline_action() != nil: + return c.convert(n.Inline_action()) + } + + return todo("convertDo_stmtContext", n) +} + +func (c *cc) convertCall_actionContext(n *parser.Call_actionContext) ast.Node { + if n == nil { + return nil + } + if n.LPAREN() != nil && n.RPAREN() != nil { + funcCall := &ast.FuncCall{ + Funcname: &ast.List{}, + Args: &ast.List{}, + AggOrder: &ast.List{}, + } + + if n.Bind_parameter() != nil { + funcCall.Funcname.Items = append(funcCall.Funcname.Items, c.convert(n.Bind_parameter())) + } else if n.EMPTY_ACTION() != nil { + funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: "EMPTY_ACTION"}) + } + + if n.Expr_list() != nil { + for _, expr := range n.Expr_list().AllExpr() { + funcCall.Args.Items = append(funcCall.Args.Items, c.convert(expr)) + } + } + + return &ast.DoStmt{ + Args: &ast.List{Items: []ast.Node{funcCall}}, + } + } + return todo("convertCall_actionContext", n) +} + +func (c *cc) convertInline_actionContext(n *parser.Inline_actionContext) ast.Node { + if n == nil { + return nil + } + if n.BEGIN() != nil && n.END() != nil && n.DO() != nil { + args := &ast.List{} + if defineBody := n.Define_action_or_subquery_body(); defineBody != nil { + cores := defineBody.AllSql_stmt_core() + for _, stmtCore := range cores { + if converted := c.convert(stmtCore); converted != nil { + args.Items = append(args.Items, converted) + } + } + } + return &ast.DoStmt{Args: args} + } + return todo("convertInline_actionContext", n) +} + func (c *cc) convertDrop_table_stmtContext(n *parser.Drop_table_stmtContext) ast.Node { if n.DROP() != nil && (n.TABLESTORE() != nil || (n.EXTERNAL() != nil && n.TABLE() != nil) || n.TABLE() != nil) { name := parseTableName(n.Simple_table_ref().Simple_table_ref_core()) @@ -509,12 +648,9 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { if valSource != nil { switch { case valSource.Values_stmt() != nil: - source = &ast.SelectStmt{ - ValuesLists: c.convert(valSource.Values_stmt()).(*ast.List), - FromClause: &ast.List{}, - TargetList: &ast.List{}, - } - + stmt := emptySelectStmt() + stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + source = stmt case valSource.Select_stmt() != nil: source = c.convert(valSource.Select_stmt()) } @@ -704,12 +840,9 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { if valSource != nil { switch { case valSource.Values_stmt() != nil: - source = &ast.SelectStmt{ - ValuesLists: c.convert(valSource.Values_stmt()).(*ast.List), - FromClause: &ast.List{}, - TargetList: &ast.List{}, - } - + stmt := emptySelectStmt() + stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + source = stmt case valSource.Select_stmt() != nil: source = c.convert(valSource.Select_stmt()) } @@ -777,12 +910,9 @@ func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast if valSource != nil { switch { case valSource.Values_stmt() != nil: - source = &ast.SelectStmt{ - ValuesLists: c.convert(valSource.Values_stmt()).(*ast.List), - FromClause: &ast.List{}, - TargetList: &ast.List{}, - } - + stmt := emptySelectStmt() + stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + source = stmt case valSource.Select_stmt() != nil: source = c.convert(valSource.Select_stmt()) } @@ -859,60 +989,106 @@ func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_li } func (c *cc) convertSelectStmtContext(n *parser.Select_stmtContext) ast.Node { + if len(n.AllSelect_kind_parenthesis()) == 0 { + return todo("convertSelectStmtContext", n) + } + skp := n.Select_kind_parenthesis(0) if skp == nil { - return nil + return todo("convertSelectStmtContext", skp) } - partial := skp.Select_kind_partial() - if partial == nil { - return nil + + stmt := c.convertSelectKindParenthesis(skp) + left, ok := stmt.(*ast.SelectStmt) + if left == nil || !ok { + return todo("convertSelectKindParenthesis", skp) } + + kinds := n.AllSelect_kind_parenthesis() + ops := n.AllSelect_op() + + for i := 1; i < len(kinds); i++ { + stmt := c.convertSelectKindParenthesis(kinds[i]) + right, ok := stmt.(*ast.SelectStmt) + if right == nil || !ok { + return todo("convertSelectKindParenthesis", kinds[i]) + } + + var op ast.SetOperation + var all bool + if i-1 < len(ops) && ops[i-1] != nil { + so := ops[i-1] + switch { + case so.UNION() != nil: + op = ast.Union + case so.INTERSECT() != nil: + log.Fatalf("YDB: INTERSECT is not implemented yet") + case so.EXCEPT() != nil: + log.Fatalf("YDB: EXCEPT is not implemented yet") + default: + op = ast.None + } + all = so.ALL() != nil + } + larg := left + left = emptySelectStmt() + left.Op = op + left.All = all + left.Larg = larg + left.Rarg = right + } + + return left +} + +func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisContext) ast.Node { + if n == nil || n.Select_kind_partial() == nil { + return todo("convertSelectKindParenthesis", n) + } + partial := n.Select_kind_partial() + sk := partial.Select_kind() if sk == nil { - return nil + return todo("convertSelectKind", sk) } - selectStmt := &ast.SelectStmt{} + var base ast.Node switch { - case sk.Process_core() != nil: - cnode := c.convert(sk.Process_core()) - stmt, ok := cnode.(*ast.SelectStmt) - if !ok { - return nil - } - selectStmt = stmt case sk.Select_core() != nil: - cnode := c.convert(sk.Select_core()) - stmt, ok := cnode.(*ast.SelectStmt) - if !ok { - return nil - } - selectStmt = stmt + base = c.convertSelectCoreContext(sk.Select_core()) + case sk.Process_core() != nil: + log.Fatalf("PROCESS is not supported in YDB engine") case sk.Reduce_core() != nil: - cnode := c.convert(sk.Reduce_core()) - stmt, ok := cnode.(*ast.SelectStmt) - if !ok { - return nil - } - selectStmt = stmt + log.Fatalf("REDUCE is not supported in YDB engine") + } + stmt, ok := base.(*ast.SelectStmt) + if !ok || stmt == nil { + return todo("convertSelectKindParenthesis", sk.Select_core()) } - // todo: cover process and reduce core, - // todo: cover LIMIT and OFFSET + // TODO: handle INTO RESULT clause - return selectStmt + if partial.LIMIT() != nil { + exprs := partial.AllExpr() + if len(exprs) >= 1 { + stmt.LimitCount = c.convert(exprs[0]) + } + if partial.OFFSET() != nil { + if len(exprs) >= 2 { + stmt.LimitOffset = c.convert(exprs[1]) + } + } + } + + return stmt } -func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { - stmt := &ast.SelectStmt{ - TargetList: &ast.List{}, - FromClause: &ast.List{}, - } +func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { + stmt := emptySelectStmt() if n.Opt_set_quantifier() != nil { oq := n.Opt_set_quantifier() if oq.DISTINCT() != nil { - // todo: add distinct support - stmt.DistinctClause = &ast.List{} + stmt.DistinctClause.Items = append(stmt.DistinctClause.Items, &ast.TODO{}) // trick to handle distinct } } resultCols := n.AllResult_column() @@ -932,8 +1108,14 @@ func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { Items: items, } } + + // TODO: handle WITHOUT clause + jsList := n.AllJoin_source() - if len(n.AllFROM()) > 0 && len(jsList) > 0 { + if len(n.AllFROM()) > 1 { + log.Fatalf("YDB: Only one FROM clause is allowed") + } + if len(jsList) > 0 { var fromItems []ast.Node for _, js := range jsList { jsCon, ok := js.(*parser.Join_sourceContext) @@ -950,15 +1132,137 @@ func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { Items: fromItems, } } + + exprIdx := 0 if n.WHERE() != nil { - whereCtx := n.Expr(0) - if whereCtx != nil { + if whereCtx := n.Expr(exprIdx); whereCtx != nil { stmt.WhereClause = c.convert(whereCtx) } + exprIdx++ + } + if n.HAVING() != nil { + if havingCtx := n.Expr(exprIdx); havingCtx != nil { + stmt.HavingClause = c.convert(havingCtx) + } + exprIdx++ + } + + if gbc := n.Group_by_clause(); gbc != nil { + if gel := gbc.Grouping_element_list(); gel != nil { + var groups []ast.Node + for _, ne := range gel.AllGrouping_element() { + groups = append(groups, c.convert(ne)) + } + if len(groups) > 0 { + stmt.GroupClause = &ast.List{Items: groups} + } + } + } + + if ext := n.Ext_order_by_clause(); ext != nil { + if ob := ext.Order_by_clause(); ob != nil && ob.ORDER() != nil && ob.BY() != nil { + // TODO: ASSUME ORDER BY + if sl := ob.Sort_specification_list(); sl != nil { + var orderItems []ast.Node + for _, sp := range sl.AllSort_specification() { + ss, ok := sp.(*parser.Sort_specificationContext) + if !ok || ss == nil { + continue + } + expr := c.convert(ss.Expr()) + dir := ast.SortByDirDefault + if ss.ASC() != nil { + dir = ast.SortByDirAsc + } else if ss.DESC() != nil { + dir = ast.SortByDirDesc + } + orderItems = append(orderItems, &ast.SortBy{ + Node: expr, + SortbyDir: dir, + SortbyNulls: ast.SortByNullsUndefined, + UseOp: &ast.List{}, + Location: c.pos(ss.GetStart()), + }) + } + if len(orderItems) > 0 { + stmt.SortClause = &ast.List{Items: orderItems} + } + } + } } return stmt } +func (c *cc) convertGrouping_elementContext(n parser.IGrouping_elementContext) ast.Node { + if n == nil { + return todo("convertGrouping_elementContext", n) + } + if ogs := n.Ordinary_grouping_set(); ogs != nil { + return c.convert(ogs) + } + if rl := n.Rollup_list(); rl != nil { + return c.convert(rl) + } + if cl := n.Cube_list(); cl != nil { + return c.convert(cl) + } + if gss := n.Grouping_sets_specification(); gss != nil { + return c.convert(gss) + } + return todo("convertGrouping_elementContext", n) +} + +func (c *cc) convertOrdinary_grouping_setContext(n *parser.Ordinary_grouping_setContext) ast.Node { + if n == nil || n.Named_expr() == nil { + return todo("convertOrdinary_grouping_setContext", n) + } + + return c.convert(n.Named_expr()) +} + +func (c *cc) convertRollup_listContext(n *parser.Rollup_listContext) ast.Node { + if n == nil || n.ROLLUP() == nil || n.LPAREN() == nil || n.RPAREN() == nil { + return todo("convertRollup_listContext", n) + } + + var items []ast.Node + if list := n.Ordinary_grouping_set_list(); list != nil { + for _, ogs := range list.AllOrdinary_grouping_set() { + items = append(items, c.convert(ogs)) + } + } + return &ast.GroupingSet{Kind: 1, Content: &ast.List{Items: items}} +} + +func (c *cc) convertCube_listContext(n *parser.Cube_listContext) ast.Node { + if n == nil || n.CUBE() == nil || n.LPAREN() == nil || n.RPAREN() == nil { + return todo("convertCube_listContext", n) + } + + var items []ast.Node + if list := n.Ordinary_grouping_set_list(); list != nil { + for _, ogs := range list.AllOrdinary_grouping_set() { + items = append(items, c.convert(ogs)) + } + } + + return &ast.GroupingSet{Kind: 2, Content: &ast.List{Items: items}} +} + +func (c *cc) convertGrouping_sets_specificationContext(n *parser.Grouping_sets_specificationContext) ast.Node { + if n == nil || n.GROUPING() == nil || n.SETS() == nil || n.LPAREN() == nil || n.RPAREN() == nil { + return todo("convertGrouping_sets_specificationContext", n) + } + + var items []ast.Node + if gel := n.Grouping_element_list(); gel != nil { + for _, ge := range gel.AllGrouping_element() { + items = append(items, c.convert(ge)) + } + } + return &ast.GroupingSet{Kind: 3, Content: &ast.List{Items: items}} +} + func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { // todo: support opt_id_prefix target := &ast.ResTarget{ @@ -1683,6 +1987,22 @@ func (c *cc) convertSqlStmtCore(n parser.ISql_stmt_coreContext) ast.Node { return nil } +func (c *cc) convertNamed_exprContext(n *parser.Named_exprContext) ast.Node { + if n == nil || n.Expr() == nil { + return todo("convertNamed_exprContext", n) + } + expr := c.convert(n.Expr()) + if n.AS() != nil && n.An_id_or_type() != nil { + name := parseAnIdOrType(n.An_id_or_type()) + return &ast.ResTarget{ + Name: &name, + Val: expr, + Location: c.pos(n.Expr().GetStart()), + } + } + return expr +} + func (c *cc) convertExpr(n *parser.ExprContext) ast.Node { if n == nil { return nil @@ -2457,9 +2777,6 @@ func (c *cc) convert(node node) ast.Node { case *parser.Select_stmtContext: return c.convertSelectStmtContext(n) - case *parser.Select_coreContext: - return c.convertSelectCoreContext(n) - case *parser.Result_columnContext: return c.convertResultColumn(n) @@ -2553,6 +2870,12 @@ func (c *cc) convert(node node) ast.Node { case *parser.Update_stmtContext: return c.convertUpdate_stmtContext(n) + case *parser.Alter_table_stmtContext: + return c.convertAlter_table_stmtContext(n) + + case *parser.Do_stmtContext: + return c.convertDo_stmtContext(n) + case *parser.Drop_table_stmtContext: return c.convertDrop_table_stmtContext(n) @@ -2593,7 +2916,31 @@ func (c *cc) convert(node node) ast.Node { return c.convertAlter_group_stmtContext(n) case *parser.Drop_role_stmtContext: - return c.convertDrop_role_stmtCOntext(n) + return c.convertDrop_role_stmtContext(n) + + case *parser.Grouping_elementContext: + return c.convertGrouping_elementContext(n) + + case *parser.Ordinary_grouping_setContext: + return c.convertOrdinary_grouping_setContext(n) + + case *parser.Rollup_listContext: + return c.convertRollup_listContext(n) + + case *parser.Cube_listContext: + return c.convertCube_listContext(n) + + case *parser.Grouping_sets_specificationContext: + return c.convertGrouping_sets_specificationContext(n) + + case *parser.Named_exprContext: + return c.convertNamed_exprContext(n) + + case *parser.Call_actionContext: + return c.convertCall_actionContext(n) + + case *parser.Inline_actionContext: + return c.convertInline_actionContext(n) default: return todo("convert(case=default)", n) diff --git a/internal/engine/ydb/utils.go b/internal/engine/ydb/utils.go index 3847ee5055..f2023e8ba9 100755 --- a/internal/engine/ydb/utils.go +++ b/internal/engine/ydb/utils.go @@ -85,7 +85,7 @@ func parseIdOrType(ctx parser.IId_or_typeContext) string { } Id := ctx.(*parser.Id_or_typeContext) if Id.Id() != nil { - return identifier(parseIdTable(Id.Id())) + return identifier(parseId(Id.Id())) } return "" @@ -112,13 +112,25 @@ func parseAnIdSchema(ctx parser.IAn_id_schemaContext) string { return "" } -func parseIdTable(ctx parser.IIdContext) string { +func parseId(ctx parser.IIdContext) string { if ctx == nil { return "" } return ctx.GetText() } +func parseAnIdTable(ctx parser.IAn_id_tableContext) string { + if ctx == nil { + return "" + } + if id := ctx.Id_table(); id != nil { + return id.GetText() + } else if str := ctx.STRING_VALUE(); str != nil { + return str.GetText() + } + return "" +} + func parseIntegerValue(text string) (int64, error) { text = strings.ToLower(text) base := 10 @@ -194,3 +206,16 @@ func byteOffsetFromRuneIndex(s string, runeIndex int) int { } return bytePos } + +func emptySelectStmt() *ast.SelectStmt { + return &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + } +} From 296cca76fbbfeb4405f3ad13e2439fc3526e3ca4 Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov <150552906+1NepuNep1@users.noreply.github.com> Date: Tue, 9 Sep 2025 14:45:14 +0300 Subject: [PATCH 14/18] Made new params building logic & got rid off .Query handle (#11) * Made new params building logic & got rid off .Query handle (replaced it with QueryResultSet) * Cosmetics * Bug fixes --------- Co-authored-by: Viktor Pentyukhov --- examples/authors/ydb/query.sql.go | 51 +++++---- internal/codegen/golang/query.go | 101 ++++++++++++++++++ .../templates/ydb-go-sdk/queryCode.tmpl | 43 +++----- internal/codegen/golang/ydb_type.go | 18 ++++ .../ydb/catalog_tests/create_table_test.go | 2 +- internal/engine/ydb/convert.go | 2 +- 6 files changed, 167 insertions(+), 50 deletions(-) diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index 3233b705d3..e9b6b332a4 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -25,11 +25,15 @@ type CreateOrUpdateAuthorParams struct { func (q *Queries) CreateOrUpdateAuthor(ctx context.Context, arg CreateOrUpdateAuthorParams, opts ...query.ExecuteOption) error { err := q.db.Exec(ctx, createOrUpdateAuthor, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ - "$p0": arg.P0, - "$p1": arg.P1, - "$p2": arg.P2, - })))..., + append(opts, + query.WithParameters( + ydb.ParamsBuilder(). + Param("$p0").Uint64(arg.P0). + Param("$p1").Text(arg.P1). + Param("$p2").BeginOptional().Text(arg.P2).EndOptional(). + Build(), + ), + )..., ) if err != nil { return xerrors.WithStackTrace(err) @@ -43,9 +47,13 @@ DELETE FROM authors WHERE id = $p0 func (q *Queries) DeleteAuthor(ctx context.Context, p0 uint64, opts ...query.ExecuteOption) error { err := q.db.Exec(ctx, deleteAuthor, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ - "$p0": p0, - })))..., + append(opts, + query.WithParameters( + ydb.ParamsBuilder(). + Param("$p0").Uint64(p0). + Build(), + ), + )..., ) if err != nil { return xerrors.WithStackTrace(err) @@ -72,9 +80,13 @@ WHERE id = $p0 LIMIT 1 func (q *Queries) GetAuthor(ctx context.Context, p0 uint64, opts ...query.ExecuteOption) (Author, error) { row, err := q.db.QueryRow(ctx, getAuthor, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ - "$p0": p0, - })))..., + append(opts, + query.WithParameters( + ydb.ParamsBuilder(). + Param("$p0").Uint64(p0). + Build(), + ), + )..., ) var i Author if err != nil { @@ -92,25 +104,20 @@ SELECT id, name, bio FROM authors ORDER BY name ` func (q *Queries) ListAuthors(ctx context.Context, opts ...query.ExecuteOption) ([]Author, error) { - result, err := q.db.Query(ctx, listAuthors, opts...) + result, err := q.db.QueryResultSet(ctx, listAuthors, opts...) if err != nil { return nil, xerrors.WithStackTrace(err) } var items []Author - for set, err := range result.ResultSets(ctx) { + for row, err := range result.Rows(ctx) { if err != nil { return nil, xerrors.WithStackTrace(err) } - for row, err := range set.Rows(ctx) { - if err != nil { - return nil, xerrors.WithStackTrace(err) - } - var i Author - if err := row.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, xerrors.WithStackTrace(err) - } - items = append(items, i) + var i Author + if err := row.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, xerrors.WithStackTrace(err) } + items = append(items, i) } if err := result.Close(ctx); err != nil { return nil, xerrors.WithStackTrace(err) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 02a09c3870..7cda1b7c2b 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts" + "github.com/sqlc-dev/sqlc/internal/codegen/sdk" "github.com/sqlc-dev/sqlc/internal/metadata" "github.com/sqlc-dev/sqlc/internal/plugin" ) @@ -294,6 +295,106 @@ func (v QueryValue) YDBParamMapEntries() string { return "\n" + strings.Join(parts, ",\n") } +// ydbBuilderMethodForColumnType maps a YDB column data type to a ParamsBuilder method name. +func ydbBuilderMethodForColumnType(dbType string) string { + switch strings.ToLower(dbType) { + case "bool": + return "Bool" + case "uint64": + return "Uint64" + case "int64": + return "Int64" + case "uint32": + return "Uint32" + case "int32": + return "Int32" + case "uint16": + return "Uint16" + case "int16": + return "Int16" + case "uint8": + return "Uint8" + case "int8": + return "Int8" + case "float": + return "Float" + case "double": + return "Double" + case "json": + return "JSON" + case "jsondocument": + return "JSONDocument" + case "utf8", "text", "string": + return "Text" + case "date": + return "Date" + case "date32": + return "Date32" + case "datetime": + return "Datetime" + case "timestamp": + return "Timestamp" + case "tzdate": + return "TzDate" + case "tzdatetime": + return "TzDatetime" + case "tztimestamp": + return "TzTimestamp" + + //TODO: support other types + default: + return "" + } +} + +// YDBParamsBuilder emits Go code that constructs YDB params using ParamsBuilder. +func (v QueryValue) YDBParamsBuilder() string { + if v.isEmpty() { + return "" + } + + var lines []string + + for _, field := range v.getParameterFields() { + if field.Column != nil && field.Column.IsNamedParam { + name := field.Column.GetName() + if name == "" { + continue + } + paramName := fmt.Sprintf("%q", addDollarPrefix(name)) + variable := escape(v.VariableForField(field)) + + var method string + if field.Column != nil && field.Column.Type != nil { + method = ydbBuilderMethodForColumnType(sdk.DataType(field.Column.Type)) + } + + goType := field.Type + isPtr := strings.HasPrefix(goType, "*") + if isPtr { + goType = strings.TrimPrefix(goType, "*") + } + + if method == "" { + panic(fmt.Sprintf("unknown YDB column type for param %s (goType=%s)", name, goType)) + } + + if isPtr { + lines = append(lines, fmt.Sprintf("\t\t\tParam(%s).BeginOptional().%s(%s).EndOptional().", paramName, method, variable)) + } else { + lines = append(lines, fmt.Sprintf("\t\t\tParam(%s).%s(%s).", paramName, method, variable)) + } + } + } + + if len(lines) == 0 { + return "" + } + + params := strings.Join(lines, "\n") + return fmt.Sprintf("\nquery.WithParameters(\n\t\tydb.ParamsBuilder().\n%s\n\t\t\tBuild(),\n\t\t),\n", params) +} + func (v QueryValue) getParameterFields() []Field { if v.Struct == nil { return []Field{ diff --git a/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl b/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl index ecd78b1344..c56fc953f8 100644 --- a/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl +++ b/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl @@ -27,8 +27,8 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA {{- if .Arg.IsEmpty }} row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, opts...) {{- else }} - row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ {{.Arg.YDBParamMapEntries}} })))..., + row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, + append(opts, {{.Arg.YDBParamsBuilder}})..., ) {{- end }} {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} @@ -61,10 +61,10 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ([]{{.Ret.DefineType}}, error) { {{- $dbArg := "q.db" }}{{- if $.EmitMethodsWithDBArgument }}{{- $dbArg = "db" }}{{- end -}} {{- if .Arg.IsEmpty }} - result, err := {{$dbArg}}.Query(ctx, {{.ConstantName}}, opts...) + result, err := {{$dbArg}}.QueryResultSet(ctx, {{.ConstantName}}, opts...) {{- else }} - result, err := {{$dbArg}}.Query(ctx, {{.ConstantName}}, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ {{.Arg.YDBParamMapEntries}} })))..., + result, err := {{$dbArg}}.QueryResultSet(ctx, {{.ConstantName}}, + append(opts, {{.Arg.YDBParamsBuilder}})..., ) {{- end }} if err != nil { @@ -79,7 +79,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA {{else}} var items []{{.Ret.DefineType}} {{end -}} - for set, err := range result.ResultSets(ctx) { + for row, err := range result.Rows(ctx) { if err != nil { {{- if $.WrapErrors}} return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) @@ -87,24 +87,15 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA return nil, xerrors.WithStackTrace(err) {{- end }} } - for row, err := range set.Rows(ctx) { - if err != nil { - {{- if $.WrapErrors}} - return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) - {{- else }} - return nil, xerrors.WithStackTrace(err) - {{- end }} - } - var {{.Ret.Name}} {{.Ret.Type}} - if err := row.Scan({{.Ret.Scan}}); err != nil { - {{- if $.WrapErrors}} - return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) - {{- else }} - return nil, xerrors.WithStackTrace(err) - {{- end }} - } - items = append(items, {{.Ret.ReturnName}}) - } + var {{.Ret.Name}} {{.Ret.Type}} + if err := row.Scan({{.Ret.Scan}}); err != nil { + {{- if $.WrapErrors}} + return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + {{- else }} + return nil, xerrors.WithStackTrace(err) + {{- end }} + } + items = append(items, {{.Ret.ReturnName}}) } if err := result.Close(ctx); err != nil { {{- if $.WrapErrors}} @@ -125,8 +116,8 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA {{- if .Arg.IsEmpty }} err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, opts...) {{- else }} - err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ {{.Arg.YDBParamMapEntries}} })))..., + err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, + append(opts, {{.Arg.YDBParamsBuilder}})..., ) {{- end }} if err != nil { diff --git a/internal/codegen/golang/ydb_type.go b/internal/codegen/golang/ydb_type.go index e9e5c46344..0ef665aee1 100644 --- a/internal/codegen/golang/ydb_type.go +++ b/internal/codegen/golang/ydb_type.go @@ -151,6 +151,24 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col // return "sql.NullInt64" return "*int64" + case "json", "jsondocument": + if notNull { + return "string" + } + if emitPointersForNull { + return "*string" + } + return "*string" + + case "date", "date32", "datetime", "timestamp", "tzdate", "tztimestamp", "tzdatetime": + if notNull { + return "time.Time" + } + if emitPointersForNull { + return "*time.Time" + } + return "*time.Time" + case "null": // return "sql.Null" return "interface{}" diff --git a/internal/engine/ydb/catalog_tests/create_table_test.go b/internal/engine/ydb/catalog_tests/create_table_test.go index e98288d75a..7761118927 100644 --- a/internal/engine/ydb/catalog_tests/create_table_test.go +++ b/internal/engine/ydb/catalog_tests/create_table_test.go @@ -106,7 +106,7 @@ func TestCreateTable(t *testing.T) { { Name: "amount", Type: ast.TypeName{ - Name: "Decimal", + Name: "decimal", Names: &ast.List{ Items: []ast.Node{ &ast.Integer{Ival: 22}, diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index 250dc467e0..e173cf287e 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -1582,7 +1582,7 @@ func (c *cc) convertTypeName(n parser.IType_nameContext) *ast.TypeName { if decimal := n.Type_name_decimal(); decimal != nil { if integerOrBinds := decimal.AllInteger_or_bind(); len(integerOrBinds) >= 2 { return &ast.TypeName{ - Name: "Decimal", + Name: "decimal", TypeOid: 0, Names: &ast.List{ Items: []ast.Node{ From 1379e43d2e5b508d0afed705e133cbd1db94dbd5 Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov <150552906+1NepuNep1@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:27:49 +0300 Subject: [PATCH 15/18] Rewrited convert to Visitor Interface from ANTLR4 + fixed some types casts & error texts (#12) * Rewrited ydb convert to Visitor Interface + validated type casts and error texts --- go.mod | 2 +- go.sum | 8 +- internal/engine/ydb/convert.go | 1805 +++++++++++++++++--------------- internal/engine/ydb/parse.go | 5 +- internal/engine/ydb/utils.go | 78 +- 5 files changed, 1053 insertions(+), 845 deletions(-) diff --git a/go.mod b/go.mod index c72f29b6b1..4b755a4baa 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 github.com/xeipuuv/gojsonschema v1.2.0 github.com/ydb-platform/ydb-go-sdk/v3 v3.115.3 - github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 + github.com/ydb-platform/yql-parsers v0.0.0-20250911122629-e8a65d734cbd golang.org/x/sync v0.16.0 google.golang.org/grpc v1.75.0 google.golang.org/protobuf v1.36.8 diff --git a/go.sum b/go.sum index 53cb8e8aec..eb68917f04 100644 --- a/go.sum +++ b/go.sum @@ -144,8 +144,6 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jonboulle/clockwork v0.3.0 h1:9BSCMi8C+0qdApAp4auwX0RkLGUjs956h0EkuQymUhg= -github.com/jonboulle/clockwork v0.3.0/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I= github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -242,12 +240,10 @@ github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17 github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 h1:LY6cI8cP4B9rrpTleZk95+08kl2gF4rixG7+V/dwL6Q= github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77/go.mod h1:Er+FePu1dNUieD+XTMDduGpQuCPssK5Q4BjF+IIXJ3I= -github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0 h1:TwWSp3gRMcja/hRpOofncLvgxAXCmzpz5cGtmdaoITw= -github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0/go.mod h1:l5sSv153E18VvYcsmr51hok9Sjc16tEC8AXGbwrk+ho= github.com/ydb-platform/ydb-go-sdk/v3 v3.115.3 h1:SFeSK2c+PmiToyNIhr143u+YDzLhl/kboXwKLYDk0O4= github.com/ydb-platform/ydb-go-sdk/v3 v3.115.3/go.mod h1:Pp1w2xxUoLQ3NCNAwV7pvDq0TVQOdtAqs+ZiC+i8r14= -github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 h1:KFtJwlPdOxWjCKXX0jFJ8k1FlbqbRbUW3k/kYSZX7SA= -github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333/go.mod h1:vrPJPS8cdPSV568YcXhB4bUwhyV8bmWKqmQ5c5Xi99o= +github.com/ydb-platform/yql-parsers v0.0.0-20250911122629-e8a65d734cbd h1:ZfUkkZ1m5JCAw7jHQavecv+gKJWA6SNxuKLqHQ5/988= +github.com/ydb-platform/yql-parsers v0.0.0-20250911122629-e8a65d734cbd/go.mod h1:vrPJPS8cdPSV568YcXhB4bUwhyV8bmWKqmQ5c5Xi99o= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index e173cf287e..8b67191ce6 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -12,8 +12,8 @@ import ( ) type cc struct { - paramCount int - content string + parser.BaseYQLVisitor + content string } func (c *cc) pos(token antlr.Token) int { @@ -54,9 +54,9 @@ func NewIdentifier(t string) *ast.String { return &ast.String{Str: identifier(t)} } -func (c *cc) convertDrop_role_stmtContext(n *parser.Drop_role_stmtContext) ast.Node { +func (c *cc) VisitDrop_role_stmt(n *parser.Drop_role_stmtContext) interface{} { if n.DROP() == nil || (n.USER() == nil && n.GROUP() == nil) || len(n.AllRole_name()) == 0 { - return todo("Drop_role_stmtContext", n) + return todo("VisitDrop_role_stmt", n) } stmt := &ast.DropRoleStmt{ @@ -67,7 +67,7 @@ func (c *cc) convertDrop_role_stmtContext(n *parser.Drop_role_stmtContext) ast.N for _, role := range n.AllRole_name() { member, isParam, _ := c.extractRoleSpec(role, ast.RoleSpecType(1)) if member == nil { - return todo("Drop_role_stmtContext", n) + return todo("VisitDrop_role_stmt", role) } if debug.Active && isParam { @@ -80,13 +80,13 @@ func (c *cc) convertDrop_role_stmtContext(n *parser.Drop_role_stmtContext) ast.N return stmt } -func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) ast.Node { +func (c *cc) VisitAlter_group_stmt(n *parser.Alter_group_stmtContext) interface{} { if n.ALTER() == nil || n.GROUP() == nil || len(n.AllRole_name()) == 0 { - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n) } role, paramFlag, _ := c.extractRoleSpec(n.Role_name(0), ast.RoleSpecType(1)) if role == nil { - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n) } if debug.Active && paramFlag { @@ -101,7 +101,10 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a switch { case n.RENAME() != nil && n.TO() != nil && len(n.AllRole_name()) > 1: - newName := c.convert(n.Role_name(1)) + newName, ok := n.Role_name(1).Accept(c).(ast.Node) + if !ok { + return todo("VisitAlter_group_stmt", n.Role_name(1)) + } action := "rename" defElem := &ast.DefElem{ @@ -120,12 +123,12 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a case *ast.Boolean: defElem.Arg = val default: - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n.Role_name(1)) } case *ast.ParamRef, *ast.A_Expr: defElem.Arg = newName default: - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n.Role_name(1)) } if debug.Active && !paramFlag && bindFlag { @@ -140,7 +143,7 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a for _, role := range n.AllRole_name()[1:] { member, isParam, _ := c.extractRoleSpec(role, ast.RoleSpecType(1)) if member == nil { - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", role) } if debug.Active && isParam && !paramFlag { @@ -161,21 +164,21 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a Defname: &defname, Arg: optionList, Defaction: action, - Location: c.pos(n.GetStart()), + Location: c.pos(n.Role_name(1).GetStart()), }) } return stmt } -func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast.Node { +func (c *cc) VisitAlter_user_stmt(n *parser.Alter_user_stmtContext) interface{} { if n.ALTER() == nil || n.USER() == nil || len(n.AllRole_name()) == 0 { - return todo("Alter_user_stmtContext", n) + return todo("VisitAlter_user_stmt", n) } role, paramFlag, _ := c.extractRoleSpec(n.Role_name(0), ast.RoleSpecType(1)) if role == nil { - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n) } if debug.Active && paramFlag { @@ -190,7 +193,10 @@ func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast switch { case n.RENAME() != nil && n.TO() != nil && len(n.AllRole_name()) > 1: - newName := c.convert(n.Role_name(1)) + newName, ok := n.Role_name(1).Accept(c).(ast.Node) + if !ok { + return todo("VisitAlter_user_stmt", n.Role_name(1)) + } action := "rename" defElem := &ast.DefElem{ @@ -209,12 +215,12 @@ func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast case *ast.Boolean: defElem.Arg = val default: - return todo("Alter_user_stmtContext", n) + return todo("VisitAlter_user_stmt", n.Role_name(1)) } case *ast.ParamRef, *ast.A_Expr: defElem.Arg = newName default: - return todo("Alter_user_stmtContext", n) + return todo("VisitAlter_user_stmt", n.Role_name(1)) } if debug.Active && !paramFlag && bindFlag { @@ -225,7 +231,11 @@ func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast case len(n.AllUser_option()) > 0: for _, opt := range n.AllUser_option() { - if node := c.convert(opt); node != nil { + if temp := opt.Accept(c); temp != nil { + var node, ok = temp.(ast.Node) + if !ok { + return todo("VisitAlter_user_stmt", opt) + } stmt.Options.Items = append(stmt.Options.Items, node) } } @@ -234,11 +244,14 @@ func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast return stmt } -func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) ast.Node { +func (c *cc) VisitCreate_group_stmt(n *parser.Create_group_stmtContext) interface{} { if n.CREATE() == nil || n.GROUP() == nil || len(n.AllRole_name()) == 0 { - return todo("Create_group_stmtContext", n) + return todo("VisitCreate_group_stmt", n) + } + groupName, ok := n.Role_name(0).Accept(c).(ast.Node) + if !ok { + return todo("VisitCreate_group_stmt", n.Role_name(0)) } - groupName := c.convert(n.Role_name(0)) stmt := &ast.CreateRoleStmt{ StmtType: ast.RoleStmtType(3), @@ -255,12 +268,12 @@ func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) case *ast.Boolean: stmt.BindRole = groupName default: - return todo("convertCreate_group_stmtContext", n) + return todo("VisitCreate_group_stmt", n.Role_name(0)) } case *ast.ParamRef, *ast.A_Expr: stmt.BindRole = groupName default: - return todo("convertCreate_group_stmtContext", n) + return todo("VisitCreate_group_stmt", n.Role_name(0)) } if debug.Active && paramFlag { @@ -273,7 +286,7 @@ func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) for _, role := range n.AllRole_name()[1:] { member, isParam, _ := c.extractRoleSpec(role, ast.RoleSpecType(1)) if member == nil { - return todo("convertCreate_group_stmtContext", n) + return todo("VisitCreate_group_stmt", role) } if debug.Active && isParam && !paramFlag { @@ -286,26 +299,29 @@ func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) stmt.Options.Items = append(stmt.Options.Items, &ast.DefElem{ Defname: &defname, Arg: optionList, - Location: c.pos(n.GetStart()), + Location: c.pos(n.Role_name(1).GetStart()), }) } return stmt } -func (c *cc) convertUse_stmtContext(n *parser.Use_stmtContext) ast.Node { +func (c *cc) VisitUse_stmt(n *parser.Use_stmtContext) interface{} { if n.USE() != nil && n.Cluster_expr() != nil { - clusterExpr := c.convert(n.Cluster_expr()) + clusterExpr, ok := n.Cluster_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUse_stmt", n.Cluster_expr()) + } stmt := &ast.UseStmt{ Xpr: clusterExpr, - Location: c.pos(n.GetStart()), + Location: c.pos(n.Cluster_expr().GetStart()), } return stmt } - return todo("convertUse_stmtContext", n) + return todo("VisitUse_stmt", n) } -func (c *cc) convertCluster_exprContext(n *parser.Cluster_exprContext) ast.Node { +func (c *cc) VisitCluster_expr(n *parser.Cluster_exprContext) interface{} { var node ast.Node switch { @@ -318,12 +334,16 @@ func (c *cc) convertCluster_exprContext(n *parser.Cluster_exprContext) ast.Node Location: c.pos(anID.GetStart()), } } else if bp := pureCtx.Bind_parameter(); bp != nil { - node = c.convert(bp) + temp, ok := bp.Accept(c).(ast.Node) + if !ok { + return todo("VisitCluster_expr", bp) + } + node = temp } case n.ASTERISK() != nil: node = &ast.A_Star{} default: - return todo("convertCluster_exprContext", n) + return todo("VisitCluster_expr", n) } if n.An_id() != nil && n.COLON() != nil { @@ -339,11 +359,14 @@ func (c *cc) convertCluster_exprContext(n *parser.Cluster_exprContext) ast.Node return node } -func (c *cc) convertCreate_user_stmtContext(n *parser.Create_user_stmtContext) ast.Node { +func (c *cc) VisitCreate_user_stmt(n *parser.Create_user_stmtContext) interface{} { if n.CREATE() == nil || n.USER() == nil || n.Role_name() == nil { - return todo("convertCreate_user_stmtContext", n) + return todo("VisitCreate_user_stmt", n) + } + roleNode, ok := n.Role_name().Accept(c).(ast.Node) + if !ok { + return todo("VisitCreate_user_stmt", n.Role_name()) } - roleNode := c.convert(n.Role_name()) stmt := &ast.CreateRoleStmt{ StmtType: ast.RoleStmtType(2), @@ -360,12 +383,12 @@ func (c *cc) convertCreate_user_stmtContext(n *parser.Create_user_stmtContext) a case *ast.Boolean: stmt.BindRole = roleNode default: - return todo("convertCreate_user_stmtContext", n) + return todo("VisitCreate_user_stmt", n.Role_name()) } case *ast.ParamRef, *ast.A_Expr: stmt.BindRole = roleNode default: - return todo("convertCreate_user_stmtContext", n) + return todo("VisitCreate_user_stmt", n.Role_name()) } if debug.Active && paramFlag { @@ -375,7 +398,11 @@ func (c *cc) convertCreate_user_stmtContext(n *parser.Create_user_stmtContext) a if len(n.AllUser_option()) > 0 { options := []ast.Node{} for _, opt := range n.AllUser_option() { - if node := c.convert(opt); node != nil { + if temp := opt.Accept(c); temp != nil { + node, ok := temp.(ast.Node) + if !ok { + return todo("VisitCreate_user_stmt", opt) + } options = append(options, node) } } @@ -386,7 +413,7 @@ func (c *cc) convertCreate_user_stmtContext(n *parser.Create_user_stmtContext) a return stmt } -func (c *cc) convertUser_optionContext(n *parser.User_optionContext) ast.Node { +func (c *cc) VisitUser_option(n *parser.User_optionContext) interface{} { switch { case n.Authentication_option() != nil: aOpt := n.Authentication_option() @@ -436,40 +463,43 @@ func (c *cc) convertUser_optionContext(n *parser.User_optionContext) ast.Node { Location: c.pos(lOpt.GetStart()), } default: - return todo("convertUser_optionContext", n) + return todo("VisitUser_option", n) } - return nil + return todo("VisitUser_option", n) } -func (c *cc) convertRole_nameContext(n *parser.Role_nameContext) ast.Node { +func (c *cc) VisitRole_name(n *parser.Role_nameContext) interface{} { switch { case n.An_id_or_type() != nil: name := parseAnIdOrType(n.An_id_or_type()) - return &ast.A_Const{Val: NewIdentifier(name), Location: c.pos(n.GetStart())} + return &ast.A_Const{Val: NewIdentifier(name), Location: c.pos(n.An_id_or_type().GetStart())} case n.Bind_parameter() != nil: - bindPar := c.convert(n.Bind_parameter()) + bindPar, ok := n.Bind_parameter().Accept(c).(ast.Node) + if !ok { + return todo("VisitRole_name", n.Bind_parameter()) + } return bindPar } - return todo("convertRole_nameContext", n) + return todo("VisitRole_name", n) } -func (c *cc) convertCommit_stmtContext(n *parser.Commit_stmtContext) ast.Node { +func (c *cc) VisitCommit_stmt(n *parser.Commit_stmtContext) interface{} { if n.COMMIT() != nil { return &ast.TransactionStmt{Kind: ast.TransactionStmtKind(3)} } - return todo("convertCommit_stmtContext", n) + return todo("VisitCommit_stmt", n) } -func (c *cc) convertRollback_stmtContext(n *parser.Rollback_stmtContext) ast.Node { +func (c *cc) VisitRollback_stmt(n *parser.Rollback_stmtContext) interface{} { if n.ROLLBACK() != nil { return &ast.TransactionStmt{Kind: ast.TransactionStmtKind(4)} } - return todo("convertRollback_stmtContext", n) + return todo("VisitRollback_stmt", n) } -func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) ast.Node { +func (c *cc) VisitAlter_table_stmt(n *parser.Alter_table_stmtContext) interface{} { if n.ALTER() == nil || n.TABLE() == nil || n.Simple_table_ref() == nil || len(n.AllAlter_table_action()) == 0 { - return todo("convertAlter_table_stmtContext", n) + return todo("VisitAlter_table_stmt", n) } stmt := &ast.AlterTableStmt{ @@ -486,7 +516,14 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a case action.Alter_table_add_column() != nil: ac := action.Alter_table_add_column() if ac.ADD() != nil && ac.Column_schema() != nil { - columnDef := c.convertColumnSchema(ac.Column_schema().(*parser.Column_schemaContext)) + temp, ok := ac.Column_schema().Accept(c).(ast.Node) + if !ok { + return todo("VisitAlter_table_stmt", ac.Column_schema()) + } + columnDef, ok := temp.(*ast.ColumnDef) + if !ok { + return todo("VisitAlter_table_stmt", ac.Column_schema()) + } stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ Name: &columnDef.Colname, Subtype: ast.AT_AddColumn, @@ -514,7 +551,7 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a case action.Alter_table_rename_to() != nil: ac := action.Alter_table_rename_to() if ac.RENAME() != nil && ac.TO() != nil && ac.An_id_table() != nil { - // TODO: Returning here may be incorrect if there are multiple specs + // FIXME: Returning here may be incorrect if there are multiple specs newName := parseAnIdTable(ac.An_id_table()) return &ast.RenameTableStmt{ Table: stmt.Table, @@ -541,25 +578,33 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a return stmt } -func (c *cc) convertDo_stmtContext(n *parser.Do_stmtContext) ast.Node { +func (c *cc) VisitDo_stmt(n *parser.Do_stmtContext) interface{} { if n.DO() == nil || (n.Call_action() == nil && n.Inline_action() == nil) { - return todo("convertDo_stmtContext", n) + return todo("VisitDo_stmt", n) } switch { case n.Call_action() != nil: - return c.convert(n.Call_action()) + result, ok := n.Call_action().Accept(c).(ast.Node) + if !ok { + return todo("VisitDo_stmt", n.Call_action()) + } + return result case n.Inline_action() != nil: - return c.convert(n.Inline_action()) + result, ok := n.Inline_action().Accept(c).(ast.Node) + if !ok { + return todo("VisitDo_stmt", n.Inline_action()) + } + return result } - return todo("convertDo_stmtContext", n) + return todo("VisitDo_stmt", n) } -func (c *cc) convertCall_actionContext(n *parser.Call_actionContext) ast.Node { +func (c *cc) VisitCall_action(n *parser.Call_actionContext) interface{} { if n == nil { - return nil + return todo("VisitCall_action", n) } if n.LPAREN() != nil && n.RPAREN() != nil { funcCall := &ast.FuncCall{ @@ -569,14 +614,22 @@ func (c *cc) convertCall_actionContext(n *parser.Call_actionContext) ast.Node { } if n.Bind_parameter() != nil { - funcCall.Funcname.Items = append(funcCall.Funcname.Items, c.convert(n.Bind_parameter())) + bindPar, ok := n.Bind_parameter().Accept(c).(ast.Node) + if !ok { + return todo("VisitCall_action", n.Bind_parameter()) + } + funcCall.Funcname.Items = append(funcCall.Funcname.Items, bindPar) } else if n.EMPTY_ACTION() != nil { funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: "EMPTY_ACTION"}) } if n.Expr_list() != nil { for _, expr := range n.Expr_list().AllExpr() { - funcCall.Args.Items = append(funcCall.Args.Items, c.convert(expr)) + exprNode, ok := expr.Accept(c).(ast.Node) + if !ok { + return todo("VisitCall_action", expr) + } + funcCall.Args.Items = append(funcCall.Args.Items, exprNode) } } @@ -584,29 +637,33 @@ func (c *cc) convertCall_actionContext(n *parser.Call_actionContext) ast.Node { Args: &ast.List{Items: []ast.Node{funcCall}}, } } - return todo("convertCall_actionContext", n) + return todo("VisitCall_action", n) } -func (c *cc) convertInline_actionContext(n *parser.Inline_actionContext) ast.Node { +func (c *cc) VisitInline_action(n *parser.Inline_actionContext) interface{} { if n == nil { - return nil + return todo("VisitInline_action", n) } if n.BEGIN() != nil && n.END() != nil && n.DO() != nil { args := &ast.List{} if defineBody := n.Define_action_or_subquery_body(); defineBody != nil { cores := defineBody.AllSql_stmt_core() for _, stmtCore := range cores { - if converted := c.convert(stmtCore); converted != nil { - args.Items = append(args.Items, converted) + if converted := stmtCore.Accept(c); converted != nil { + var convertedNode, ok = converted.(ast.Node) + if !ok { + return todo("VisitInline_action", stmtCore) + } + args.Items = append(args.Items, convertedNode) } } } return &ast.DoStmt{Args: args} } - return todo("convertInline_actionContext", n) + return todo("VisitInline_action", n) } -func (c *cc) convertDrop_table_stmtContext(n *parser.Drop_table_stmtContext) ast.Node { +func (c *cc) VisitDrop_table_stmt(n *parser.Drop_table_stmtContext) interface{} { if n.DROP() != nil && (n.TABLESTORE() != nil || (n.EXTERNAL() != nil && n.TABLE() != nil) || n.TABLE() != nil) { name := parseTableName(n.Simple_table_ref().Simple_table_ref_core()) stmt := &ast.DropTableStmt{ @@ -615,10 +672,10 @@ func (c *cc) convertDrop_table_stmtContext(n *parser.Drop_table_stmtContext) ast } return stmt } - return todo("convertDrop_Table_stmtContxt", n) + return todo("VisitDrop_table_stmt", n) } -func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { +func (c *cc) VisitDelete_stmt(n *parser.Delete_stmtContext) interface{} { batch := n.BATCH() != nil tableName := identifier(n.Simple_table_ref().Simple_table_ref_core().GetText()) @@ -626,7 +683,11 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { var where ast.Node if n.WHERE() != nil && n.Expr() != nil { - where = c.convert(n.Expr()) + whereNode, ok := n.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", n.Expr()) + } + where = whereNode } var cols *ast.List var source ast.Node @@ -649,17 +710,37 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { switch { case valSource.Values_stmt() != nil: stmt := emptySelectStmt() - stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + temp, ok := valSource.Values_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", valSource.Values_stmt()) + } + list, ok := temp.(*ast.List) + if !ok { + return todo("VisitDelete_stmt", valSource.Values_stmt()) + } + stmt.ValuesLists = list source = stmt case valSource.Select_stmt() != nil: - source = c.convert(valSource.Select_stmt()) + temp, ok := valSource.Select_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", valSource.Select_stmt()) + } + source = temp } } } returning := &ast.List{} if ret := n.Returning_columns_list(); ret != nil { - returning = c.convert(ret).(*ast.List) + temp, ok := ret.Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", n.Returning_columns_list()) + } + returningNode, ok := temp.(*ast.List) + if !ok { + return todo("VisitDelete_stmt", n.Returning_columns_list()) + } + returning = returningNode } stmts := &ast.DeleteStmt{ @@ -674,7 +755,7 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { return stmts } -func (c *cc) convertPragma_stmtContext(n *parser.Pragma_stmtContext) ast.Node { +func (c *cc) VisitPragma_stmt(n *parser.Pragma_stmtContext) interface{} { if n.PRAGMA() != nil && n.An_id() != nil { prefix := "" if p := n.Opt_id_prefix_or_type(); p != nil { @@ -696,22 +777,30 @@ func (c *cc) convertPragma_stmtContext(n *parser.Pragma_stmtContext) ast.Node { if n.EQUALS() != nil { stmt.Equals = true if val := n.Pragma_value(0); val != nil { - stmt.Values = &ast.List{Items: []ast.Node{c.convert(val)}} + valNode, ok := val.Accept(c).(ast.Node) + if !ok { + return todo("VisitPragma_stmt", n.Pragma_value(0)) + } + stmt.Values = &ast.List{Items: []ast.Node{valNode}} } } else if lp := n.LPAREN(); lp != nil { values := []ast.Node{} for _, v := range n.AllPragma_value() { - values = append(values, c.convert(v)) + valNode, ok := v.Accept(c).(ast.Node) + if !ok { + return todo("VisitPragma_stmt", v) + } + values = append(values, valNode) } stmt.Values = &ast.List{Items: values} } return stmt } - return todo("convertPragma_stmtContext", n) + return todo("VisitPragma_stmt", n) } -func (c *cc) convertPragma_valueContext(n *parser.Pragma_valueContext) ast.Node { +func (c *cc) VisitPragma_value(n *parser.Pragma_valueContext) interface{} { switch { case n.Signed_number() != nil: if n.Signed_number().Integer() != nil { @@ -742,16 +831,20 @@ func (c *cc) convertPragma_valueContext(n *parser.Pragma_valueContext) ast.Node } return &ast.A_Const{Val: &ast.Boolean{Boolval: i}, Location: c.pos(n.GetStart())} case n.Bind_parameter() != nil: - bindPar := c.convert(n.Bind_parameter()) - return bindPar + bindPar := n.Bind_parameter().Accept(c) + var bindParNode, ok = bindPar.(ast.Node) + if !ok { + return todo("VisitPragma_value", n.Bind_parameter()) + } + return bindParNode } - return todo("convertPragma_valueContext", n) + return todo("VisitPragma_value", n) } -func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { - if n.UPDATE() == nil { - return nil +func (c *cc) VisitUpdate_stmt(n *parser.Update_stmtContext) interface{} { + if n == nil || n.UPDATE() == nil { + return todo("VisitUpdate_stmt", n) } batch := n.BATCH() != nil @@ -772,7 +865,10 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { for _, clause := range nSet.Set_clause_list().AllSet_clause() { targetCtx := clause.Set_target() columnName := identifier(targetCtx.Column_name().GetText()) - expr := c.convert(clause.Expr()) + expr, ok := clause.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", clause.Expr()) + } resTarget := &ast.ResTarget{ Name: &columnName, Val: expr, @@ -798,7 +894,11 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { Args: &ast.List{}, } for _, expr := range exprList.AllExpr() { - rowExpr.Args.Items = append(rowExpr.Args.Items, c.convert(expr)) + exprNode, ok := expr.Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", expr) + } + rowExpr.Args.Items = append(rowExpr.Args.Items, exprNode) } } @@ -817,7 +917,11 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { } if n.WHERE() != nil && n.Expr() != nil { - where = c.convert(n.Expr()) + whereNode, ok := n.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", n.Expr()) + } + where = whereNode } } else if n.ON() != nil && n.Into_values_source() != nil { @@ -841,17 +945,37 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { switch { case valSource.Values_stmt() != nil: stmt := emptySelectStmt() - stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + temp, ok := valSource.Values_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", valSource.Values_stmt()) + } + list, ok := temp.(*ast.List) + if !ok { + return todo("VisitUpdate_stmt", valSource.Values_stmt()) + } + stmt.ValuesLists = list source = stmt case valSource.Select_stmt() != nil: - source = c.convert(valSource.Select_stmt()) + temp, ok := valSource.Select_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", valSource.Select_stmt()) + } + source = temp } } } returning := &ast.List{} if ret := n.Returning_columns_list(); ret != nil { - returning = c.convert(ret).(*ast.List) + temp, ok := ret.Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", n.Returning_columns_list()) + } + returningNode, ok := temp.(*ast.List) + if !ok { + return todo("VisitDelete_stmt", n.Returning_columns_list()) + } + returning = returningNode } stmts := &ast.UpdateStmt{ @@ -869,7 +993,7 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { return stmts } -func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast.Node { +func (c *cc) VisitInto_table_stmt(n *parser.Into_table_stmtContext) interface{} { tableName := identifier(n.Into_simple_table_ref().Simple_table_ref().Simple_table_ref_core().GetText()) rel := &ast.RangeVar{ Relname: &tableName, @@ -911,17 +1035,37 @@ func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast switch { case valSource.Values_stmt() != nil: stmt := emptySelectStmt() - stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + temp, ok := valSource.Values_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitInto_table_stmt", valSource.Values_stmt()) + } + stmtNode, ok := temp.(*ast.List) + if !ok { + return todo("VisitInto_table_stmt", valSource.Values_stmt()) + } + stmt.ValuesLists = stmtNode source = stmt case valSource.Select_stmt() != nil: - source = c.convert(valSource.Select_stmt()) + sourceNode, ok := valSource.Select_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitInto_table_stmt", valSource.Select_stmt()) + } + source = sourceNode } } } returning := &ast.List{} if ret := n.Returning_columns_list(); ret != nil { - returning = c.convert(ret).(*ast.List) + temp, ok := ret.Accept(c).(ast.Node) + if !ok { + return todo("VisitInto_table_stmt", n.Returning_columns_list()) + } + returningNode, ok := temp.(*ast.List) + if !ok { + return todo("VisitInto_table_stmt", n.Returning_columns_list()) + } + returning = returningNode } stmts := &ast.InsertStmt{ @@ -935,7 +1079,7 @@ func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast return stmts } -func (c *cc) convertValues_stmtContext(n *parser.Values_stmtContext) ast.Node { +func (c *cc) VisitValues_stmt(n *parser.Values_stmtContext) interface{} { mainList := &ast.List{} for _, rowCtx := range n.Values_source_row_list().AllValues_source_row() { @@ -943,8 +1087,12 @@ func (c *cc) convertValues_stmtContext(n *parser.Values_stmtContext) ast.Node { exprListCtx := rowCtx.Expr_list().(*parser.Expr_listContext) for _, exprCtx := range exprListCtx.AllExpr() { - if converted := c.convert(exprCtx); converted != nil { - rowList.Items = append(rowList.Items, converted) + if converted := exprCtx.Accept(c); converted != nil { + var convertedNode, ok = converted.(ast.Node) + if !ok { + return todo("VisitValues_stmt", exprCtx) + } + rowList.Items = append(rowList.Items, convertedNode) } } @@ -955,7 +1103,7 @@ func (c *cc) convertValues_stmtContext(n *parser.Values_stmtContext) ast.Node { return mainList } -func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_listContext) ast.Node { +func (c *cc) VisitReturning_columns_list(n *parser.Returning_columns_listContext) interface{} { list := &ast.List{Items: []ast.Node{}} if n.ASTERISK() != nil { @@ -988,30 +1136,36 @@ func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_li return list } -func (c *cc) convertSelectStmtContext(n *parser.Select_stmtContext) ast.Node { +func (c *cc) VisitSelect_stmt(n *parser.Select_stmtContext) interface{} { if len(n.AllSelect_kind_parenthesis()) == 0 { - return todo("convertSelectStmtContext", n) + return todo("VisitSelect_stmt", n) } skp := n.Select_kind_parenthesis(0) if skp == nil { - return todo("convertSelectStmtContext", skp) + return todo("VisitSelect_stmt", skp) } - stmt := c.convertSelectKindParenthesis(skp) - left, ok := stmt.(*ast.SelectStmt) + temp, ok := skp.Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", skp) + } + left, ok := temp.(*ast.SelectStmt) if left == nil || !ok { - return todo("convertSelectKindParenthesis", skp) + return todo("VisitSelect_kind_parenthesis", skp) } kinds := n.AllSelect_kind_parenthesis() ops := n.AllSelect_op() for i := 1; i < len(kinds); i++ { - stmt := c.convertSelectKindParenthesis(kinds[i]) - right, ok := stmt.(*ast.SelectStmt) + temp, ok := kinds[i].Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", kinds[i]) + } + right, ok := temp.(*ast.SelectStmt) if right == nil || !ok { - return todo("convertSelectKindParenthesis", kinds[i]) + return todo("VisitSelect_kind_parenthesis", kinds[i]) } var op ast.SetOperation @@ -1041,21 +1195,25 @@ func (c *cc) convertSelectStmtContext(n *parser.Select_stmtContext) ast.Node { return left } -func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisContext) ast.Node { +func (c *cc) VisitSelect_kind_parenthesis(n *parser.Select_kind_parenthesisContext) interface{} { if n == nil || n.Select_kind_partial() == nil { - return todo("convertSelectKindParenthesis", n) + return todo("VisitSelect_kind_parenthesis", n) } partial := n.Select_kind_partial() sk := partial.Select_kind() if sk == nil { - return todo("convertSelectKind", sk) + return todo("VisitSelect_kind_parenthesis", sk) } var base ast.Node switch { case sk.Select_core() != nil: - base = c.convertSelectCoreContext(sk.Select_core()) + baseNode, ok := sk.Select_core().Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", sk.Select_core()) + } + base = baseNode case sk.Process_core() != nil: log.Fatalf("PROCESS is not supported in YDB engine") case sk.Reduce_core() != nil: @@ -1063,7 +1221,7 @@ func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisConte } stmt, ok := base.(*ast.SelectStmt) if !ok || stmt == nil { - return todo("convertSelectKindParenthesis", sk.Select_core()) + return todo("VisitSelect_kind_parenthesis", sk.Select_core()) } // TODO: handle INTO RESULT clause @@ -1071,11 +1229,19 @@ func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisConte if partial.LIMIT() != nil { exprs := partial.AllExpr() if len(exprs) >= 1 { - stmt.LimitCount = c.convert(exprs[0]) + temp, ok := exprs[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", exprs[0]) + } + stmt.LimitCount = temp } if partial.OFFSET() != nil { if len(exprs) >= 2 { - stmt.LimitOffset = c.convert(exprs[1]) + temp, ok := exprs[1].Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", exprs[1]) + } + stmt.LimitOffset = temp } } } @@ -1083,7 +1249,7 @@ func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisConte return stmt } -func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { +func (c *cc) VisitSelect_core(n *parser.Select_coreContext) interface{} { stmt := emptySelectStmt() if n.Opt_set_quantifier() != nil { oq := n.Opt_set_quantifier() @@ -1095,14 +1261,11 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { if len(resultCols) > 0 { var items []ast.Node for _, rc := range resultCols { - resCol, ok := rc.(*parser.Result_columnContext) + convNode, ok := rc.Accept(c).(ast.Node) if !ok { - continue - } - convNode := c.convertResultColumn(resCol) - if convNode != nil { - items = append(items, convNode) + return todo("VisitSelect_core", rc) } + items = append(items, convNode) } stmt.TargetList = &ast.List{ Items: items, @@ -1118,15 +1281,11 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { if len(jsList) > 0 { var fromItems []ast.Node for _, js := range jsList { - jsCon, ok := js.(*parser.Join_sourceContext) + joinNode, ok := js.Accept(c).(ast.Node) if !ok { - continue - } - - joinNode := c.convertJoinSource(jsCon) - if joinNode != nil { - fromItems = append(fromItems, joinNode) + return todo("VisitSelect_core", js) } + fromItems = append(fromItems, joinNode) } stmt.FromClause = &ast.List{ Items: fromItems, @@ -1136,13 +1295,21 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { exprIdx := 0 if n.WHERE() != nil { if whereCtx := n.Expr(exprIdx); whereCtx != nil { - stmt.WhereClause = c.convert(whereCtx) + where, ok := whereCtx.Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_core", whereCtx) + } + stmt.WhereClause = where } exprIdx++ } if n.HAVING() != nil { if havingCtx := n.Expr(exprIdx); havingCtx != nil { - stmt.HavingClause = c.convert(havingCtx) + having, ok := havingCtx.Accept(c).(ast.Node) + if !ok || having == nil { + return todo("VisitSelect_core", havingCtx) + } + stmt.HavingClause = having } exprIdx++ } @@ -1151,7 +1318,11 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { if gel := gbc.Grouping_element_list(); gel != nil { var groups []ast.Node for _, ne := range gel.AllGrouping_element() { - groups = append(groups, c.convert(ne)) + groupBy, ok := ne.Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_core", ne) + } + groups = append(groups, groupBy) } if len(groups) > 0 { stmt.GroupClause = &ast.List{Items: groups} @@ -1165,15 +1336,14 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { if sl := ob.Sort_specification_list(); sl != nil { var orderItems []ast.Node for _, sp := range sl.AllSort_specification() { - ss, ok := sp.(*parser.Sort_specificationContext) - if !ok || ss == nil { - continue + expr, ok := sp.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_core", sp.Expr()) } - expr := c.convert(ss.Expr()) dir := ast.SortByDirDefault - if ss.ASC() != nil { + if sp.ASC() != nil { dir = ast.SortByDirAsc - } else if ss.DESC() != nil { + } else if sp.DESC() != nil { dir = ast.SortByDirDesc } orderItems = append(orderItems, &ast.SortBy{ @@ -1181,7 +1351,7 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { SortbyDir: dir, SortbyNulls: ast.SortByNullsUndefined, UseOp: &ast.List{}, - Location: c.pos(ss.GetStart()), + Location: c.pos(sp.GetStart()), }) } if len(orderItems) > 0 { @@ -1193,77 +1363,109 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { return stmt } -func (c *cc) convertGrouping_elementContext(n parser.IGrouping_elementContext) ast.Node { +func (c *cc) VisitGrouping_element(n *parser.Grouping_elementContext) interface{} { if n == nil { - return todo("convertGrouping_elementContext", n) + return todo("VisitGrouping_element", n) } if ogs := n.Ordinary_grouping_set(); ogs != nil { - return c.convert(ogs) + groupingSet, ok := ogs.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_element", ogs) + } + return groupingSet } if rl := n.Rollup_list(); rl != nil { - return c.convert(rl) + rollupList, ok := rl.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_element", rl) + } + return rollupList } if cl := n.Cube_list(); cl != nil { - return c.convert(cl) + cubeList, ok := cl.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_element", cl) + } + return cubeList } if gss := n.Grouping_sets_specification(); gss != nil { - return c.convert(gss) + groupingSets, ok := gss.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_element", gss) + } + return groupingSets } - return todo("convertGrouping_elementContext", n) + return todo("VisitGrouping_element", n) } -func (c *cc) convertOrdinary_grouping_setContext(n *parser.Ordinary_grouping_setContext) ast.Node { +func (c *cc) VisitOrdinary_grouping_set(n *parser.Ordinary_grouping_setContext) interface{} { if n == nil || n.Named_expr() == nil { - return todo("convertOrdinary_grouping_setContext", n) + return todo("VisitOrdinary_grouping_set", n) } - return c.convert(n.Named_expr()) + namedExpr, ok := n.Named_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitOrdinary_grouping_set", n.Named_expr()) + } + return namedExpr } -func (c *cc) convertRollup_listContext(n *parser.Rollup_listContext) ast.Node { +func (c *cc) VisitRollup_list(n *parser.Rollup_listContext) interface{} { if n == nil || n.ROLLUP() == nil || n.LPAREN() == nil || n.RPAREN() == nil { - return todo("convertRollup_listContext", n) + return todo("VisitRollup_list", n) } var items []ast.Node if list := n.Ordinary_grouping_set_list(); list != nil { for _, ogs := range list.AllOrdinary_grouping_set() { - items = append(items, c.convert(ogs)) + og, ok := ogs.Accept(c).(ast.Node) + if !ok { + return todo("VisitRollup_list", ogs) + } + items = append(items, og) } } return &ast.GroupingSet{Kind: 1, Content: &ast.List{Items: items}} } -func (c *cc) convertCube_listContext(n *parser.Cube_listContext) ast.Node { +func (c *cc) VisitCube_list(n *parser.Cube_listContext) interface{} { if n == nil || n.CUBE() == nil || n.LPAREN() == nil || n.RPAREN() == nil { - return todo("convertCube_listContext", n) + return todo("VisitCube_list", n) } var items []ast.Node if list := n.Ordinary_grouping_set_list(); list != nil { for _, ogs := range list.AllOrdinary_grouping_set() { - items = append(items, c.convert(ogs)) + og, ok := ogs.Accept(c).(ast.Node) + if !ok { + return todo("VisitCube_list", ogs) + } + items = append(items, og) } } return &ast.GroupingSet{Kind: 2, Content: &ast.List{Items: items}} } -func (c *cc) convertGrouping_sets_specificationContext(n *parser.Grouping_sets_specificationContext) ast.Node { +func (c *cc) VisitGrouping_sets_specification(n *parser.Grouping_sets_specificationContext) interface{} { if n == nil || n.GROUPING() == nil || n.SETS() == nil || n.LPAREN() == nil || n.RPAREN() == nil { - return todo("convertGrouping_sets_specificationContext", n) + return todo("VisitGrouping_sets_specification", n) } var items []ast.Node if gel := n.Grouping_element_list(); gel != nil { for _, ge := range gel.AllGrouping_element() { - items = append(items, c.convert(ge)) + g, ok := ge.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_sets_specification", ge) + } + items = append(items, g) } } return &ast.GroupingSet{Kind: 3, Content: &ast.List{Items: items}} } -func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { +func (c *cc) VisitResult_column(n *parser.Result_columnContext) interface{} { // todo: support opt_id_prefix target := &ast.ResTarget{ Location: c.pos(n.GetStart()), @@ -1274,11 +1476,15 @@ func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { case n.ASTERISK() != nil: val = c.convertWildCardField(n) case iexpr != nil: - val = c.convert(iexpr) + temp, ok := iexpr.Accept(c).(ast.Node) + if !ok { + return todo("VisitResult_column", iexpr) + } + val = temp } if val == nil { - return nil + return todo("VisitResult_column", n) } switch { case n.AS() != nil && n.An_id_or_type() != nil: @@ -1291,30 +1497,27 @@ func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { return target } -func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { - if n == nil { - return nil +func (c *cc) VisitJoin_source(n *parser.Join_sourceContext) interface{} { + if n == nil || len(n.AllFlatten_source()) == 0 { + return todo("VisitJoin_source", n) } fsList := n.AllFlatten_source() - if len(fsList) == 0 { - return nil - } joinOps := n.AllJoin_op() joinConstraints := n.AllJoin_constraint() // todo: add ANY support - leftNode := c.convertFlattenSource(fsList[0]) - if leftNode == nil { - return nil + leftNode, ok := fsList[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitJoin_source", fsList[0]) } for i, jopCtx := range joinOps { if i+1 >= len(fsList) { break } - rightNode := c.convertFlattenSource(fsList[i+1]) - if rightNode == nil { - return leftNode + rightNode, ok := fsList[i+1].Accept(c).(ast.Node) + if !ok { + return todo("VisitJoin_source", fsList[i+1]) } jexpr := &ast.JoinExpr{ Larg: leftNode, @@ -1343,7 +1546,11 @@ func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { switch { case jc.ON() != nil: if exprCtx := jc.Expr(); exprCtx != nil { - jexpr.Quals = c.convert(exprCtx) + expr, ok := exprCtx.Accept(c).(ast.Node) + if !ok { + return todo("VisitJoin_source", exprCtx) + } + jexpr.Quals = expr } case jc.USING() != nil: if pureListCtx := jc.Pure_column_or_named_list(); pureListCtx != nil { @@ -1353,12 +1560,17 @@ func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { if anID := pureCtx.An_id(); anID != nil { using.Items = append(using.Items, NewIdentifier(parseAnId(anID))) } else if bp := pureCtx.Bind_parameter(); bp != nil { - bindPar := c.convert(bp) + bindPar, ok := bp.Accept(c).(ast.Node) + if !ok { + return todo("VisitJoin_source", bp) + } using.Items = append(using.Items, bindPar) } } jexpr.UsingClause = &using } + default: + return todo("VisitJoin_source", jc) } } } @@ -1367,31 +1579,25 @@ func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { return leftNode } -func (c *cc) convertFlattenSource(n parser.IFlatten_sourceContext) ast.Node { - if n == nil { - return nil - } - nss := n.Named_single_source() - if nss == nil { - return nil +func (c *cc) VisitFlatten_source(n *parser.Flatten_sourceContext) interface{} { + if n == nil || n.Named_single_source() == nil { + return todo("VisitFlatten_source", n) } - namedSingleSource, ok := nss.(*parser.Named_single_sourceContext) + namedSingleSource, ok := n.Named_single_source().Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitFlatten_source", n.Named_single_source()) } - return c.convertNamedSingleSource(namedSingleSource) + return namedSingleSource } -func (c *cc) convertNamedSingleSource(n *parser.Named_single_sourceContext) ast.Node { - ss := n.Single_source() - if ss == nil { - return nil +func (c *cc) VisitNamed_single_source(n *parser.Named_single_sourceContext) interface{} { + if n == nil || n.Single_source() == nil { + return todo("VisitNamed_single_source", n) } - SingleSource, ok := ss.(*parser.Single_sourceContext) + base, ok := n.Single_source().Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitNamed_single_source", n.Single_source()) } - base := c.convertSingleSource(SingleSource) if n.AS() != nil && n.An_id() != nil { aliasText := parseAnId(n.An_id()) @@ -1407,7 +1613,11 @@ func (c *cc) convertNamedSingleSource(n *parser.Named_single_sourceContext) ast. return base } -func (c *cc) convertSingleSource(n *parser.Single_sourceContext) ast.Node { +func (c *cc) VisitSingle_source(n *parser.Single_sourceContext) interface{} { + if n == nil { + return todo("VisitSingle_source", n) + } + if n.Table_ref() != nil { tableName := n.Table_ref().GetText() // !! debug !! return &ast.RangeVar{ @@ -1417,7 +1627,10 @@ func (c *cc) convertSingleSource(n *parser.Single_sourceContext) ast.Node { } if n.Select_stmt() != nil { - subquery := c.convert(n.Select_stmt()) + subquery, ok := n.Select_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitSingle_source", n.Select_stmt()) + } return &ast.RangeSubselect{ Subquery: subquery, } @@ -1425,35 +1638,30 @@ func (c *cc) convertSingleSource(n *parser.Single_sourceContext) ast.Node { } // todo: Values stmt - return nil + return todo("VisitSingle_source", n) } -func (c *cc) convertBindParameter(n *parser.Bind_parameterContext) ast.Node { - // !!debug later!! - if n.DOLLAR() != nil { - if n.TRUE() != nil { - return &ast.A_Const{Val: &ast.Boolean{Boolval: true}, Location: c.pos(n.GetStart())} - } - if n.FALSE() != nil { - return &ast.A_Const{Val: &ast.Boolean{Boolval: false}, Location: c.pos(n.GetStart())} - } +func (c *cc) VisitBind_parameter(n *parser.Bind_parameterContext) interface{} { + if n == nil || n.DOLLAR() == nil { + return todo("VisitBind_parameter", n) + } - if an := n.An_id_or_type(); an != nil { - idText := parseAnIdOrType(an) - return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "@"}}}, - Rexpr: &ast.String{Str: idText}, - Location: c.pos(n.GetStart()), - } - } - c.paramCount++ - return &ast.ParamRef{ - Number: c.paramCount, + if n.TRUE() != nil { + return &ast.A_Const{Val: &ast.Boolean{Boolval: true}, Location: c.pos(n.GetStart())} + } + if n.FALSE() != nil { + return &ast.A_Const{Val: &ast.Boolean{Boolval: false}, Location: c.pos(n.GetStart())} + } + + if an := n.An_id_or_type(); an != nil { + idText := parseAnIdOrType(an) + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "@"}}}, + Rexpr: &ast.String{Str: idText}, Location: c.pos(n.GetStart()), - Dollar: true, } } - return &ast.TODO{} + return todo("VisitBind_parameter", n) } func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef { @@ -1472,129 +1680,154 @@ func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef } } -func (c *cc) convertOptIdPrefix(ctx parser.IOpt_id_prefixContext) string { - if ctx == nil { +func (c *cc) convertOptIdPrefix(n parser.IOpt_id_prefixContext) string { + if n == nil { return "" } - if ctx.An_id() != nil { - return ctx.An_id().GetText() + if n.An_id() != nil { + return n.An_id().GetText() } return "" } -func (c *cc) convertCreate_table_stmtContext(n *parser.Create_table_stmtContext) ast.Node { +func (c *cc) VisitCreate_table_stmt(n *parser.Create_table_stmtContext) interface{} { stmt := &ast.CreateTableStmt{ Name: parseTableName(n.Simple_table_ref().Simple_table_ref_core()), IfNotExists: n.EXISTS() != nil, } - for _, idef := range n.AllCreate_table_entry() { - if def, ok := idef.(*parser.Create_table_entryContext); ok { + for _, def := range n.AllCreate_table_entry() { + switch { + case def.Column_schema() != nil: + temp, ok := def.Column_schema().Accept(c).(ast.Node) + if !ok { + return todo("VisitCreate_table_stmt", def.Column_schema()) + } + colCtx, ok := temp.(*ast.ColumnDef) + if !ok { + return todo("VisitCreate_table_stmt", def.Column_schema()) + } + stmt.Cols = append(stmt.Cols, colCtx) + case def.Table_constraint() != nil: + conCtx := def.Table_constraint() switch { - case def.Column_schema() != nil: - if colCtx, ok := def.Column_schema().(*parser.Column_schemaContext); ok { - colDef := c.convertColumnSchema(colCtx) - if colDef != nil { - stmt.Cols = append(stmt.Cols, colDef) - } - } - case def.Table_constraint() != nil: - if conCtx, ok := def.Table_constraint().(*parser.Table_constraintContext); ok { - switch { - case conCtx.PRIMARY() != nil && conCtx.KEY() != nil: - for _, cname := range conCtx.AllAn_id() { - for _, col := range stmt.Cols { - if col.Colname == parseAnId(cname) { - col.IsNotNull = true - } - } + case conCtx.PRIMARY() != nil && conCtx.KEY() != nil: + for _, cname := range conCtx.AllAn_id() { + for _, col := range stmt.Cols { + if col.Colname == parseAnId(cname) { + col.IsNotNull = true } - case conCtx.PARTITION() != nil && conCtx.BY() != nil: - _ = conCtx - // todo: partition by constraint - case conCtx.ORDER() != nil && conCtx.BY() != nil: - _ = conCtx - // todo: order by constraint } } - - case def.Table_index() != nil: - if indCtx, ok := def.Table_index().(*parser.Table_indexContext); ok { - _ = indCtx - // todo - } - case def.Family_entry() != nil: - if famCtx, ok := def.Family_entry().(*parser.Family_entryContext); ok { - _ = famCtx - // todo - } - case def.Changefeed() != nil: // таблица ориентированная - if cgfCtx, ok := def.Changefeed().(*parser.ChangefeedContext); ok { - _ = cgfCtx - // todo - } + case conCtx.PARTITION() != nil && conCtx.BY() != nil: + return todo("VisitCreate_table_stmt", conCtx) + case conCtx.ORDER() != nil && conCtx.BY() != nil: + return todo("VisitCreate_table_stmt", conCtx) } + + case def.Table_index() != nil: + return todo("VisitCreate_table_stmt", def.Table_index()) + case def.Family_entry() != nil: + return todo("VisitCreate_table_stmt", def.Family_entry()) + case def.Changefeed() != nil: // table-oriented + return todo("VisitCreate_table_stmt", def.Changefeed()) } } return stmt } -func (c *cc) convertColumnSchema(n *parser.Column_schemaContext) *ast.ColumnDef { - +func (c *cc) VisitColumn_schema(n *parser.Column_schemaContext) interface{} { + if n == nil { + return todo("VisitColumn_schema", n) + } col := &ast.ColumnDef{} if anId := n.An_id_schema(); anId != nil { col.Colname = identifier(parseAnIdSchema(anId)) } if tnb := n.Type_name_or_bind(); tnb != nil { - col.TypeName = c.convertTypeNameOrBind(tnb) + temp, ok := tnb.Accept(c).(ast.Node) + if !ok { + return todo("VisitColumn_schema", tnb) + } + typeName, ok := temp.(*ast.TypeName) + if !ok { + return todo("VisitColumn_schema", tnb) + } + col.TypeName = typeName } if colCons := n.Opt_column_constraints(); colCons != nil { col.IsNotNull = colCons.NOT() != nil && colCons.NULL() != nil - //todo: cover exprs if needed + + if colCons.DEFAULT() != nil && colCons.Expr() != nil { + defaultExpr, ok := colCons.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitColumn_schema", colCons.Expr()) + } + col.RawDefault = defaultExpr + } } // todo: family return col } -func (c *cc) convertTypeNameOrBind(n parser.IType_name_or_bindContext) *ast.TypeName { +func (c *cc) VisitType_name_or_bind(n *parser.Type_name_or_bindContext) interface{} { + if n == nil { + return todo("VisitType_name_or_bind", n) + } + if t := n.Type_name(); t != nil { - return c.convertTypeName(t) + temp, ok := t.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_or_bind", t) + } + typeName, ok := temp.(*ast.TypeName) + if !ok { + return todo("VisitType_name_or_bind", t) + } + return typeName } else if b := n.Bind_parameter(); b != nil { return &ast.TypeName{Name: "BIND:" + identifier(parseAnIdOrType(b.An_id_or_type()))} } - return nil + return todo("VisitType_name_or_bind", n) } -func (c *cc) convertTypeName(n parser.IType_nameContext) *ast.TypeName { +func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} { if n == nil { - return nil + return todo("VisitType_name", n) } if composite := n.Type_name_composite(); composite != nil { - if node := c.convertTypeNameComposite(composite); node != nil { - if typeName, ok := node.(*ast.TypeName); ok { - return typeName - } + typeName, ok := composite.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_or_bind", composite) } + return typeName } if decimal := n.Type_name_decimal(); decimal != nil { if integerOrBinds := decimal.AllInteger_or_bind(); len(integerOrBinds) >= 2 { + first, ok := integerOrBinds[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name", decimal.Integer_or_bind(0)) + } + second, ok := integerOrBinds[1].Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name", decimal.Integer_or_bind(1)) + } return &ast.TypeName{ Name: "decimal", TypeOid: 0, Names: &ast.List{ Items: []ast.Node{ - c.convertIntegerOrBind(integerOrBinds[0]), - c.convertIntegerOrBind(integerOrBinds[1]), + first, + second, }, }, } } } - // Handle simple types if simple := n.Type_name_simple(); simple != nil { return &ast.TypeName{ Name: simple.GetText(), @@ -1602,41 +1835,49 @@ func (c *cc) convertTypeName(n parser.IType_nameContext) *ast.TypeName { } } - return nil + return todo("VisitType_name", n) } -func (c *cc) convertIntegerOrBind(n parser.IInteger_or_bindContext) ast.Node { +func (c *cc) VisitInteger_or_bind(n *parser.Integer_or_bindContext) interface{} { if n == nil { - return nil + return todo("VisitInteger_or_bind", n) } if integer := n.Integer(); integer != nil { val, err := parseIntegerValue(integer.GetText()) if err != nil { - return &ast.TODO{} + return todo("VisitInteger_or_bind", n.Integer()) } return &ast.Integer{Ival: val} } if bind := n.Bind_parameter(); bind != nil { - return c.convertBindParameter(bind.(*parser.Bind_parameterContext)) + temp, ok := bind.Accept(c).(ast.Node) + if !ok { + return todo("VisitInteger_or_bind", bind) + } + return temp } - return nil + return todo("VisitInteger_or_bind", n) } -func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast.Node { +func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) interface{} { if n == nil { - return nil + return todo("VisitType_name_composite", n) } if opt := n.Type_name_optional(); opt != nil { if typeName := opt.Type_name_or_bind(); typeName != nil { + tn, ok := typeName.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeName) + } return &ast.TypeName{ Name: "Optional", TypeOid: 0, Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + Items: []ast.Node{tn}, }, } } @@ -1646,7 +1887,11 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if typeNames := tuple.AllType_name_or_bind(); len(typeNames) > 0 { var items []ast.Node for _, tn := range typeNames { - items = append(items, c.convertTypeNameOrBind(tn)) + tnNode, ok := tn.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", tn) + } + items = append(items, tnNode) } return &ast.TypeName{ Name: "Tuple", @@ -1688,11 +1933,15 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if list := n.Type_name_list(); list != nil { if typeName := list.Type_name_or_bind(); typeName != nil { + tn, ok := typeName.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeName) + } return &ast.TypeName{ Name: "List", TypeOid: 0, Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + Items: []ast.Node{tn}, }, } } @@ -1700,37 +1949,41 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if stream := n.Type_name_stream(); stream != nil { if typeName := stream.Type_name_or_bind(); typeName != nil { + tn, ok := typeName.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeName) + } return &ast.TypeName{ Name: "Stream", TypeOid: 0, Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + Items: []ast.Node{tn}, }, } } } if flow := n.Type_name_flow(); flow != nil { - if typeName := flow.Type_name_or_bind(); typeName != nil { - return &ast.TypeName{ - Name: "Flow", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, - }, - } - } + return todo("VisitType_name_composite", flow) } if dict := n.Type_name_dict(); dict != nil { if typeNames := dict.AllType_name_or_bind(); len(typeNames) >= 2 { + first, ok := typeNames[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeNames[0]) + } + second, ok := typeNames[1].Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeNames[1]) + } return &ast.TypeName{ Name: "Dict", TypeOid: 0, Names: &ast.List{ Items: []ast.Node{ - c.convertTypeNameOrBind(typeNames[0]), - c.convertTypeNameOrBind(typeNames[1]), + first, + second, }, }, } @@ -1739,259 +1992,234 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if set := n.Type_name_set(); set != nil { if typeName := set.Type_name_or_bind(); typeName != nil { + tn, ok := typeName.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeName) + } return &ast.TypeName{ Name: "Set", TypeOid: 0, Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + Items: []ast.Node{tn}, }, } } } - if enum := n.Type_name_enum(); enum != nil { - if typeTags := enum.AllType_name_tag(); len(typeTags) > 0 { - var items []ast.Node - for range typeTags { // todo: Handle enum tags - items = append(items, &ast.TODO{}) - } - return &ast.TypeName{ - Name: "Enum", - TypeOid: 0, - Names: &ast.List{Items: items}, - } - } + if enum := n.Type_name_enum(); enum != nil { // todo: handle enum + todo("VisitType_name_composite", enum) } - if resource := n.Type_name_resource(); resource != nil { - if typeTag := resource.Type_name_tag(); typeTag != nil { - // TODO: Handle resource tag - return &ast.TypeName{ - Name: "Resource", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{&ast.TODO{}}, - }, - } - } + if resource := n.Type_name_resource(); resource != nil { // todo: handle resource + todo("VisitType_name_composite", resource) } - if tagged := n.Type_name_tagged(); tagged != nil { - if typeName := tagged.Type_name_or_bind(); typeName != nil { - if typeTag := tagged.Type_name_tag(); typeTag != nil { - // TODO: Handle tagged type and tag - return &ast.TypeName{ - Name: "Tagged", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{ - c.convertTypeNameOrBind(typeName), - &ast.TODO{}, - }, - }, - } - } - } + if tagged := n.Type_name_tagged(); tagged != nil { // todo: handle tagged + todo("VisitType_name_composite", tagged) } - if callable := n.Type_name_callable(); callable != nil { - // TODO: Handle callable argument list and return type - return &ast.TypeName{ - Name: "Callable", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{&ast.TODO{}}, - }, - } + if callable := n.Type_name_callable(); callable != nil { // todo: handle callable + todo("VisitType_name_composite", callable) } - return nil + return todo("VisitType_name_composite", n) } -func (c *cc) convertSqlStmtCore(n parser.ISql_stmt_coreContext) ast.Node { +func (c *cc) VisitSql_stmt_core(n *parser.Sql_stmt_coreContext) interface{} { if n == nil { - return nil + return todo("VisitSql_stmt_core", n) } if stmt := n.Pragma_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Select_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Named_nodes_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) + } + if stmt := n.Named_nodes_stmt(); stmt != nil { + return stmt.Accept(c) + } + if stmt := n.Create_table_stmt(); stmt != nil { + return stmt.Accept(c) } if stmt := n.Drop_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Use_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Into_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Commit_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Update_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Delete_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Rollback_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Declare_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Import_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Export_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_external_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Do_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Define_action_or_subquery_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.If_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.For_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Values_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_user_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_user_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_group_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_group_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_role_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_object_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_object_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_object_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_external_data_source_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_external_data_source_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_external_data_source_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_replication_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_replication_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_topic_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_topic_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_topic_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Grant_permissions_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Revoke_permissions_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_table_store_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Upsert_object_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_view_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_view_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_replication_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_resource_pool_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_resource_pool_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_resource_pool_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_backup_collection_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_backup_collection_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_backup_collection_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Analyze_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_resource_pool_classifier_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_resource_pool_classifier_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_resource_pool_classifier_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Backup_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Restore_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_sequence_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } - return nil + return todo("VisitSql_stmt_core", n) } -func (c *cc) convertNamed_exprContext(n *parser.Named_exprContext) ast.Node { +func (c *cc) VisitNamed_expr(n *parser.Named_exprContext) interface{} { if n == nil || n.Expr() == nil { - return todo("convertNamed_exprContext", n) + return todo("VisitNamed_expr", n) } - expr := c.convert(n.Expr()) + + expr, ok := n.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitNamed_expr", n) + } + if n.AS() != nil && n.An_id_or_type() != nil { name := parseAnIdOrType(n.An_id_or_type()) return &ast.ResTarget{ @@ -2003,32 +2231,32 @@ func (c *cc) convertNamed_exprContext(n *parser.Named_exprContext) ast.Node { return expr } -func (c *cc) convertExpr(n *parser.ExprContext) ast.Node { +func (c *cc) VisitExpr(n *parser.ExprContext) interface{} { if n == nil { - return nil + return todo("VisitExpr", n) } if tn := n.Type_name_composite(); tn != nil { - return c.convertTypeNameComposite(tn) + return tn.Accept(c) } orSubs := n.AllOr_subexpr() if len(orSubs) == 0 { - return nil + return todo("VisitExpr", n) } - orSub, ok := orSubs[0].(*parser.Or_subexprContext) + left, ok := n.Or_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitExpr", n) } - left := c.convertOrSubExpr(orSub) for i := 1; i < len(orSubs); i++ { - orSub, ok = orSubs[i].(*parser.Or_subexprContext) + + right, ok := orSubs[i].Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitExpr", n) } - right := c.convertOrSubExpr(orSub) + left = &ast.BoolExpr{ Boolop: ast.BoolExprTypeOr, Args: &ast.List{Items: []ast.Node{left, right}}, @@ -2038,26 +2266,23 @@ func (c *cc) convertExpr(n *parser.ExprContext) ast.Node { return left } -func (c *cc) convertOrSubExpr(n *parser.Or_subexprContext) ast.Node { - if n == nil { - return nil +func (c *cc) VisitOr_subexpr(n *parser.Or_subexprContext) interface{} { + if n == nil || len(n.AllAnd_subexpr()) == 0 { + return todo("VisitOr_subexpr", n) } - andSubs := n.AllAnd_subexpr() - if len(andSubs) == 0 { - return nil - } - andSub, ok := andSubs[0].(*parser.And_subexprContext) + + left, ok := n.And_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitOr_subexpr", n) } - left := c.convertAndSubexpr(andSub) - for i := 1; i < len(andSubs); i++ { - andSub, ok = andSubs[i].(*parser.And_subexprContext) + for i := 1; i < len(n.AllAnd_subexpr()); i++ { + + right, ok := n.And_subexpr(i).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitOr_subexpr", n) } - right := c.convertAndSubexpr(andSub) + left = &ast.BoolExpr{ Boolop: ast.BoolExprTypeAnd, Args: &ast.List{Items: []ast.Node{left, right}}, @@ -2067,28 +2292,23 @@ func (c *cc) convertOrSubExpr(n *parser.Or_subexprContext) ast.Node { return left } -func (c *cc) convertAndSubexpr(n *parser.And_subexprContext) ast.Node { - if n == nil { - return nil - } - - xors := n.AllXor_subexpr() - if len(xors) == 0 { - return nil +func (c *cc) VisitAnd_subexpr(n *parser.And_subexprContext) interface{} { + if n == nil || len(n.AllXor_subexpr()) == 0 { + return todo("VisitAnd_subexpr", n) } - xor, ok := xors[0].(*parser.Xor_subexprContext) + left, ok := n.Xor_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitAnd_subexpr", n) } - left := c.convertXorSubexpr(xor) - for i := 1; i < len(xors); i++ { - xor, ok = xors[i].(*parser.Xor_subexprContext) + for i := 1; i < len(n.AllXor_subexpr()); i++ { + + right, ok := n.Xor_subexpr(i).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitAnd_subexpr", n) } - right := c.convertXorSubexpr(xor) + left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: "XOR"}}}, Lexpr: left, @@ -2099,40 +2319,53 @@ func (c *cc) convertAndSubexpr(n *parser.And_subexprContext) ast.Node { return left } -func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { - if n == nil { - return nil - } - es := n.Eq_subexpr() - if es == nil { - return nil +func (c *cc) VisitXor_subexpr(n *parser.Xor_subexprContext) interface{} { + if n == nil || n.Eq_subexpr() == nil { + return todo("VisitXor_subexpr", n) } - subExpr, ok := es.(*parser.Eq_subexprContext) + + base, ok := n.Eq_subexpr().Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitXor_subexpr", n) } - base := c.convertEqSubexpr(subExpr) - if cond := n.Cond_expr(); cond != nil { - condCtx, ok := cond.(*parser.Cond_exprContext) - if !ok { - return base - } + + if condCtx := n.Cond_expr(); condCtx != nil { switch { case condCtx.IN() != nil: if inExpr := condCtx.In_expr(); inExpr != nil { - return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "IN"}}}, - Lexpr: base, - Rexpr: c.convert(inExpr), + temp, ok := inExpr.Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", inExpr) + } + list, ok := temp.(*ast.List) + if !ok { + return todo("VisitXor_subexpr", inExpr) + } + return &ast.In{ + Expr: base, + List: list.Items, + Not: condCtx.NOT() != nil, + Location: c.pos(n.GetStart()), } } case condCtx.BETWEEN() != nil: if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 2 { + + first, ok := eqSubs[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", n) + } + + second, ok := eqSubs[1].Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", n) + } + return &ast.BetweenExpr{ Expr: base, - Left: c.convert(eqSubs[0]), - Right: c.convert(eqSubs[1]), + Left: first, + Right: second, Not: condCtx.NOT() != nil, Location: c.pos(n.GetStart()), } @@ -2155,7 +2388,7 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { Nulltesttype: 1, // IS NULL Location: c.pos(n.GetStart()), } - case condCtx.IS() != nil && condCtx.NOT() != nil && condCtx.NULL() != nil: + case condCtx.NOT() != nil && condCtx.NULL() != nil: return &ast.NullTest{ Arg: base, Nulltesttype: 2, // IS NOT NULL @@ -2165,10 +2398,16 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { // debug!!! matchOp := condCtx.Match_op().GetText() if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 1 { + + xpr, ok := eqSubs[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", n) + } + expr := &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: matchOp}}}, Lexpr: base, - Rexpr: c.convert(eqSubs[0]), + Rexpr: xpr, } if condCtx.ESCAPE() != nil && len(eqSubs) >= 2 { //nolint // todo: Add ESCAPE support @@ -2177,25 +2416,43 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { } case len(condCtx.AllEQUALS()) > 0 || len(condCtx.AllEQUALS2()) > 0 || len(condCtx.AllNOT_EQUALS()) > 0 || len(condCtx.AllNOT_EQUALS2()) > 0: - // debug!!! - var op string - switch { - case len(condCtx.AllEQUALS()) > 0: - op = "=" - case len(condCtx.AllEQUALS2()) > 0: - op = "==" - case len(condCtx.AllNOT_EQUALS()) > 0: - op = "!=" - case len(condCtx.AllNOT_EQUALS2()) > 0: - op = "<>" - } - if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 1 { - return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, - Lexpr: base, - Rexpr: c.convert(eqSubs[0]), + eqSubs := condCtx.AllEq_subexpr() + if len(eqSubs) >= 1 { + left := base + + ops := c.collectEqualityOps(condCtx) + + for i, eqSub := range eqSubs { + right, ok := eqSub.Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", condCtx) + } + + var op string + if i < len(ops) { + op = ops[i].GetText() + } else { + if len(condCtx.AllEQUALS()) > 0 { + op = "=" + } else if len(condCtx.AllEQUALS2()) > 0 { + op = "==" + } else if len(condCtx.AllNOT_EQUALS()) > 0 { + op = "!=" + } else if len(condCtx.AllNOT_EQUALS2()) > 0 { + op = "<>" + } + } + + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, + Lexpr: left, + Rexpr: right, + Location: c.pos(condCtx.GetStart()), + } } + return left } + return todo("VisitXor_subexpr", condCtx) case len(condCtx.AllDistinct_from_op()) > 0: // debug!!! distinctOps := condCtx.AllDistinct_from_op() @@ -2206,10 +2463,16 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { if not { op = "IS NOT DISTINCT FROM" } + + xpr, ok := eqSubs[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", n) + } + return &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, Lexpr: base, - Rexpr: c.convert(eqSubs[0]), + Rexpr: xpr, } } } @@ -2218,26 +2481,24 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { return base } -func (c *cc) convertEqSubexpr(n *parser.Eq_subexprContext) ast.Node { - if n == nil { - return nil - } - neqList := n.AllNeq_subexpr() - if len(neqList) == 0 { - return nil +func (c *cc) VisitEq_subexpr(n *parser.Eq_subexprContext) interface{} { + if n == nil || len(n.AllNeq_subexpr()) == 0 { + return todo("VisitEq_subexpr", n) } - neq, ok := neqList[0].(*parser.Neq_subexprContext) + + left, ok := n.Neq_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitEq_subexpr", n) } - left := c.convertNeqSubexpr(neq) + ops := c.collectComparisonOps(n) - for i := 1; i < len(neqList); i++ { - neq, ok = neqList[i].(*parser.Neq_subexprContext) + for i := 1; i < len(n.AllNeq_subexpr()); i++ { + + right, ok := n.Neq_subexpr(i).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitEq_subexpr", n) } - right := c.convertNeqSubexpr(neq) + opText := ops[i-1].GetText() left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, @@ -2249,40 +2510,22 @@ func (c *cc) convertEqSubexpr(n *parser.Eq_subexprContext) ast.Node { return left } -func (c *cc) collectComparisonOps(n parser.IEq_subexprContext) []antlr.TerminalNode { - var ops []antlr.TerminalNode - for _, child := range n.GetChildren() { - if tn, ok := child.(antlr.TerminalNode); ok { - switch tn.GetText() { - case "<", "<=", ">", ">=": - ops = append(ops, tn) - } - } +func (c *cc) VisitNeq_subexpr(n *parser.Neq_subexprContext) interface{} { + if n == nil || len(n.AllBit_subexpr()) == 0 { + return todo("VisitNeq_subexpr", n) } - return ops -} -func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { - if n == nil { - return nil - } - bitList := n.AllBit_subexpr() - if len(bitList) == 0 { - return nil - } - - bl, ok := bitList[0].(*parser.Bit_subexprContext) + left, ok := n.Bit_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitNeq_subexpr", n) } - left := c.convertBitSubexpr(bl) + ops := c.collectBitwiseOps(n) - for i := 1; i < len(bitList); i++ { - bl, ok = bitList[i].(*parser.Bit_subexprContext) + for i := 1; i < len(n.AllBit_subexpr()); i++ { + right, ok := n.Bit_subexpr(i).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitNeq_subexpr", n) } - right := c.convertBitSubexpr(bl) opText := ops[i-1].GetText() left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, @@ -2293,13 +2536,12 @@ func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { } if n.Double_question() != nil { - nextCtx := n.Neq_subexpr() - if nextCtx != nil { - neq, ok2 := nextCtx.(*parser.Neq_subexprContext) + if nextCtx := n.Neq_subexpr(); nextCtx != nil { + right, ok2 := nextCtx.Accept(c).(ast.Node) if !ok2 { - return nil + return todo("VisitNeq_subexpr", n) } - right := c.convertNeqSubexpr(neq) + left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: "??"}}}, Lexpr: left, @@ -2326,28 +2568,24 @@ func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { return left } -func (c *cc) collectBitwiseOps(ctx parser.INeq_subexprContext) []antlr.TerminalNode { - var ops []antlr.TerminalNode - children := ctx.GetChildren() - for _, child := range children { - if tn, ok := child.(antlr.TerminalNode); ok { - txt := tn.GetText() - switch txt { - case "<<", ">>", "<<|", ">>|", "&", "|", "^": - ops = append(ops, tn) - } - } +func (c *cc) VisitBit_subexpr(n *parser.Bit_subexprContext) interface{} { + if n == nil || len(n.AllAdd_subexpr()) == 0 { + return todo("VisitBit_subexpr", n) } - return ops -} -func (c *cc) convertBitSubexpr(n *parser.Bit_subexprContext) ast.Node { - addList := n.AllAdd_subexpr() - left := c.convertAddSubexpr(addList[0].(*parser.Add_subexprContext)) + left, ok := n.Add_subexpr(0).Accept(c).(ast.Node) + if !ok { + return todo("VisitBit_subexpr", n) + } ops := c.collectBitOps(n) - for i := 1; i < len(addList); i++ { - right := c.convertAddSubexpr(addList[i].(*parser.Add_subexprContext)) + for i := 1; i < len(n.AllAdd_subexpr()); i++ { + + right, ok := n.Add_subexpr(i).Accept(c).(ast.Node) + if !ok { + return todo("VisitBit_subexpr", n) + } + opText := ops[i-1].GetText() left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, @@ -2359,28 +2597,24 @@ func (c *cc) convertBitSubexpr(n *parser.Bit_subexprContext) ast.Node { return left } -func (c *cc) collectBitOps(ctx parser.IBit_subexprContext) []antlr.TerminalNode { - var ops []antlr.TerminalNode - children := ctx.GetChildren() - for _, child := range children { - if tn, ok := child.(antlr.TerminalNode); ok { - txt := tn.GetText() - switch txt { - case "+", "-": - ops = append(ops, tn) - } - } +func (c *cc) VisitAdd_subexpr(n *parser.Add_subexprContext) interface{} { + if n == nil || len(n.AllMul_subexpr()) == 0 { + return todo("VisitAdd_subexpr", n) } - return ops -} -func (c *cc) convertAddSubexpr(n *parser.Add_subexprContext) ast.Node { - mulList := n.AllMul_subexpr() - left := c.convertMulSubexpr(mulList[0].(*parser.Mul_subexprContext)) + left, ok := n.Mul_subexpr(0).Accept(c).(ast.Node) + if !ok { + return todo("VisitAdd_subexpr", n) + } ops := c.collectAddOps(n) - for i := 1; i < len(mulList); i++ { - right := c.convertMulSubexpr(mulList[i].(*parser.Mul_subexprContext)) + for i := 1; i < len(n.AllMul_subexpr()); i++ { + + right, ok := n.Mul_subexpr(i).Accept(c).(ast.Node) + if !ok { + return todo("VisitAdd_subexpr", n) + } + opText := ops[i-1].GetText() left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, @@ -2392,25 +2626,23 @@ func (c *cc) convertAddSubexpr(n *parser.Add_subexprContext) ast.Node { return left } -func (c *cc) collectAddOps(ctx parser.IAdd_subexprContext) []antlr.TerminalNode { - var ops []antlr.TerminalNode - for _, child := range ctx.GetChildren() { - if tn, ok := child.(antlr.TerminalNode); ok { - switch tn.GetText() { - case "*", "/", "%": - ops = append(ops, tn) - } - } +func (c *cc) VisitMul_subexpr(n *parser.Mul_subexprContext) interface{} { + if n == nil || len(n.AllCon_subexpr()) == 0 { + return todo("VisitMul_subexpr", n) } - return ops -} -func (c *cc) convertMulSubexpr(n *parser.Mul_subexprContext) ast.Node { - conList := n.AllCon_subexpr() - left := c.convertConSubexpr(conList[0].(*parser.Con_subexprContext)) + left, ok := n.Con_subexpr(0).Accept(c).(ast.Node) + if !ok { + return todo("VisitMul_subexpr", n) + } + + for i := 1; i < len(n.AllCon_subexpr()); i++ { + + right, ok := n.Con_subexpr(i).Accept(c).(ast.Node) + if !ok { + return todo("VisitMul_subexpr", n) + } - for i := 1; i < len(conList); i++ { - right := c.convertConSubexpr(conList[i].(*parser.Con_subexprContext)) left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: "||"}}}, Lexpr: left, @@ -2421,42 +2653,76 @@ func (c *cc) convertMulSubexpr(n *parser.Mul_subexprContext) ast.Node { return left } -func (c *cc) convertConSubexpr(n *parser.Con_subexprContext) ast.Node { +func (c *cc) VisitCon_subexpr(n *parser.Con_subexprContext) interface{} { + if n == nil || (n.Unary_op() == nil && n.Unary_subexpr() == nil) { + return todo("VisitCon_subexpr", n) + } + if opCtx := n.Unary_op(); opCtx != nil { op := opCtx.GetText() - operand := c.convertUnarySubexpr(n.Unary_subexpr().(*parser.Unary_subexprContext)) + operand, ok := n.Unary_subexpr().Accept(c).(ast.Node) + if !ok { + return todo("VisitCon_subexpr", opCtx) + } return &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, Rexpr: operand, Location: c.pos(n.GetStart()), } } - return c.convertUnarySubexpr(n.Unary_subexpr().(*parser.Unary_subexprContext)) + + operand, ok := n.Unary_subexpr().Accept(c).(ast.Node) + if !ok { + return todo("VisitCon_subexpr", n.Unary_subexpr()) + } + return operand + } -func (c *cc) convertUnarySubexpr(n *parser.Unary_subexprContext) ast.Node { +func (c *cc) VisitUnary_subexpr(n *parser.Unary_subexprContext) interface{} { + if n == nil || (n.Unary_casual_subexpr() == nil && n.Json_api_expr() == nil) { + return todo("VisitUnary_subexpr", n) + } + if casual := n.Unary_casual_subexpr(); casual != nil { - return c.convertUnaryCasualSubexpr(casual.(*parser.Unary_casual_subexprContext)) + expr, ok := casual.Accept(c).(ast.Node) + if !ok { + return todo("VisitUnary_subexpr", casual) + } + return expr } if jsonExpr := n.Json_api_expr(); jsonExpr != nil { - return c.convertJsonApiExpr(jsonExpr.(*parser.Json_api_exprContext)) + expr, ok := jsonExpr.Accept(c).(ast.Node) + if !ok { + return todo("VisitUnary_subexpr", jsonExpr) + } + return expr } - return nil + + return todo("VisitUnary_subexpr", n) } -func (c *cc) convertJsonApiExpr(n *parser.Json_api_exprContext) ast.Node { - return todo("Json_api_exprContext", n) +func (c *cc) VisitJson_api_expr(n *parser.Json_api_exprContext) interface{} { + return todo("VisitJson_api_expr", n) } -func (c *cc) convertUnaryCasualSubexpr(n *parser.Unary_casual_subexprContext) ast.Node { +func (c *cc) VisitUnary_casual_subexpr(n *parser.Unary_casual_subexprContext) interface{} { var current ast.Node switch { case n.Id_expr() != nil: - current = c.convertIdExpr(n.Id_expr().(*parser.Id_exprContext)) + expr, ok := n.Id_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUnary_casual_subexpr", n.Id_expr()) + } + current = expr case n.Atom_expr() != nil: - current = c.convertAtomExpr(n.Atom_expr().(*parser.Atom_exprContext)) + expr, ok := n.Atom_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUnary_casual_subexpr", n.Atom_expr()) + } + current = expr default: - return todo("Unary_casual_subexprContext", n) + return todo("VisitUnary_casual_subexpr", n) } if suffix := n.Unary_subexpr_suffix(); suffix != nil { @@ -2478,17 +2744,24 @@ func (c *cc) processSuffixChain(base ast.Node, suffix *parser.Unary_subexpr_suff case antlr.TerminalNode: if elem.GetText() == "." { current = c.handleDotSuffix(current, suffix, &i) + } else { + return todo("Unary_subexpr_suffixContext", suffix) } + default: + return todo("Unary_subexpr_suffixContext", suffix) } } return current } func (c *cc) handleKeySuffix(base ast.Node, keyCtx *parser.Key_exprContext) ast.Node { - keyNode := c.convertKey_exprContext(keyCtx) + keyNode, ok := keyCtx.Accept(c).(ast.Node) + if !ok { + return todo("VisitKey_expr", keyCtx) + } ind, ok := keyNode.(*ast.A_Indirection) if !ok { - return todo("Key_exprContext", keyCtx) + return todo("VisitKey_expr", keyCtx) } if indirection, ok := base.(*ast.A_Indirection); ok { @@ -2505,9 +2778,13 @@ func (c *cc) handleKeySuffix(base ast.Node, keyCtx *parser.Key_exprContext) ast. } func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprContext, idx int) ast.Node { - funcCall, ok := c.convertInvoke_exprContext(invokeCtx).(*ast.FuncCall) + temp, ok := invokeCtx.Accept(c).(ast.Node) if !ok { - return todo("Invoke_exprContext", invokeCtx) + return todo("VisitInvoke_expr", invokeCtx) + } + funcCall, ok := temp.(*ast.FuncCall) + if !ok { + return todo("VisitInvoke_expr", invokeCtx) } if idx == 0 { @@ -2535,7 +2812,7 @@ func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprCont return funcCall } default: - return todo("Invoke_exprContext", invokeCtx) + return todo("VisitInvoke_expr", invokeCtx) } } @@ -2562,14 +2839,18 @@ func (c *cc) handleDotSuffix(base ast.Node, suffix *parser.Unary_subexpr_suffixC var field ast.Node switch v := next.(type) { case *parser.Bind_parameterContext: - field = c.convertBindParameter(v) + temp, ok := v.Accept(c).(ast.Node) + if !ok { + return todo("VisitBind_parameter", v) + } + field = temp case *parser.An_id_or_typeContext: field = &ast.String{Str: parseAnIdOrType(v)} case antlr.TerminalNode: if val, err := parseIntegerValue(v.GetText()); err == nil { field = &ast.A_Const{Val: &ast.Integer{Ival: val}} } else { - return &ast.TODO{} + return todo("Unary_subexpr_suffixContext", suffix) } } @@ -2586,16 +2867,19 @@ func (c *cc) handleDotSuffix(base ast.Node, suffix *parser.Unary_subexpr_suffixC } } -func (c *cc) convertKey_exprContext(n *parser.Key_exprContext) ast.Node { +func (c *cc) VisitKey_expr(n *parser.Key_exprContext) interface{} { if n.LBRACE_SQUARE() == nil || n.RBRACE_SQUARE() == nil || n.Expr() == nil { - return todo("Key_exprContext", n) + return todo("VisitKey_expr", n) } stmt := &ast.A_Indirection{ Indirection: &ast.List{}, } - expr := c.convert(n.Expr()) + expr, ok := n.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitKey_expr", n.Expr()) + } stmt.Indirection.Items = append(stmt.Indirection.Items, &ast.A_Indices{ Uidx: expr, @@ -2604,9 +2888,9 @@ func (c *cc) convertKey_exprContext(n *parser.Key_exprContext) ast.Node { return stmt } -func (c *cc) convertInvoke_exprContext(n *parser.Invoke_exprContext) ast.Node { +func (c *cc) VisitInvoke_expr(n *parser.Invoke_exprContext) interface{} { if n.LPAREN() == nil || n.RPAREN() == nil { - return todo("Invoke_exprContext", n) + return todo("VisitInvoke_expr", n) } distinct := false @@ -2625,7 +2909,10 @@ func (c *cc) convertInvoke_exprContext(n *parser.Invoke_exprContext) ast.Node { if nList := n.Named_expr_list(); nList != nil { for _, namedExpr := range nList.AllNamed_expr() { name := parseAnIdOrType(namedExpr.An_id_or_type()) - expr := c.convert(namedExpr.Expr()) + expr, ok := namedExpr.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitInvoke_expr", namedExpr.Expr()) + } var res ast.Node if rt, ok := expr.(*ast.ResTarget); ok { @@ -2652,7 +2939,10 @@ func (c *cc) convertInvoke_exprContext(n *parser.Invoke_exprContext) ast.Node { return stmt } -func (c *cc) convertIdExpr(n *parser.Id_exprContext) ast.Node { +func (c *cc) VisitId_expr(n *parser.Id_exprContext) interface{} { + if n == nil { + return todo("VisitId_expr", n) + } if id := n.Identifier(); id != nil { return &ast.ColumnRef{ Fields: &ast.List{ @@ -2663,25 +2953,43 @@ func (c *cc) convertIdExpr(n *parser.Id_exprContext) ast.Node { Location: c.pos(id.GetStart()), } } - return &ast.TODO{} + return todo("VisitId_expr", n) } -func (c *cc) convertAtomExpr(n *parser.Atom_exprContext) ast.Node { +func (c *cc) VisitAtom_expr(n *parser.Atom_exprContext) interface{} { + if n == nil { + return todo("VisitAtom_expr", n) + } + switch { - case n.An_id_or_type() != nil && n.NAMESPACE() != nil: - return NewIdentifier(parseAnIdOrType(n.An_id_or_type()) + "::" + parseIdOrType(n.Id_or_type())) case n.An_id_or_type() != nil: + if n.NAMESPACE() != nil { + return NewIdentifier(parseAnIdOrType(n.An_id_or_type()) + "::" + parseIdOrType(n.Id_or_type())) + } return NewIdentifier(parseAnIdOrType(n.An_id_or_type())) case n.Literal_value() != nil: - return c.convertLiteralValue(n.Literal_value().(*parser.Literal_valueContext)) + expr, ok := n.Literal_value().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Literal_value()) + } + return expr case n.Bind_parameter() != nil: - return c.convertBindParameter(n.Bind_parameter().(*parser.Bind_parameterContext)) + expr, ok := n.Bind_parameter().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Bind_parameter()) + } + return expr + // TODO: check other cases default: - return &ast.TODO{} + return todo("VisitAtom_expr", n) } } -func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { +func (c *cc) VisitLiteral_value(n *parser.Literal_valueContext) interface{} { + if n == nil { + return todo("VisitLiteral_value", n) + } + switch { case n.Integer() != nil: text := n.Integer().GetText() @@ -2690,7 +2998,7 @@ func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { if debug.Active { log.Printf("Failed to parse integer value '%s': %v", text, err) } - return &ast.TODO{} + return todo("VisitLiteral_value", n.Integer()) } return &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: c.pos(n.GetStart())} @@ -2716,22 +3024,16 @@ func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { return &ast.Null{} case n.CURRENT_TIME() != nil: - if debug.Active { - log.Printf("TODO: Implement CURRENT_TIME") - } - return &ast.TODO{} + log.Fatalf("CURRENT_TIME is not supported yet") + return todo("VisitLiteral_value", n) case n.CURRENT_DATE() != nil: - if debug.Active { - log.Printf("TODO: Implement CURRENT_DATE") - } - return &ast.TODO{} + log.Fatalf("CURRENT_DATE is not supported yet") + return todo("VisitLiteral_value", n) case n.CURRENT_TIMESTAMP() != nil: - if debug.Active { - log.Printf("TODO: Implement CURRENT_TIMESTAMP") - } - return &ast.TODO{} + log.Fatalf("CURRENT_TIMESTAMP is not supported yet") + return todo("VisitLiteral_value", n) case n.BLOB() != nil: blobText := n.BLOB().GetText() @@ -2744,205 +3046,36 @@ func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { return &ast.TODO{} default: - if debug.Active { - log.Printf("Unknown literal value type: %T", n) - } - return &ast.TODO{} + return todo("VisitLiteral_value", n) } } -func (c *cc) convertSqlStmt(n *parser.Sql_stmtContext) ast.Node { - if n == nil { - return nil - } - // todo: handle explain - if core := n.Sql_stmt_core(); core != nil { - return c.convert(core) +func (c *cc) VisitSql_stmt(n *parser.Sql_stmtContext) interface{} { + if n == nil || n.Sql_stmt_core() == nil { + return todo("VisitSql_stmt", n) } - return nil -} - -func (c *cc) convert(node node) ast.Node { - switch n := node.(type) { - case *parser.Sql_stmtContext: - return c.convertSqlStmt(n) - - case *parser.Sql_stmt_coreContext: - return c.convertSqlStmtCore(n) - - case *parser.Create_table_stmtContext: - return c.convertCreate_table_stmtContext(n) - - case *parser.Select_stmtContext: - return c.convertSelectStmtContext(n) - - case *parser.Result_columnContext: - return c.convertResultColumn(n) - - case *parser.Join_sourceContext: - return c.convertJoinSource(n) - - case *parser.Flatten_sourceContext: - return c.convertFlattenSource(n) - - case *parser.Named_single_sourceContext: - return c.convertNamedSingleSource(n) - - case *parser.Single_sourceContext: - return c.convertSingleSource(n) - - case *parser.Bind_parameterContext: - return c.convertBindParameter(n) - - case *parser.ExprContext: - return c.convertExpr(n) - - case *parser.Or_subexprContext: - return c.convertOrSubExpr(n) - - case *parser.And_subexprContext: - return c.convertAndSubexpr(n) - - case *parser.Xor_subexprContext: - return c.convertXorSubexpr(n) - - case *parser.Eq_subexprContext: - return c.convertEqSubexpr(n) - - case *parser.Neq_subexprContext: - return c.convertNeqSubexpr(n) - - case *parser.Bit_subexprContext: - return c.convertBitSubexpr(n) - - case *parser.Add_subexprContext: - return c.convertAddSubexpr(n) - - case *parser.Mul_subexprContext: - return c.convertMulSubexpr(n) - - case *parser.Con_subexprContext: - return c.convertConSubexpr(n) - - case *parser.Unary_subexprContext: - return c.convertUnarySubexpr(n) - - case *parser.Unary_casual_subexprContext: - return c.convertUnaryCasualSubexpr(n) - - case *parser.Id_exprContext: - return c.convertIdExpr(n) - - case *parser.Atom_exprContext: - return c.convertAtomExpr(n) - - case *parser.Literal_valueContext: - return c.convertLiteralValue(n) - - case *parser.Json_api_exprContext: - return c.convertJsonApiExpr(n) - - case *parser.Type_name_compositeContext: - return c.convertTypeNameComposite(n) - - case *parser.Type_nameContext: - return c.convertTypeName(n) - - case *parser.Integer_or_bindContext: - return c.convertIntegerOrBind(n) - - case *parser.Type_name_or_bindContext: - return c.convertTypeNameOrBind(n) - - case *parser.Into_table_stmtContext: - return c.convertInto_table_stmtContext(n) - - case *parser.Values_stmtContext: - return c.convertValues_stmtContext(n) - - case *parser.Returning_columns_listContext: - return c.convertReturning_columns_listContext(n) - - case *parser.Delete_stmtContext: - return c.convertDelete_stmtContext(n) - - case *parser.Update_stmtContext: - return c.convertUpdate_stmtContext(n) - - case *parser.Alter_table_stmtContext: - return c.convertAlter_table_stmtContext(n) - - case *parser.Do_stmtContext: - return c.convertDo_stmtContext(n) - - case *parser.Drop_table_stmtContext: - return c.convertDrop_table_stmtContext(n) - - case *parser.Commit_stmtContext: - return c.convertCommit_stmtContext(n) - - case *parser.Rollback_stmtContext: - return c.convertRollback_stmtContext(n) - - case *parser.Pragma_valueContext: - return c.convertPragma_valueContext(n) - - case *parser.Pragma_stmtContext: - return c.convertPragma_stmtContext(n) - - case *parser.Use_stmtContext: - return c.convertUse_stmtContext(n) - - case *parser.Cluster_exprContext: - return c.convertCluster_exprContext(n) - - case *parser.Create_user_stmtContext: - return c.convertCreate_user_stmtContext(n) - - case *parser.Role_nameContext: - return c.convertRole_nameContext(n) - - case *parser.User_optionContext: - return c.convertUser_optionContext(n) - - case *parser.Create_group_stmtContext: - return c.convertCreate_group_stmtContext(n) - - case *parser.Alter_user_stmtContext: - return c.convertAlter_user_stmtContext(n) - - case *parser.Alter_group_stmtContext: - return c.convertAlter_group_stmtContext(n) - - case *parser.Drop_role_stmtContext: - return c.convertDrop_role_stmtContext(n) - - case *parser.Grouping_elementContext: - return c.convertGrouping_elementContext(n) - - case *parser.Ordinary_grouping_setContext: - return c.convertOrdinary_grouping_setContext(n) - - case *parser.Rollup_listContext: - return c.convertRollup_listContext(n) - - case *parser.Cube_listContext: - return c.convertCube_listContext(n) - - case *parser.Grouping_sets_specificationContext: - return c.convertGrouping_sets_specificationContext(n) - - case *parser.Named_exprContext: - return c.convertNamed_exprContext(n) + expr, ok := n.Sql_stmt_core().Accept(c).(ast.Node) + if !ok { + return todo("VisitSql_stmt", n.Sql_stmt_core()) + } - case *parser.Call_actionContext: - return c.convertCall_actionContext(n) + if n.EXPLAIN() != nil { + options := &ast.List{Items: []ast.Node{}} - case *parser.Inline_actionContext: - return c.convertInline_actionContext(n) + if n.QUERY() != nil && n.PLAN() != nil { + queryPlan := "QUERY PLAN" + options.Items = append(options.Items, &ast.DefElem{ + Defname: &queryPlan, + Arg: &ast.TODO{}, + }) + } - default: - return todo("convert(case=default)", n) + return &ast.ExplainStmt{ + Query: expr, + Options: options, + } } + + return expr } diff --git a/internal/engine/ydb/parse.go b/internal/engine/ydb/parse.go index 1c263924a5..8fbdd81ebb 100755 --- a/internal/engine/ydb/parse.go +++ b/internal/engine/ydb/parse.go @@ -64,7 +64,10 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { loc := 0 for _, stmt := range stmtListCtx.AllSql_stmt() { converter := &cc{content: string(blob)} - out := converter.convert(stmt) + out, ok := stmt.Accept(converter).(ast.Node) + if !ok { + return nil, fmt.Errorf("expected ast.Node; got %T", out) + } if _, ok := out.(*ast.TODO); ok { loc = byteOffset(content, stmt.GetStop().GetStop() + 2) continue diff --git a/internal/engine/ydb/utils.go b/internal/engine/ydb/utils.go index f2023e8ba9..8f118df09b 100755 --- a/internal/engine/ydb/utils.go +++ b/internal/engine/ydb/utils.go @@ -156,7 +156,13 @@ func parseIntegerValue(text string) (int64, error) { } func (c *cc) extractRoleSpec(n parser.IRole_nameContext, roletype ast.RoleSpecType) (*ast.RoleSpec, bool, ast.Node) { - roleNode := c.convert(n) + if n == nil { + return nil, false, nil + } + roleNode, ok := n.Accept(c).(ast.Node) + if !ok { + return nil, false, nil + } roleSpec := &ast.RoleSpec{ Roletype: roletype, @@ -219,3 +225,73 @@ func emptySelectStmt() *ast.SelectStmt { LockingClause: &ast.List{}, } } + +func (c *cc) collectComparisonOps(n parser.IEq_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + for _, child := range n.GetChildren() { + if tn, ok := child.(antlr.TerminalNode); ok { + switch tn.GetText() { + case "<", "<=", ">", ">=": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) collectBitwiseOps(ctx parser.INeq_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + children := ctx.GetChildren() + for _, child := range children { + if tn, ok := child.(antlr.TerminalNode); ok { + txt := tn.GetText() + switch txt { + case "<<", ">>", "<<|", ">>|", "&", "|", "^": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) collectBitOps(ctx parser.IBit_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + children := ctx.GetChildren() + for _, child := range children { + if tn, ok := child.(antlr.TerminalNode); ok { + txt := tn.GetText() + switch txt { + case "+", "-": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) collectAddOps(ctx parser.IAdd_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + for _, child := range ctx.GetChildren() { + if tn, ok := child.(antlr.TerminalNode); ok { + switch tn.GetText() { + case "*", "/", "%": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) collectEqualityOps(ctx parser.ICond_exprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + children := ctx.GetChildren() + for _, child := range children { + if tn, ok := child.(antlr.TerminalNode); ok { + switch tn.GetText() { + case "=", "==", "!=", "<>": + ops = append(ops, tn) + } + } + } + return ops +} From a680b1b919b5e85d57d7dccabbb13bff8b90cf7c Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov <150552906+1NepuNep1@users.noreply.github.com> Date: Fri, 26 Sep 2025 20:01:12 +0300 Subject: [PATCH 16/18] Massive Optional + Functions update (#13) * Massive Functions update * Fixed ydb_type nullable problem and added new funcs to ydb catalog * Update internal/engine/ydb/convert.go * Removed comment from query.go to extractBaseType method --- internal/codegen/golang/query.go | 4 +- internal/codegen/golang/ydb_type.go | 23 +- internal/engine/ydb/convert.go | 165 ++++- internal/engine/ydb/lib/aggregate.go | 533 ++++++++------- internal/engine/ydb/lib/basic.go | 832 ++++++++++++++++++----- internal/engine/ydb/lib/cpp.go | 27 + internal/engine/ydb/lib/cpp/datetime.go | 695 +++++++++++++++++++ internal/engine/ydb/lib/cpp/digest.go | 171 +++++ internal/engine/ydb/lib/cpp/hyperscan.go | 105 +++ internal/engine/ydb/lib/cpp/ip.go | 140 ++++ internal/engine/ydb/lib/cpp/math.go | 439 ++++++++++++ internal/engine/ydb/lib/cpp/pcre.go | 105 +++ internal/engine/ydb/lib/cpp/pire.go | 85 +++ internal/engine/ydb/lib/cpp/re2.go | 319 +++++++++ internal/engine/ydb/lib/cpp/string.go | 152 +++++ internal/engine/ydb/lib/cpp/unicode.go | 532 +++++++++++++++ internal/engine/ydb/lib/cpp/url.go | 413 +++++++++++ internal/engine/ydb/lib/cpp/yson.go | 632 +++++++++++++++++ internal/engine/ydb/lib/window.go | 163 +++++ internal/engine/ydb/stdlib.go | 6 +- internal/sql/ast/recursive_func_call.go | 33 - internal/sql/astutils/rewrite.go | 8 - internal/sql/astutils/walk.go | 20 - 23 files changed, 5114 insertions(+), 488 deletions(-) create mode 100644 internal/engine/ydb/lib/cpp.go create mode 100644 internal/engine/ydb/lib/cpp/datetime.go create mode 100644 internal/engine/ydb/lib/cpp/digest.go create mode 100644 internal/engine/ydb/lib/cpp/hyperscan.go create mode 100644 internal/engine/ydb/lib/cpp/ip.go create mode 100644 internal/engine/ydb/lib/cpp/math.go create mode 100644 internal/engine/ydb/lib/cpp/pcre.go create mode 100644 internal/engine/ydb/lib/cpp/pire.go create mode 100644 internal/engine/ydb/lib/cpp/re2.go create mode 100644 internal/engine/ydb/lib/cpp/string.go create mode 100644 internal/engine/ydb/lib/cpp/unicode.go create mode 100644 internal/engine/ydb/lib/cpp/url.go create mode 100644 internal/engine/ydb/lib/cpp/yson.go create mode 100644 internal/engine/ydb/lib/window.go delete mode 100644 internal/sql/ast/recursive_func_call.go diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 7cda1b7c2b..52be2ecceb 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -297,7 +297,9 @@ func (v QueryValue) YDBParamMapEntries() string { // ydbBuilderMethodForColumnType maps a YDB column data type to a ParamsBuilder method name. func ydbBuilderMethodForColumnType(dbType string) string { - switch strings.ToLower(dbType) { + baseType := extractBaseType(strings.ToLower(dbType)) + + switch baseType { case "bool": return "Bool" case "uint64": diff --git a/internal/codegen/golang/ydb_type.go b/internal/codegen/golang/ydb_type.go index 0ef665aee1..0a4db80a3b 100644 --- a/internal/codegen/golang/ydb_type.go +++ b/internal/codegen/golang/ydb_type.go @@ -12,9 +12,11 @@ import ( func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string { columnType := strings.ToLower(sdk.DataType(col.Type)) - notNull := col.NotNull || col.IsArray + notNull := (col.NotNull || col.IsArray) && !isNullableType(columnType) emitPointersForNull := options.EmitPointersForNullTypes + columnType = extractBaseType(columnType) + // https://ydb.tech/docs/ru/yql/reference/types/ // ydb-go-sdk doesn't support sql.Null* yet switch columnType { @@ -49,7 +51,7 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col } // return "sql.NullInt16" return "*int16" - case "int", "int32": //ydb doesn't have int type, but we need it to support untyped constants + case "int", "int32": //ydb doesn't have int type, but we need it to support untyped constants if notNull { return "int32" } @@ -159,7 +161,7 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col return "*string" } return "*string" - + case "date", "date32", "datetime", "timestamp", "tzdate", "tztimestamp", "tzdatetime": if notNull { return "time.Time" @@ -185,3 +187,18 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col } } + +// This function extracts the base type from optional types +func extractBaseType(typeStr string) string { + if strings.HasPrefix(typeStr, "optional<") && strings.HasSuffix(typeStr, ">") { + return strings.TrimSuffix(strings.TrimPrefix(typeStr, "optional<"), ">") + } + if strings.HasSuffix(typeStr, "?") { + return strings.TrimSuffix(typeStr, "?") + } + return typeStr +} + +func isNullableType(typeStr string) bool { + return strings.HasPrefix(typeStr, "optional<") && strings.HasSuffix(typeStr, ">") || strings.HasSuffix(typeStr, "?") +} diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index 8b67191ce6..0fa339fa56 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -1,6 +1,7 @@ package ydb import ( + "fmt" "log" "strconv" "strings" @@ -1787,7 +1788,15 @@ func (c *cc) VisitType_name_or_bind(n *parser.Type_name_or_bindContext) interfac } return typeName } else if b := n.Bind_parameter(); b != nil { - return &ast.TypeName{Name: "BIND:" + identifier(parseAnIdOrType(b.An_id_or_type()))} + param, ok := b.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_or_bind", b) + } + return &ast.TypeName{ + Names: &ast.List{ + Items: []ast.Node{param}, + }, + } } return todo("VisitType_name_or_bind", n) } @@ -1797,6 +1806,8 @@ func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} { return todo("VisitType_name", n) } + questionCount := len(n.AllQUESTION()) + if composite := n.Type_name_composite(); composite != nil { typeName, ok := composite.Accept(c).(ast.Node) if !ok { @@ -1815,8 +1826,12 @@ func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} { if !ok { return todo("VisitType_name", decimal.Integer_or_bind(1)) } + name := "decimal" + if questionCount > 0 { + name = name + "?" + } return &ast.TypeName{ - Name: "decimal", + Name: name, TypeOid: 0, Names: &ast.List{ Items: []ast.Node{ @@ -1829,12 +1844,17 @@ func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} { } if simple := n.Type_name_simple(); simple != nil { + name := simple.GetText() + if questionCount > 0 { + name = name + "?" + } return &ast.TypeName{ - Name: simple.GetText(), + Name: name, TypeOid: 0, } } + // todo: handle multiple ? suffixes return todo("VisitType_name", n) } @@ -1868,19 +1888,7 @@ func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) inte } if opt := n.Type_name_optional(); opt != nil { - if typeName := opt.Type_name_or_bind(); typeName != nil { - tn, ok := typeName.Accept(c).(ast.Node) - if !ok { - return todo("VisitType_name_composite", typeName) - } - return &ast.TypeName{ - Name: "Optional", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{tn}, - }, - } - } + return opt.Accept(c) } if tuple := n.Type_name_tuple(); tuple != nil { @@ -2025,6 +2033,27 @@ func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) inte return todo("VisitType_name_composite", n) } +func (c *cc) VisitType_name_optional(n *parser.Type_name_optionalContext) interface{} { + if n == nil || n.Type_name_or_bind() == nil { + return todo("VisitType_name_optional", n) + } + + tn, ok := n.Type_name_or_bind().Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_optional", n.Type_name_or_bind()) + } + innerTypeName, ok := tn.(*ast.TypeName) + if !ok { + return todo("VisitType_name_optional", n.Type_name_or_bind()) + } + name := fmt.Sprintf("Optional<%s>", innerTypeName.Name) + return &ast.TypeName{ + Name: name, + TypeOid: 0, + Names: &ast.List{}, + } +} + func (c *cc) VisitSql_stmt_core(n *parser.Sql_stmt_coreContext) interface{} { if n == nil { return todo("VisitSql_stmt_core", n) @@ -2799,13 +2828,28 @@ func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprCont } funcName := strings.Join(nameParts, ".") - if funcName == "coalesce" { + if funcName == "coalesce" || funcName == "nvl" { return &ast.CoalesceExpr{ Args: funcCall.Args, Location: baseNode.Location, } } + if funcName == "greatest" || funcName == "max_of" { + return &ast.MinMaxExpr{ + Op: ast.MinMaxOp(1), + Args: funcCall.Args, + Location: baseNode.Location, + } + } + if funcName == "least" || funcName == "min_of" { + return &ast.MinMaxExpr{ + Op: ast.MinMaxOp(2), + Args: funcCall.Args, + Location: baseNode.Location, + } + } + funcCall.Func = &ast.FuncName{Name: funcName} funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: funcName}) @@ -2816,15 +2860,12 @@ func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprCont } } - stmt := &ast.RecursiveFuncCall{ - Func: base, - Funcname: funcCall.Funcname, - AggStar: funcCall.AggStar, - Location: funcCall.Location, - Args: funcCall.Args, - AggDistinct: funcCall.AggDistinct, + stmt := &ast.FuncExpr{ + Xpr: base, + Args: funcCall.Args, + Location: funcCall.Location, } - stmt.Funcname.Items = append(stmt.Funcname.Items, base) + return stmt } @@ -2943,16 +2984,42 @@ func (c *cc) VisitId_expr(n *parser.Id_exprContext) interface{} { if n == nil { return todo("VisitId_expr", n) } + + ref := &ast.ColumnRef{ + Fields: &ast.List{}, + Location: c.pos(n.GetStart()), + } + if id := n.Identifier(); id != nil { - return &ast.ColumnRef{ - Fields: &ast.List{ - Items: []ast.Node{ - NewIdentifier(id.GetText()), - }, - }, - Location: c.pos(id.GetStart()), - } + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(id.GetText())) + return ref + } + + if keyword := n.Keyword_compat(); keyword != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText())) + return ref + } + + if keyword := n.Keyword_alter_uncompat(); keyword != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText())) + return ref + } + + if keyword := n.Keyword_in_uncompat(); keyword != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText())) + return ref + } + + if keyword := n.Keyword_window_uncompat(); keyword != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText())) + return ref + } + + if keyword := n.Keyword_hint_uncompat(); keyword != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText())) + return ref } + return todo("VisitId_expr", n) } @@ -2979,12 +3046,44 @@ func (c *cc) VisitAtom_expr(n *parser.Atom_exprContext) interface{} { return todo("VisitAtom_expr", n.Bind_parameter()) } return expr + case n.Cast_expr() != nil: + expr, ok := n.Cast_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Cast_expr()) + } + return expr // TODO: check other cases default: return todo("VisitAtom_expr", n) } } +func (c *cc) VisitCast_expr(n *parser.Cast_exprContext) interface{} { + if n == nil || n.CAST() == nil || n.Expr() == nil || n.AS() == nil || n.Type_name_or_bind() == nil { + return todo("VisitCast_expr", n) + } + + expr, ok := n.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitCast_expr", n.Expr()) + } + + temp, ok := n.Type_name_or_bind().Accept(c).(ast.Node) + if !ok { + return todo("VisitCast_expr", n.Type_name_or_bind()) + } + typeName, ok := temp.(*ast.TypeName) + if !ok { + return todo("VisitCast_expr", n.Type_name_or_bind()) + } + + return &ast.TypeCast{ + Arg: expr, + TypeName: typeName, + Location: c.pos(n.GetStart()), + } +} + func (c *cc) VisitLiteral_value(n *parser.Literal_valueContext) interface{} { if n == nil { return todo("VisitLiteral_value", n) diff --git a/internal/engine/ydb/lib/aggregate.go b/internal/engine/ydb/lib/aggregate.go index dfb3924e90..7c5d795eca 100644 --- a/internal/engine/ydb/lib/aggregate.go +++ b/internal/engine/ydb/lib/aggregate.go @@ -8,323 +8,400 @@ import ( func AggregateFunctions() []*catalog.Function { var funcs []*catalog.Function - // COUNT(*) - funcs = append(funcs, &catalog.Function{ - Name: "COUNT", - Args: []*catalog.Argument{}, - ReturnType: &ast.TypeName{Name: "Uint64"}, - }) + funcs = append(funcs, countFuncs()...) + funcs = append(funcs, minMaxFuncs()...) + funcs = append(funcs, sumFuncs()...) + funcs = append(funcs, avgFuncs()...) + funcs = append(funcs, countIfFuncs()...) + funcs = append(funcs, sumIfFuncs()...) + funcs = append(funcs, avgIfFuncs()...) + funcs = append(funcs, someFuncs()...) + funcs = append(funcs, countDistinctEstimateHLLFuncs()...) + funcs = append(funcs, maxByMinByFuncs()...) + funcs = append(funcs, stddevVarianceFuncs()...) + funcs = append(funcs, correlationCovarianceFuncs()...) + funcs = append(funcs, percentileMedianFuncs()...) + funcs = append(funcs, boolAndOrXorFuncs()...) + funcs = append(funcs, bitAndOrXorFuncs()...) - // COUNT(T) и COUNT(T?) - for _, typ := range types { - funcs = append(funcs, &catalog.Function{ - Name: "COUNT", - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - }, - ReturnType: &ast.TypeName{Name: "Uint64"}, - }) - funcs = append(funcs, &catalog.Function{ + // TODO: Aggregate_List, Top, Bottom, Top_By, Bottom_By, TopFreq, Mode, + // Histogram LinearHistogram, LogarithmicHistogram, LogHistogram, CDF, + // SessionStart, AGGREGATE_BY, MULTI_AGGREGATE_BY + + return funcs +} + +func countFuncs() []*catalog.Function { + return []*catalog.Function{ + { Name: "COUNT", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}, Mode: ast.FuncParamVariadic}, + {Type: &ast.TypeName{Name: "any"}}, }, ReturnType: &ast.TypeName{Name: "Uint64"}, - }) + }, } +} - // MIN и MAX - for _, typ := range types { - funcs = append(funcs, &catalog.Function{ +func minMaxFuncs() []*catalog.Function { + return []*catalog.Function{ + { Name: "MIN", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, }, - ReturnType: &ast.TypeName{Name: typ}, - ReturnTypeNullable: true, - }) - funcs = append(funcs, &catalog.Function{ + ReturnType: &ast.TypeName{Name: "any"}, + }, + { Name: "MAX", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, }, - ReturnType: &ast.TypeName{Name: typ}, - ReturnTypeNullable: true, - }) + ReturnType: &ast.TypeName{Name: "any"}, + }, } +} - // SUM для unsigned типов - for _, typ := range unsignedTypes { - funcs = append(funcs, &catalog.Function{ +func sumFuncs() []*catalog.Function { + return []*catalog.Function{ + { Name: "SUM", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, }, - ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnType: &ast.TypeName{Name: "any"}, ReturnTypeNullable: true, - }) + }, } +} - // SUM для signed типов - for _, typ := range signedTypes { - funcs = append(funcs, &catalog.Function{ - Name: "SUM", +func avgFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "AVG", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, }, - ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnType: &ast.TypeName{Name: "any"}, ReturnTypeNullable: true, - }) + }, } +} - // SUM для float/double - for _, typ := range []string{"float", "double"} { - funcs = append(funcs, &catalog.Function{ - Name: "SUM", +func countIfFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "COUNT_IF", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Bool"}}, }, - ReturnType: &ast.TypeName{Name: typ}, + ReturnType: &ast.TypeName{Name: "Uint64"}, ReturnTypeNullable: true, - }) + }, } +} - // AVG для целочисленных типов - for _, typ := range append(unsignedTypes, signedTypes...) { - funcs = append(funcs, &catalog.Function{ - Name: "AVG", +func sumIfFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "SUM_IF", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Bool"}}, }, - ReturnType: &ast.TypeName{Name: "Double"}, + ReturnType: &ast.TypeName{Name: "any"}, ReturnTypeNullable: true, - }) + }, } +} - // AVG для float/double - for _, typ := range []string{"float", "double"} { - funcs = append(funcs, &catalog.Function{ - Name: "AVG", +func avgIfFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "AVG_IF", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Bool"}}, }, - ReturnType: &ast.TypeName{Name: typ}, + ReturnType: &ast.TypeName{Name: "any"}, ReturnTypeNullable: true, - }) + }, } +} - // COUNT_IF - funcs = append(funcs, &catalog.Function{ - Name: "COUNT_IF", - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: "Bool"}}, - }, - ReturnType: &ast.TypeName{Name: "Uint64"}, - ReturnTypeNullable: true, - }) +func someFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "SOME", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} - // SUM_IF для unsigned - for _, typ := range unsignedTypes { - funcs = append(funcs, &catalog.Function{ - Name: "SUM_IF", +func countDistinctEstimateHLLFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "CountDistinctEstimate", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "any"}}, }, ReturnType: &ast.TypeName{Name: "Uint64"}, ReturnTypeNullable: true, - }) - } - - // SUM_IF для signed - for _, typ := range signedTypes { - funcs = append(funcs, &catalog.Function{ - Name: "SUM_IF", + }, + { + Name: "HyperLogLog", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "any"}}, }, - ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnType: &ast.TypeName{Name: "Uint64"}, ReturnTypeNullable: true, - }) + }, + { + Name: "HLL", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, } +} - // SUM_IF для float/double - for _, typ := range []string{"float", "double"} { - funcs = append(funcs, &catalog.Function{ - Name: "SUM_IF", +func maxByMinByFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "MAX_BY", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, }, - ReturnType: &ast.TypeName{Name: typ}, + ReturnType: &ast.TypeName{Name: "any"}, ReturnTypeNullable: true, - }) + }, + { + Name: "MIN_BY", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, } + // todo: min/max_by with third argument returning list +} - // AVG_IF для целочисленных - for _, typ := range append(unsignedTypes, signedTypes...) { - funcs = append(funcs, &catalog.Function{ - Name: "AVG_IF", +func stddevVarianceFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "STDDEV", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Double"}}, }, ReturnType: &ast.TypeName{Name: "Double"}, ReturnTypeNullable: true, - }) - } - - // AVG_IF для float/double - for _, typ := range []string{"float", "double"} { - funcs = append(funcs, &catalog.Function{ - Name: "AVG_IF", + }, + { + Name: "STDDEV_POPULATION", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Double"}}, }, - ReturnType: &ast.TypeName{Name: typ}, + ReturnType: &ast.TypeName{Name: "Double"}, ReturnTypeNullable: true, - }) - } - - // SOME - for _, typ := range types { - funcs = append(funcs, &catalog.Function{ - Name: "SOME", + }, + { + Name: "POPULATION_STDDEV", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Double"}}, }, - ReturnType: &ast.TypeName{Name: typ}, + ReturnType: &ast.TypeName{Name: "Double"}, ReturnTypeNullable: true, - }) - } - - // AGGREGATE_LIST и AGGREGATE_LIST_DISTINCT - for _, typ := range types { - funcs = append(funcs, &catalog.Function{ - Name: "AGGREGATE_LIST", + }, + { + Name: "STDDEV_SAMPLE", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Double"}}, }, - ReturnType: &ast.TypeName{Name: "List<" + typ + ">"}, - }) - funcs = append(funcs, &catalog.Function{ - Name: "AGGREGATE_LIST_DISTINCT", + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, + { + Name: "STDDEVSAMP", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "Double"}}, }, - ReturnType: &ast.TypeName{Name: "List<" + typ + ">"}, - }) - } - - // BOOL_AND, BOOL_OR, BOOL_XOR - boolAggrs := []string{"BOOL_AND", "BOOL_OR", "BOOL_XOR"} - for _, name := range boolAggrs { - funcs = append(funcs, &catalog.Function{ - Name: name, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, + { + Name: "VARIANCE", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Double"}}, }, - ReturnType: &ast.TypeName{Name: "Bool"}, + ReturnType: &ast.TypeName{Name: "Double"}, ReturnTypeNullable: true, - }) - } - - // BIT_AND, BIT_OR, BIT_XOR - bitAggrs := []string{"BIT_AND", "BIT_OR", "BIT_XOR"} - for _, typ := range append(unsignedTypes, signedTypes...) { - for _, name := range bitAggrs { - funcs = append(funcs, &catalog.Function{ - Name: name, - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - }, - ReturnType: &ast.TypeName{Name: typ}, - ReturnTypeNullable: true, - }) - } - } - - // STDDEV и VARIANCE - stdDevVariants := []struct { - name string - returnType string - }{ - {"STDDEV", "Double"}, - {"VARIANCE", "Double"}, - {"STDDEV_SAMPLE", "Double"}, - {"VARIANCE_SAMPLE", "Double"}, - {"STDDEV_POPULATION", "Double"}, - {"VARIANCE_POPULATION", "Double"}, - } - for _, variant := range stdDevVariants { - funcs = append(funcs, &catalog.Function{ - Name: variant.name, + }, + { + Name: "VARIANCE_POPULATION", Args: []*catalog.Argument{ {Type: &ast.TypeName{Name: "Double"}}, }, - ReturnType: &ast.TypeName{Name: variant.returnType}, + ReturnType: &ast.TypeName{Name: "Double"}, ReturnTypeNullable: true, - }) + }, + { + Name: "POPULATION_VARIANCE", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, + { + Name: "VARPOP", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, + { + Name: "VARIANCE_SAMPLE", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, } +} - // CORRELATION и COVARIANCE - corrCovar := []string{"CORRELATION", "COVARIANCE", "COVARIANCE_SAMPLE", "COVARIANCE_POPULATION"} - for _, name := range corrCovar { - funcs = append(funcs, &catalog.Function{ - Name: name, +func correlationCovarianceFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "CORRELATION", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, + { + Name: "COVARIANCE", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, + { + Name: "COVARIANCE_SAMPLE", Args: []*catalog.Argument{ {Type: &ast.TypeName{Name: "Double"}}, {Type: &ast.TypeName{Name: "Double"}}, }, ReturnType: &ast.TypeName{Name: "Double"}, ReturnTypeNullable: true, - }) + }, + { + Name: "COVARIANCE_POPULATION", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, } +} - // HISTOGRAM - funcs = append(funcs, &catalog.Function{ - Name: "HISTOGRAM", - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: "Double"}}, - }, - ReturnType: &ast.TypeName{Name: "HistogramStruct"}, - ReturnTypeNullable: true, - }) - - // TOP и BOTTOM - topBottom := []string{"TOP", "BOTTOM"} - for _, name := range topBottom { - for _, typ := range types { - funcs = append(funcs, &catalog.Function{ - Name: name, - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: "Uint32"}}, - }, - ReturnType: &ast.TypeName{Name: "List<" + typ + ">"}, - }) - } +func percentileMedianFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "PERCENTILE", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "MEDIAN", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "MEDIAN", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, } +} - // MAX_BY и MIN_BY - minMaxBy := []string{"MAX_BY", "MIN_BY"} - for _, name := range minMaxBy { - for _, typ := range types { - funcs = append(funcs, &catalog.Function{ - Name: name, - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: "any"}}, - }, - ReturnType: &ast.TypeName{Name: typ}, - ReturnTypeNullable: true, - }) - } +func boolAndOrXorFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "BOOL_AND", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + ReturnTypeNullable: true, + }, + { + Name: "BOOL_OR", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + ReturnTypeNullable: true, + }, + { + Name: "BOOL_XOR", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + ReturnTypeNullable: true, + }, } +} - // ... (добавьте другие агрегатные функции по аналогии) - - return funcs +func bitAndOrXorFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "BIT_AND", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "BIT_OR", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "BIT_XOR", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } } diff --git a/internal/engine/ydb/lib/basic.go b/internal/engine/ydb/lib/basic.go index 08c0011787..d5cbfda950 100644 --- a/internal/engine/ydb/lib/basic.go +++ b/internal/engine/ydb/lib/basic.go @@ -5,199 +5,709 @@ import ( "github.com/sqlc-dev/sqlc/internal/sql/catalog" ) -var types = []string{ - "bool", - "int8", "int16", "int32", "int64", - "uint8", "uint16", "uint32", "uint64", - "float", "double", - "string", "utf8", - "any", -} - -var ( - unsignedTypes = []string{"uint8", "uint16", "uint32", "uint64"} - signedTypes = []string{"int8", "int16", "int32", "int64"} - numericTypes = append(append(unsignedTypes, signedTypes...), "float", "double") -) - func BasicFunctions() []*catalog.Function { var funcs []*catalog.Function - for _, typ := range types { - // COALESCE, NVL - funcs = append(funcs, &catalog.Function{ - Name: "COALESCE", + funcs = append(funcs, lengthFuncs()...) + funcs = append(funcs, substringFuncs()...) + funcs = append(funcs, findFuncs()...) + funcs = append(funcs, rfindFuncs()...) + funcs = append(funcs, startsWithFuncs()...) + funcs = append(funcs, endsWithFuncs()...) + funcs = append(funcs, ifFuncs()...) + funcs = append(funcs, nanvlFuncs()...) + funcs = append(funcs, randomFuncs()...) + funcs = append(funcs, currentUtcFuncs()...) + funcs = append(funcs, currentTzFuncs()...) + funcs = append(funcs, addTimezoneFuncs()...) + funcs = append(funcs, removeTimezoneFuncs()...) + funcs = append(funcs, versionFuncs()...) + funcs = append(funcs, ensureFuncs()...) + funcs = append(funcs, assumeStrictFuncs()...) + funcs = append(funcs, likelyFuncs()...) + funcs = append(funcs, evaluateFuncs()...) + funcs = append(funcs, simpleTypesLiteralsFuncs()...) + funcs = append(funcs, toFromBytesFuncs()...) + funcs = append(funcs, byteAtFuncs()...) + funcs = append(funcs, testClearSetFlipBitFuncs()...) + funcs = append(funcs, absFuncs()...) + funcs = append(funcs, justUnwrapNothingFuncs()...) + funcs = append(funcs, pickleUnpickleFuncs()...) + + // todo: implement functions: + // Udf, AsTuple, AsStruct, AsList, AsDict, AsSet, AsListStrict, AsDictStrict, AsSetStrict, + // Variant, AsVariant, Visit, VisitOrDefault, VariantItem, Way, DynamicVariant, + // Enum, AsEnum, AsTagged, Untag, TableRow, Callable, + // StaticMap, StaticZip, StaticFold, StaticFold1, + // AggregationFactory, AggregateTransformInput, AggregateTransformOutput, AggregateFlatten + + return funcs +} + +func lengthFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "LENGTH", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "LEN", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + } +} + +func substringFuncs() []*catalog.Function { + funcs := []*catalog.Function{ + { + Name: "Substring", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Substring", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } + return funcs +} + +func findFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Find", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "Find", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + } +} + +func rfindFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "RFind", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "RFind", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + } +} + +func startsWithFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "StartsWith", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func endsWithFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "EndsWith", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func ifFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "IF", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "IF", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: false, + }, + } +} + +func nanvlFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "NANVL", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func randomFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Random", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}, + Mode: ast.FuncParamVariadic, + }, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "RandomNumber", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}, + Mode: ast.FuncParamVariadic, + }, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "RandomUuid", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}, + Mode: ast.FuncParamVariadic, + }, + }, + ReturnType: &ast.TypeName{Name: "Uuid"}, + }, + } +} + +func currentUtcFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "CurrentUtcDate", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: typ}}, { - Type: &ast.TypeName{Name: typ}, + Type: &ast.TypeName{Name: "any"}, Mode: ast.FuncParamVariadic, }, }, - ReturnType: &ast.TypeName{Name: typ}, - ReturnTypeNullable: false, - }) - funcs = append(funcs, &catalog.Function{ - Name: "NVL", + ReturnType: &ast.TypeName{Name: "Date"}, + }, + { + Name: "CurrentUtcDatetime", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: typ}}, { - Type: &ast.TypeName{Name: typ}, + Type: &ast.TypeName{Name: "any"}, Mode: ast.FuncParamVariadic, }, }, - ReturnType: &ast.TypeName{Name: typ}, - ReturnTypeNullable: false, - }) + ReturnType: &ast.TypeName{Name: "Datetime"}, + }, + { + Name: "CurrentUtcTimestamp", + Args: []*catalog.Argument{ + { + Type: &ast.TypeName{Name: "any"}, + Mode: ast.FuncParamVariadic, + }, + }, + ReturnType: &ast.TypeName{Name: "Timestamp"}, + }, + } +} - // IF(Bool, T, T) -> T - funcs = append(funcs, &catalog.Function{ - Name: "IF", +func currentTzFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "CurrentTzDate", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + { + Type: &ast.TypeName{Name: "any"}, + Mode: ast.FuncParamVariadic, + }, + }, + ReturnType: &ast.TypeName{Name: "TzDate"}, + }, + { + Name: "CurrentTzDatetime", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + { + Type: &ast.TypeName{Name: "any"}, + Mode: ast.FuncParamVariadic, + }, + }, + ReturnType: &ast.TypeName{Name: "TzDatetime"}, + }, + { + Name: "CurrentTzTimestamp", Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + { + Type: &ast.TypeName{Name: "any"}, + Mode: ast.FuncParamVariadic, + }, + }, + ReturnType: &ast.TypeName{Name: "TzTimestamp"}, + }, + } +} + +func addTimezoneFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "AddTimezone", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func removeTimezoneFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "RemoveTimezone", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func versionFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Version", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} + +func ensureFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Ensure", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, {Type: &ast.TypeName{Name: "Bool"}}, - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "String"}}, }, - ReturnType: &ast.TypeName{Name: typ}, - ReturnTypeNullable: false, - }) + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "EnsureType", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "EnsureConvertibleTo", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} - // LENGTH, LEN - funcs = append(funcs, &catalog.Function{ - Name: "LENGTH", +func assumeStrictFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "AssumeStrict", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, }, - ReturnType: &ast.TypeName{Name: "Uint32"}, - ReturnTypeNullable: true, - }) - funcs = append(funcs, &catalog.Function{ - Name: "LEN", + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func likelyFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Likely", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, }, - ReturnType: &ast.TypeName{Name: "Uint32"}, - ReturnTypeNullable: true, - }) + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} - // StartsWith, EndsWith - funcs = append(funcs, &catalog.Function{ - Name: "StartsWith", +func evaluateFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "EvaluateExpr", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "EvaluateAtom", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func simpleTypesLiteralsFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Bool", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, }, ReturnType: &ast.TypeName{Name: "Bool"}, - }) - funcs = append(funcs, &catalog.Function{ - Name: "EndsWith", + }, + { + Name: "Uint8", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + { + Name: "Int32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Int32"}, + }, + { + Name: "Uint32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "Int64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Int64"}, + }, + { + Name: "Uint64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Float", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Float"}, + }, + { + Name: "Double", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: typ}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Decimal", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint8"}}, // precision + {Type: &ast.TypeName{Name: "Uint8"}}, // scale + }, + ReturnType: &ast.TypeName{Name: "Decimal"}, + }, + { + Name: "String", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Utf8", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Yson", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Yson"}, + }, + { + Name: "Json", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Json"}, + }, + { + Name: "Date", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Date"}, + }, + { + Name: "Datetime", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Datetime"}, + }, + { + Name: "Timestamp", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Timestamp"}, + }, + { + Name: "Interval", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval"}, + }, + { + Name: "TzDate", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "TzDate"}, + }, + { + Name: "TzDatetime", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "TzDatetime"}, + }, + { + Name: "TzTimestamp", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "TzTimestamp"}, + }, + { + Name: "Uuid", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uuid"}, + }, + } +} + +func toFromBytesFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "ToBytes", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "FromBytes", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func byteAtFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "ByteAt", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + } +} + +func testClearSetFlipBitFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "TestBit", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Uint8"}}, }, ReturnType: &ast.TypeName{Name: "Bool"}, - }) - - // ABS(T) -> T - } - - // SUBSTRING - funcs = append(funcs, &catalog.Function{ - Name: "Substring", - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: "String"}}, - }, - ReturnType: &ast.TypeName{Name: "String"}, - }) - funcs = append(funcs, &catalog.Function{ - Name: "Substring", - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: "String"}}, - {Type: &ast.TypeName{Name: "Uint32"}}, - }, - ReturnType: &ast.TypeName{Name: "String"}, - }) - funcs = append(funcs, &catalog.Function{ - Name: "Substring", - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: "String"}}, - {Type: &ast.TypeName{Name: "Uint32"}}, - {Type: &ast.TypeName{Name: "Uint32"}}, - }, - ReturnType: &ast.TypeName{Name: "String"}, - }) - - // FIND / RFIND - for _, name := range []string{"FIND", "RFIND"} { - for _, typ := range []string{"String", "Utf8"} { - funcs = append(funcs, &catalog.Function{ - Name: name, - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: typ}}, - }, - ReturnType: &ast.TypeName{Name: "Uint32"}, - }) - funcs = append(funcs, &catalog.Function{ - Name: name, - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: typ}}, - {Type: &ast.TypeName{Name: "Uint32"}}, - }, - ReturnType: &ast.TypeName{Name: "Uint32"}, - }) - } + }, + { + Name: "ClearBit", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Uint8"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "SetBit", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Uint8"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "FlipBit", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Uint8"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, } +} - for _, typ := range numericTypes { - funcs = append(funcs, &catalog.Function{ +func absFuncs() []*catalog.Function { + return []*catalog.Function{ + { Name: "Abs", Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: typ}}, - }, - ReturnType: &ast.TypeName{Name: typ}, - }) - } - - // NANVL - funcs = append(funcs, &catalog.Function{ - Name: "NANVL", - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: "Float"}}, - {Type: &ast.TypeName{Name: "Float"}}, - }, - ReturnType: &ast.TypeName{Name: "Float"}, - }) - funcs = append(funcs, &catalog.Function{ - Name: "NANVL", - Args: []*catalog.Argument{ - {Type: &ast.TypeName{Name: "Double"}}, - {Type: &ast.TypeName{Name: "Double"}}, - }, - ReturnType: &ast.TypeName{Name: "Double"}, - }) - - // Random* - funcs = append(funcs, &catalog.Function{ - Name: "Random", - Args: []*catalog.Argument{}, - ReturnType: &ast.TypeName{Name: "Double"}, - }) - funcs = append(funcs, &catalog.Function{ - Name: "RandomNumber", - Args: []*catalog.Argument{}, - ReturnType: &ast.TypeName{Name: "Uint64"}, - }) - funcs = append(funcs, &catalog.Function{ - Name: "RandomUuid", - Args: []*catalog.Argument{}, - ReturnType: &ast.TypeName{Name: "Uuid"}, - }) - - // todo: add all remain functions + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} - return funcs +func justUnwrapNothingFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Just", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "Unwrap", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Unwrap", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Nothing", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func pickleUnpickleFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pickle", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "StablePickle", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Unpickle", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } } diff --git a/internal/engine/ydb/lib/cpp.go b/internal/engine/ydb/lib/cpp.go new file mode 100644 index 0000000000..9f076aba98 --- /dev/null +++ b/internal/engine/ydb/lib/cpp.go @@ -0,0 +1,27 @@ +package lib + +import ( + "github.com/sqlc-dev/sqlc/internal/engine/ydb/lib/cpp" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func CppFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, cpp.DateTimeFunctions()...) + funcs = append(funcs, cpp.DigestFunctions()...) + funcs = append(funcs, cpp.HyperscanFunctions()...) + funcs = append(funcs, cpp.IpFunctions()...) + funcs = append(funcs, cpp.MathFunctions()...) + funcs = append(funcs, cpp.PcreFunctions()...) + funcs = append(funcs, cpp.PireFunctions()...) + funcs = append(funcs, cpp.Re2Functions()...) + funcs = append(funcs, cpp.StringFunctions()...) + funcs = append(funcs, cpp.UnicodeFunctions()...) + funcs = append(funcs, cpp.UrlFunctions()...) + funcs = append(funcs, cpp.YsonFunctions()...) + + // TODO: Histogram library, KNN library, PostgeSQL library + + return funcs +} diff --git a/internal/engine/ydb/lib/cpp/datetime.go b/internal/engine/ydb/lib/cpp/datetime.go new file mode 100644 index 0000000000..ca6f6bc6b6 --- /dev/null +++ b/internal/engine/ydb/lib/cpp/datetime.go @@ -0,0 +1,695 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func DateTimeFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, dateTimeMakeFuncs()...) + funcs = append(funcs, dateTimeGetFuncs()...) + funcs = append(funcs, dateTimeUpdateFuncs()...) + funcs = append(funcs, dateTimeFromFuncs()...) + funcs = append(funcs, dateTimeToFuncs()...) + funcs = append(funcs, dateTimeIntervalFuncs()...) + funcs = append(funcs, dateTimeStartEndFuncs()...) + funcs = append(funcs, dateTimeFormatFuncs()...) + funcs = append(funcs, dateTimeParseFuncs()...) + + return funcs +} + +func dateTimeMakeFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "DateTime::MakeDate", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Date"}, + }, + { + Name: "DateTime::MakeDate32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Date32"}, + }, + { + Name: "DateTime::MakeTzDate32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "TzDate32"}, + }, + { + Name: "DateTime::MakeDatetime", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Datetime"}, + }, + { + Name: "DateTime::MakeTzDatetime", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "TzDatetime"}, + }, + { + Name: "DateTime::MakeDatetime64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Datetime64"}, + }, + { + Name: "DateTime::MakeTzDatetime64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "TzDatetime64"}, + }, + { + Name: "DateTime::MakeTimestamp", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Timestamp"}, + }, + { + Name: "DateTime::MakeTzTimestamp", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "TzTimestamp"}, + }, + { + Name: "DateTime::MakeTimestamp64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Timestamp64"}, + }, + { + Name: "DateTime::MakeTzTimestamp64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "TzTimestamp64"}, + }, + } +} + +func dateTimeGetFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "DateTime::GetYear", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint16"}, + }, + { + Name: "DateTime::GetYear", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Int32"}, + }, + { + Name: "DateTime::GetDayOfYear", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint16"}, + }, + { + Name: "DateTime::GetMonth", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + { + Name: "DateTime::GetMonthName", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "DateTime::GetWeekOfYear", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + { + Name: "DateTime::GetWeekOfYearIso8601", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + { + Name: "DateTime::GetDayOfMonth", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + { + Name: "DateTime::GetDayOfWeek", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + { + Name: "DateTime::GetDayOfWeekName", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "DateTime::GetHour", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + { + Name: "DateTime::GetMinute", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + { + Name: "DateTime::GetSecond", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint8"}, + }, + { + Name: "DateTime::GetMillisecondOfSecond", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "DateTime::GetMicrosecondOfSecond", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "DateTime::GetTimezoneId", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint16"}, + }, + { + Name: "DateTime::GetTimezoneName", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} + +func dateTimeUpdateFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "DateTime::Update", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::Update", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::Update", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::Update", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::Update", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::Update", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::Update", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::Update", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func dateTimeFromFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "DateTime::FromSeconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "Timestamp"}, + }, + { + Name: "DateTime::FromSeconds64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Timestamp64"}, + }, + { + Name: "DateTime::FromMilliseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Timestamp"}, + }, + { + Name: "DateTime::FromMilliseconds64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Timestamp64"}, + }, + { + Name: "DateTime::FromMicroseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Timestamp"}, + }, + { + Name: "DateTime::FromMicroseconds64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Timestamp64"}, + }, + } +} + +func dateTimeToFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "DateTime::ToSeconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::ToMilliseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::ToMicroseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func dateTimeIntervalFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "DateTime::ToDays", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::ToHours", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::ToMinutes", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::ToSeconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::ToMilliseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::ToMicroseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::IntervalFromDays", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int32"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval"}, + }, + { + Name: "DateTime::Interval64FromDays", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int32"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval64"}, + }, + { + Name: "DateTime::IntervalFromHours", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int32"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval"}, + }, + { + Name: "DateTime::Interval64FromHours", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval64"}, + }, + { + Name: "DateTime::IntervalFromMinutes", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int32"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval"}, + }, + { + Name: "DateTime::Interval64FromMinutes", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval64"}, + }, + { + Name: "DateTime::IntervalFromSeconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval"}, + }, + { + Name: "DateTime::Interval64FromSeconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval64"}, + }, + { + Name: "DateTime::IntervalFromMilliseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval"}, + }, + { + Name: "DateTime::Interval64FromMilliseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval64"}, + }, + { + Name: "DateTime::IntervalFromMicroseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval"}, + }, + { + Name: "DateTime::Interval64FromMicroseconds", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Interval64"}, + }, + } +} + +func dateTimeStartEndFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "DateTime::StartOfYear", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::EndOfYear", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::StartOfQuarter", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::EndOfQuarter", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::StartOfMonth", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::EndOfMonth", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::StartOfWeek", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::EndOfWeek", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::StartOfDay", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::EndOfDay", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::StartOf", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::EndOf", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func dateTimeFormatFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "DateTime::Format", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func dateTimeParseFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "DateTime::Parse", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::Parse64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "DateTime::ParseRfc822", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::ParseIso8601", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::ParseHttp", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "DateTime::ParseX509", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/digest.go b/internal/engine/ydb/lib/cpp/digest.go new file mode 100644 index 0000000000..dccdb8509b --- /dev/null +++ b/internal/engine/ydb/lib/cpp/digest.go @@ -0,0 +1,171 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func DigestFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, digestCrcFuncs()...) + funcs = append(funcs, digestFnvFuncs()...) + funcs = append(funcs, digestMurmurFuncs()...) + funcs = append(funcs, digestCityFuncs()...) + + return funcs +} + +func digestCrcFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Digest::Crc32c", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "Digest::Crc64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Digest::Crc64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + } +} + +func digestFnvFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Digest::Fnv32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "Digest::Fnv32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "Digest::Fnv64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Digest::Fnv64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + } +} + +func digestMurmurFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Digest::MurMurHash", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Digest::MurMurHash", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Digest::MurMurHash32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "Digest::MurMurHash32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "Digest::MurMurHash2A", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Digest::MurMurHash2A", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Digest::MurMurHash2A32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + { + Name: "Digest::MurMurHash2A32", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint32"}, + }, + } +} + +func digestCityFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Digest::CityHash", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Digest::CityHash", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Digest::CityHash128", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/hyperscan.go b/internal/engine/ydb/lib/cpp/hyperscan.go new file mode 100644 index 0000000000..be3aa968e2 --- /dev/null +++ b/internal/engine/ydb/lib/cpp/hyperscan.go @@ -0,0 +1,105 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func HyperscanFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, hyperscanGrepFuncs()...) + funcs = append(funcs, hyperscanMatchFuncs()...) + funcs = append(funcs, hyperscanBacktrackingFuncs()...) + funcs = append(funcs, hyperscanMultiFuncs()...) + funcs = append(funcs, hyperscanCaptureFuncs()...) + funcs = append(funcs, hyperscanReplaceFuncs()...) + + return funcs +} + +func hyperscanGrepFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Hyperscan::Grep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func hyperscanMatchFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Hyperscan::Match", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func hyperscanBacktrackingFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Hyperscan::BacktrackingGrep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Hyperscan::BacktrackingMatch", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func hyperscanMultiFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Hyperscan::MultiGrep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Hyperscan::MultiMatch", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func hyperscanCaptureFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Hyperscan::Capture", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func hyperscanReplaceFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Hyperscan::Replace", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/ip.go b/internal/engine/ydb/lib/cpp/ip.go new file mode 100644 index 0000000000..a644da910c --- /dev/null +++ b/internal/engine/ydb/lib/cpp/ip.go @@ -0,0 +1,140 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func IpFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, ipFromStringFuncs()...) + funcs = append(funcs, ipToStringFuncs()...) + funcs = append(funcs, ipCheckFuncs()...) + funcs = append(funcs, ipConvertFuncs()...) + funcs = append(funcs, ipSubnetFuncs()...) + funcs = append(funcs, ipMatchFuncs()...) + + return funcs +} + +func ipFromStringFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Ip::FromString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Ip::SubnetFromString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + } +} + +func ipToStringFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Ip::ToString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Ip::ToString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + } +} + +func ipCheckFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Ip::IsIPv4", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Ip::IsIPv6", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Ip::IsEmbeddedIPv4", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func ipConvertFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Ip::ConvertToIPv6", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} + +func ipSubnetFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Ip::GetSubnet", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Ip::GetSubnet", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint8"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Ip::GetSubnetByMask", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} + +func ipMatchFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Ip::SubnetMatch", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/math.go b/internal/engine/ydb/lib/cpp/math.go new file mode 100644 index 0000000000..288464ad0d --- /dev/null +++ b/internal/engine/ydb/lib/cpp/math.go @@ -0,0 +1,439 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func MathFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, mathConstantsFuncs()...) + funcs = append(funcs, mathCheckFuncs()...) + funcs = append(funcs, mathUnaryFuncs()...) + funcs = append(funcs, mathBinaryFuncs()...) + funcs = append(funcs, mathLdexpFuncs()...) + funcs = append(funcs, mathRoundFuncs()...) + funcs = append(funcs, mathFuzzyEqualsFuncs()...) + funcs = append(funcs, mathModRemFuncs()...) + funcs = append(funcs, mathRoundingModeFuncs()...) + funcs = append(funcs, mathNearbyIntFuncs()...) + + return funcs +} + +func mathConstantsFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::Pi", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::E", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Eps", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + } +} + +func mathCheckFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::IsInf", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Math::IsNaN", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Math::IsFinite", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func mathUnaryFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::Abs", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Acos", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Asin", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Asinh", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Atan", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Cbrt", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Ceil", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Cos", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Cosh", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Erf", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::ErfInv", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::ErfcInv", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Exp", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Exp2", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Fabs", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Floor", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Lgamma", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Rint", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Sigmoid", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Sin", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Sinh", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Sqrt", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Tan", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Tanh", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Tgamma", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Trunc", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Log", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Log2", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Log10", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + } +} + +func mathBinaryFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::Atan2", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Fmod", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Hypot", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Pow", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Remainder", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + } +} + +func mathLdexpFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::Ldexp", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Int32"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + } +} + +func mathRoundFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::Round", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + { + Name: "Math::Round", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Int32"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + } +} + +func mathFuzzyEqualsFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::FuzzyEquals", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Math::FuzzyEquals", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "Double"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func mathModRemFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::Mod", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnTypeNullable: true, + }, + { + Name: "Math::Rem", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Int64"}}, + {Type: &ast.TypeName{Name: "Int64"}}, + }, + ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnTypeNullable: true, + }, + } +} + +func mathRoundingModeFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::RoundDownward", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Math::RoundToNearest", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Math::RoundTowardZero", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Math::RoundUpward", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func mathNearbyIntFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Math::NearbyInt", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Double"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnTypeNullable: true, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/pcre.go b/internal/engine/ydb/lib/cpp/pcre.go new file mode 100644 index 0000000000..4b313ff80f --- /dev/null +++ b/internal/engine/ydb/lib/cpp/pcre.go @@ -0,0 +1,105 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func PcreFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, pcreGrepFuncs()...) + funcs = append(funcs, pcreMatchFuncs()...) + funcs = append(funcs, pcreBacktrackingFuncs()...) + funcs = append(funcs, pcreMultiFuncs()...) + funcs = append(funcs, pcreCaptureFuncs()...) + funcs = append(funcs, pcreReplaceFuncs()...) + + return funcs +} + +func pcreGrepFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pcre::Grep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func pcreMatchFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pcre::Match", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func pcreBacktrackingFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pcre::BacktrackingGrep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Pcre::BacktrackingMatch", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func pcreMultiFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pcre::MultiGrep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Pcre::MultiMatch", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func pcreCaptureFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pcre::Capture", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func pcreReplaceFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pcre::Replace", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/pire.go b/internal/engine/ydb/lib/cpp/pire.go new file mode 100644 index 0000000000..ae7eece256 --- /dev/null +++ b/internal/engine/ydb/lib/cpp/pire.go @@ -0,0 +1,85 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func PireFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, pireGrepFuncs()...) + funcs = append(funcs, pireMatchFuncs()...) + funcs = append(funcs, pireMultiFuncs()...) + funcs = append(funcs, pireCaptureFuncs()...) + funcs = append(funcs, pireReplaceFuncs()...) + + return funcs +} + +func pireGrepFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pire::Grep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func pireMatchFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pire::Match", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func pireMultiFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pire::MultiGrep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Pire::MultiMatch", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func pireCaptureFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pire::Capture", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func pireReplaceFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Pire::Replace", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/re2.go b/internal/engine/ydb/lib/cpp/re2.go new file mode 100644 index 0000000000..667c0f57e0 --- /dev/null +++ b/internal/engine/ydb/lib/cpp/re2.go @@ -0,0 +1,319 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func Re2Functions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, re2GrepFuncs()...) + funcs = append(funcs, re2MatchFuncs()...) + funcs = append(funcs, re2CaptureFuncs()...) + funcs = append(funcs, re2FindAndConsumeFuncs()...) + funcs = append(funcs, re2ReplaceFuncs()...) + funcs = append(funcs, re2CountFuncs()...) + funcs = append(funcs, re2OptionsFuncs()...) + + return funcs +} + +func re2GrepFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Re2::Grep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Grep", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func re2MatchFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Re2::Match", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Match", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func re2CaptureFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Re2::Capture", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Capture", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func re2FindAndConsumeFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Re2::FindAndConsume", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::FindAndConsume", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func re2ReplaceFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Re2::Replace", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Replace", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func re2CountFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Re2::Count", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Count", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func re2OptionsFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Re2::Options", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Re2::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/string.go b/internal/engine/ydb/lib/cpp/string.go new file mode 100644 index 0000000000..291dd10eec --- /dev/null +++ b/internal/engine/ydb/lib/cpp/string.go @@ -0,0 +1,152 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func StringFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, stringBase32Funcs()...) + funcs = append(funcs, stringBase64Funcs()...) + funcs = append(funcs, stringEscapeFuncs()...) + funcs = append(funcs, stringHexFuncs()...) + funcs = append(funcs, stringHtmlFuncs()...) + funcs = append(funcs, stringCgiFuncs()...) + + return funcs +} + +func stringBase32Funcs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "String::Base32Encode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "String::Base32Decode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "String::Base32StrictDecode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + } +} + +func stringBase64Funcs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "String::Base64Encode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "String::Base64Decode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "String::Base64StrictDecode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + } +} + +func stringEscapeFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "String::EscapeC", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "String::UnescapeC", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} + +func stringHexFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "String::HexEncode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "String::HexDecode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + } +} + +func stringHtmlFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "String::EncodeHtml", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "String::DecodeHtml", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} + +func stringCgiFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "String::CgiEscape", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "String::CgiUnescape", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/unicode.go b/internal/engine/ydb/lib/cpp/unicode.go new file mode 100644 index 0000000000..e8c967020d --- /dev/null +++ b/internal/engine/ydb/lib/cpp/unicode.go @@ -0,0 +1,532 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func UnicodeFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, unicodeCheckFuncs()...) + funcs = append(funcs, unicodeLengthFuncs()...) + funcs = append(funcs, unicodeFindFuncs()...) + funcs = append(funcs, unicodeSubstringFuncs()...) + funcs = append(funcs, unicodeNormalizeFuncs()...) + funcs = append(funcs, unicodeTranslitFuncs()...) + funcs = append(funcs, unicodeLevensteinFuncs()...) + funcs = append(funcs, unicodeFoldFuncs()...) + funcs = append(funcs, unicodeReplaceFuncs()...) + funcs = append(funcs, unicodeRemoveFuncs()...) + funcs = append(funcs, unicodeCodePointFuncs()...) + funcs = append(funcs, unicodeReverseFuncs()...) + funcs = append(funcs, unicodeCaseFuncs()...) + funcs = append(funcs, unicodeSplitJoinFuncs()...) + funcs = append(funcs, unicodeToUint64Funcs()...) + funcs = append(funcs, unicodeStripFuncs()...) + funcs = append(funcs, unicodeIsFuncs()...) + funcs = append(funcs, unicodeIsUnicodeSetFuncs()...) + + return funcs +} + +func unicodeCheckFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::IsUtf", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func unicodeLengthFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::GetLength", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + } +} + +func unicodeFindFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::Find", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + { + Name: "Unicode::Find", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + { + Name: "Unicode::RFind", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + { + Name: "Unicode::RFind", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + } +} + +func unicodeSubstringFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::Substring", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeNormalizeFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::Normalize", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::NormalizeNFD", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::NormalizeNFC", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::NormalizeNFKD", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::NormalizeNFKC", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeTranslitFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::Translit", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::Translit", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeLevensteinFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::LevensteinDistance", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + } +} + +func unicodeFoldFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::Fold", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::Fold", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::Fold", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::Fold", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::Fold", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::Fold", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeReplaceFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::ReplaceAll", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::ReplaceFirst", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::ReplaceLast", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeRemoveFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::RemoveAll", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::RemoveFirst", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::RemoveLast", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeCodePointFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::ToCodePointList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Unicode::FromCodePointList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeReverseFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::Reverse", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeCaseFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::ToLower", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::ToUpper", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + { + Name: "Unicode::ToTitle", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeSplitJoinFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::SplitToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Unicode::SplitToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Unicode::SplitToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Unicode::SplitToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Unicode::JoinFromList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeToUint64Funcs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::ToUint64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Unicode::ToUint64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Uint16"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Unicode::TryToUint64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + { + Name: "Unicode::TryToUint64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Uint16"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + } +} + +func unicodeStripFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::Strip", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Utf8"}, + }, + } +} + +func unicodeIsFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::IsAscii", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Unicode::IsSpace", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Unicode::IsUpper", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Unicode::IsLower", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Unicode::IsAlpha", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Unicode::IsAlnum", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Unicode::IsHex", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func unicodeIsUnicodeSetFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Unicode::IsUnicodeSet", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Utf8"}}, + {Type: &ast.TypeName{Name: "Utf8"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/url.go b/internal/engine/ydb/lib/cpp/url.go new file mode 100644 index 0000000000..151115a8f0 --- /dev/null +++ b/internal/engine/ydb/lib/cpp/url.go @@ -0,0 +1,413 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func UrlFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, urlNormalizeFuncs()...) + funcs = append(funcs, urlEncodeDecodeFuncs()...) + funcs = append(funcs, urlParseFuncs()...) + funcs = append(funcs, urlGetFuncs()...) + funcs = append(funcs, urlDomainFuncs()...) + funcs = append(funcs, urlCutFuncs()...) + funcs = append(funcs, urlPunycodeFuncs()...) + funcs = append(funcs, urlQueryStringFuncs()...) + + return funcs +} + +func urlNormalizeFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Url::Normalize", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::NormalizeWithDefaultHttpScheme", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + } +} + +func urlEncodeDecodeFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Url::Encode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::Decode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + } +} + +func urlParseFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Url::Parse", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func urlGetFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Url::GetScheme", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Url::GetHost", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::GetHostPort", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::GetSchemeHost", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::GetSchemeHostPort", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::GetPort", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::GetTail", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::GetPath", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::GetFragment", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::GetCGIParam", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::GetDomain", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Uint8"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + } +} + +func urlDomainFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Url::GetTLD", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Url::IsKnownTLD", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Url::IsWellKnownTLD", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Url::GetDomainLevel", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "Url::GetSignificantDomain", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Url::GetSignificantDomain", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Url::GetOwner", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} + +func urlCutFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Url::CutScheme", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::CutWWW", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::CutWWW2", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::CutQueryStringAndFragment", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} + +func urlPunycodeFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Url::HostNameToPunycode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::ForceHostNameToPunycode", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Url::PunycodeToHostName", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Url::ForcePunycodeToHostName", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Url::CanBePunycodeHostName", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func urlQueryStringFuncs() []*catalog.Function { + // fixme: rewrite with containers if possible + return []*catalog.Function{ + { + Name: "Url::QueryStringToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::QueryStringToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::QueryStringToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::QueryStringToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::QueryStringToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::QueryStringToDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::QueryStringToDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::QueryStringToDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::QueryStringToDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::QueryStringToDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Url::BuildQueryString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + { + Name: "Url::BuildQueryString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + }, + } +} diff --git a/internal/engine/ydb/lib/cpp/yson.go b/internal/engine/ydb/lib/cpp/yson.go new file mode 100644 index 0000000000..78332e0f29 --- /dev/null +++ b/internal/engine/ydb/lib/cpp/yson.go @@ -0,0 +1,632 @@ +package cpp + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func YsonFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, ysonParseFuncs()...) + funcs = append(funcs, ysonFromFuncs()...) + funcs = append(funcs, ysonWithAttributesFuncs()...) + funcs = append(funcs, ysonEqualsFuncs()...) + funcs = append(funcs, ysonGetHashFuncs()...) + funcs = append(funcs, ysonIsFuncs()...) + funcs = append(funcs, ysonGetLengthFuncs()...) + funcs = append(funcs, ysonConvertToFuncs()...) + funcs = append(funcs, ysonConvertToListFuncs()...) + funcs = append(funcs, ysonConvertToDictFuncs()...) + funcs = append(funcs, ysonContainsFuncs()...) + funcs = append(funcs, ysonLookupFuncs()...) + funcs = append(funcs, ysonYPathFuncs()...) + funcs = append(funcs, ysonAttributesFuncs()...) + funcs = append(funcs, ysonSerializeFuncs()...) + funcs = append(funcs, ysonOptionsFuncs()...) + + return funcs +} + +func ysonParseFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::Parse", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Yson"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ParseJson", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Json"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ParseJsonDecodeUtf8", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Json"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::Parse", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::ParseJson", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::ParseJsonDecodeUtf8", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func ysonFromFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::From", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func ysonWithAttributesFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::WithAttributes", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func ysonEqualsFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::Equals", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func ysonGetHashFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::GetHash", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + } +} + +func ysonIsFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::IsEntity", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Yson::IsString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Yson::IsDouble", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Yson::IsUint64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Yson::IsInt64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Yson::IsBool", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Yson::IsList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + { + Name: "Yson::IsDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + }, + } +} + +func ysonGetLengthFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::GetLength", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + } +} + +func ysonConvertToFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::ConvertTo", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToBool", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::ConvertToInt64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::ConvertToUint64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::ConvertToDouble", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::ConvertToString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::ConvertToList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func ysonConvertToListFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::ConvertToBoolList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToInt64List", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToUint64List", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToDoubleList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToStringList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func ysonConvertToDictFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::ConvertToDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToBoolDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToInt64Dict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToUint64Dict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToDoubleDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::ConvertToStringDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func ysonContainsFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::Contains", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + ReturnTypeNullable: true, + }, + } +} + +func ysonLookupFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::Lookup", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::LookupBool", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::LookupInt64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::LookupUint64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::LookupDouble", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::LookupString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::LookupDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::LookupList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func ysonYPathFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::YPath", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::YPathBool", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Bool"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::YPathInt64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Int64"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::YPathUint64", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::YPathDouble", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::YPathString", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "String"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::YPathDict", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::YPathList", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "String"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func ysonAttributesFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::Attributes", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} + +func ysonSerializeFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::Serialize", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Yson"}, + }, + { + Name: "Yson::SerializeText", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Yson"}, + }, + { + Name: "Yson::SerializePretty", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Yson"}, + }, + { + Name: "Yson::SerializeJson", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Json"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::SerializeJson", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Json"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::SerializeJson", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Json"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::SerializeJson", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Json"}, + ReturnTypeNullable: true, + }, + { + Name: "Yson::SerializeJson", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "Json"}, + ReturnTypeNullable: true, + }, + } +} + +func ysonOptionsFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "Yson::Options", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + { + Name: "Yson::Options", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Bool"}}, + {Type: &ast.TypeName{Name: "Bool"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} diff --git a/internal/engine/ydb/lib/window.go b/internal/engine/ydb/lib/window.go new file mode 100644 index 0000000000..7c339217ad --- /dev/null +++ b/internal/engine/ydb/lib/window.go @@ -0,0 +1,163 @@ +package lib + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func WindowFunctions() []*catalog.Function { + var funcs []*catalog.Function + + funcs = append(funcs, rowNumberFuncs()...) + funcs = append(funcs, lagLeadFuncs()...) + funcs = append(funcs, firstLastValueFuncs()...) + funcs = append(funcs, nthValueFuncs()...) + funcs = append(funcs, rankFuncs()...) + funcs = append(funcs, ntileFuncs()...) + funcs = append(funcs, cumeDistFuncs()...) + funcs = append(funcs, sessionStartFuncs()...) + + return funcs +} + +func rowNumberFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "ROW_NUMBER", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + } +} + +func lagLeadFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "LAG", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "LAG", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "LEAD", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "LEAD", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "Uint32"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func firstLastValueFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "FIRST_VALUE", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + { + Name: "LAST_VALUE", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func nthValueFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "NTH_VALUE", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "any"}, + ReturnTypeNullable: true, + }, + } +} + +func rankFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "RANK", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "DENSE_RANK", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + { + Name: "PERCENT_RANK", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "any"}}, + }, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + } +} + +func ntileFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "NTILE", + Args: []*catalog.Argument{ + {Type: &ast.TypeName{Name: "Uint64"}}, + }, + ReturnType: &ast.TypeName{Name: "Uint64"}, + }, + } +} + +func cumeDistFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "CUME_DIST", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "Double"}, + }, + } +} + +func sessionStartFuncs() []*catalog.Function { + return []*catalog.Function{ + { + Name: "SESSION_START", + Args: []*catalog.Argument{}, + ReturnType: &ast.TypeName{Name: "any"}, + }, + } +} diff --git a/internal/engine/ydb/stdlib.go b/internal/engine/ydb/stdlib.go index 21dc21242b..69ef6b223e 100644 --- a/internal/engine/ydb/stdlib.go +++ b/internal/engine/ydb/stdlib.go @@ -8,11 +8,15 @@ import ( func defaultSchema(name string) *catalog.Schema { s := &catalog.Schema{ Name: name, - Funcs: make([]*catalog.Function, 0, 128), + Funcs: []*catalog.Function{}, } s.Funcs = append(s.Funcs, lib.BasicFunctions()...) s.Funcs = append(s.Funcs, lib.AggregateFunctions()...) + s.Funcs = append(s.Funcs, lib.WindowFunctions()...) + s.Funcs = append(s.Funcs, lib.CppFunctions()...) + + // TODO: add container functions if return s } diff --git a/internal/sql/ast/recursive_func_call.go b/internal/sql/ast/recursive_func_call.go deleted file mode 100644 index 1c7c0a8125..0000000000 --- a/internal/sql/ast/recursive_func_call.go +++ /dev/null @@ -1,33 +0,0 @@ -package ast - -type RecursiveFuncCall struct { - Func Node - Funcname *List - Args *List - AggOrder *List - AggFilter Node - AggWithinGroup bool - AggStar bool - AggDistinct bool - FuncVariadic bool - Over *WindowDef - Location int -} - -func (n *RecursiveFuncCall) Pos() int { - return n.Location -} - -func (n *RecursiveFuncCall) Format(buf *TrackedBuffer) { - if n == nil { - return - } - buf.astFormat(n.Func) - buf.WriteString("(") - if n.AggStar { - buf.WriteString("*") - } else { - buf.astFormat(n.Args) - } - buf.WriteString(")") -} diff --git a/internal/sql/astutils/rewrite.go b/internal/sql/astutils/rewrite.go index bcc7c17e40..372d0ee7a2 100644 --- a/internal/sql/astutils/rewrite.go +++ b/internal/sql/astutils/rewrite.go @@ -1013,14 +1013,6 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "Roles", nil, n.Roles) a.apply(n, "Newrole", nil, n.Newrole) - case *ast.RecursiveFuncCall: - a.apply(n, "Func", nil, n.Func) - a.apply(n, "Funcname", nil, n.Funcname) - a.apply(n, "Args", nil, n.Args) - a.apply(n, "AggOrder", nil, n.AggOrder) - a.apply(n, "AggFilter", nil, n.AggFilter) - a.apply(n, "Over", nil, n.Over) - case *ast.RefreshMatViewStmt: a.apply(n, "Relation", nil, n.Relation) diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index dfc313fda1..e7b78d126b 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -1734,26 +1734,6 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.Newrole) } - case *ast.RecursiveFuncCall: - if n.Func != nil { - Walk(f, n.Func) - } - if n.Funcname != nil { - Walk(f, n.Funcname) - } - if n.Args != nil { - Walk(f, n.Args) - } - if n.AggOrder != nil { - Walk(f, n.AggOrder) - } - if n.AggFilter != nil { - Walk(f, n.AggFilter) - } - if n.Over != nil { - Walk(f, n.Over) - } - case *ast.RefreshMatViewStmt: if n.Relation != nil { Walk(f, n.Relation) From a69429ce0005ffb7ee1298b732409727325138cf Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov <150552906+1NepuNep1@users.noreply.github.com> Date: Wed, 8 Oct 2025 17:36:16 +0300 Subject: [PATCH 17/18] Added simple types in param builder and ydbType (#14) --- internal/codegen/golang/query.go | 10 +++++++--- internal/codegen/golang/ydb_type.go | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 52be2ecceb..1ca3e1615e 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -304,15 +304,15 @@ func ydbBuilderMethodForColumnType(dbType string) string { return "Bool" case "uint64": return "Uint64" - case "int64": + case "int64", "bigserial", "serial8": return "Int64" case "uint32": return "Uint32" - case "int32": + case "int32", "serial", "serial4": return "Int32" case "uint16": return "Uint16" - case "int16": + case "int16", "smallserial","serial2": return "Int16" case "uint8": return "Uint8" @@ -342,6 +342,10 @@ func ydbBuilderMethodForColumnType(dbType string) string { return "TzDatetime" case "tztimestamp": return "TzTimestamp" + case "uuid": + return "UUID" + case "yson": + return "YSON" //TODO: support other types default: diff --git a/internal/codegen/golang/ydb_type.go b/internal/codegen/golang/ydb_type.go index 0a4db80a3b..3a01cc8de7 100644 --- a/internal/codegen/golang/ydb_type.go +++ b/internal/codegen/golang/ydb_type.go @@ -171,6 +171,24 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col } return "*time.Time" + case "uuid": + if notNull { + return "uuid.UUID" + } + if emitPointersForNull { + return "*uuid.UUID" + } + return "*uuid.UUID" + + case "yson": + if notNull { + return "[]byte" + } + if emitPointersForNull { + return "*[]byte" + } + return "*[]byte" + case "null": // return "sql.Null" return "interface{}" From 816eda801855b270d5b6b0e58e25d50c6a97219d Mon Sep 17 00:00:00 2001 From: Viktor Pentyukhov <150552906+1NepuNep1@users.noreply.github.com> Date: Thu, 9 Oct 2025 16:28:52 +0300 Subject: [PATCH 18/18] Ydb-go-sdk codegen: Containers support in ydb.ParamsBuilder() (#15) * Added Arrays support + rewrites some internal logic to handle sqlc.arg/narg/slice in ydb * Supported containertype named parameters * Rewrited params to handle compex types and comment if type is interface{} --- examples/authors/ydb/query.sql.go | 34 +- internal/codegen/golang/query.go | 187 +++++-- .../templates/ydb-go-sdk/queryCode.tmpl | 70 ++- internal/codegen/golang/ydb_type.go | 3 +- internal/engine/ydb/convert.go | 489 +++++++++++++----- internal/sql/rewrite/parameters.go | 4 +- 6 files changed, 568 insertions(+), 219 deletions(-) diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index e9b6b332a4..6d3c6743a9 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -24,16 +24,12 @@ type CreateOrUpdateAuthorParams struct { } func (q *Queries) CreateOrUpdateAuthor(ctx context.Context, arg CreateOrUpdateAuthorParams, opts ...query.ExecuteOption) error { + parameters := ydb.ParamsBuilder() + parameters = parameters.Param("$p0").Uint64(arg.P0) + parameters = parameters.Param("$p1").Text(arg.P1) + parameters = parameters.Param("$p2").BeginOptional().Text(arg.P2).EndOptional() err := q.db.Exec(ctx, createOrUpdateAuthor, - append(opts, - query.WithParameters( - ydb.ParamsBuilder(). - Param("$p0").Uint64(arg.P0). - Param("$p1").Text(arg.P1). - Param("$p2").BeginOptional().Text(arg.P2).EndOptional(). - Build(), - ), - )..., + append(opts, query.WithParameters(parameters.Build()))..., ) if err != nil { return xerrors.WithStackTrace(err) @@ -46,14 +42,10 @@ DELETE FROM authors WHERE id = $p0 ` func (q *Queries) DeleteAuthor(ctx context.Context, p0 uint64, opts ...query.ExecuteOption) error { + parameters := ydb.ParamsBuilder() + parameters = parameters.Param("$p0").Uint64(p0) err := q.db.Exec(ctx, deleteAuthor, - append(opts, - query.WithParameters( - ydb.ParamsBuilder(). - Param("$p0").Uint64(p0). - Build(), - ), - )..., + append(opts, query.WithParameters(parameters.Build()))..., ) if err != nil { return xerrors.WithStackTrace(err) @@ -79,14 +71,10 @@ WHERE id = $p0 LIMIT 1 ` func (q *Queries) GetAuthor(ctx context.Context, p0 uint64, opts ...query.ExecuteOption) (Author, error) { + parameters := ydb.ParamsBuilder() + parameters = parameters.Param("$p0").Uint64(p0) row, err := q.db.QueryRow(ctx, getAuthor, - append(opts, - query.WithParameters( - ydb.ParamsBuilder(). - Param("$p0").Uint64(p0). - Build(), - ), - )..., + append(opts, query.WithParameters(parameters.Build()))..., ) var i Author if err != nil { diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 1ca3e1615e..3e789f4c60 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -128,6 +128,27 @@ func (v QueryValue) UniqueFields() []Field { return fields } +// YDBUniqueFieldsWithComments returns unique fields for YDB struct generation with comments for interface{} fields +func (v QueryValue) YDBUniqueFieldsWithComments() []Field { + seen := map[string]struct{}{} + fields := make([]Field, 0, len(v.Struct.Fields)) + + for _, field := range v.Struct.Fields { + if _, found := seen[field.Name]; found { + continue + } + seen[field.Name] = struct{}{} + + if strings.HasSuffix(field.Type, "interface{}") { + field.Comment = "// sqlc couldn't resolve type, pass via params" + } + + fields = append(fields, field) + } + + return fields +} + func (v QueryValue) Params() string { if v.isEmpty() { return "" @@ -198,7 +219,7 @@ func (v QueryValue) HasSqlcSlices() bool { func (v QueryValue) Scan() string { var out []string if v.Struct == nil { - if strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() { + if strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() && !v.SQLDriver.IsYDBGoSDK() { out = append(out, "pq.Array(&"+v.Name+")") } else { out = append(out, "&"+v.Name) @@ -209,7 +230,7 @@ func (v QueryValue) Scan() string { // append any embedded fields if len(f.EmbedFields) > 0 { for _, embed := range f.EmbedFields { - if strings.HasPrefix(embed.Type, "[]") && embed.Type != "[]byte" && !v.SQLDriver.IsPGX() { + if strings.HasPrefix(embed.Type, "[]") && embed.Type != "[]byte" && !v.SQLDriver.IsPGX() && !v.SQLDriver.IsYDBGoSDK() { out = append(out, "pq.Array(&"+v.Name+"."+f.Name+"."+embed.Name+")") } else { out = append(out, "&"+v.Name+"."+f.Name+"."+embed.Name) @@ -218,7 +239,7 @@ func (v QueryValue) Scan() string { continue } - if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() { + if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() && !v.SQLDriver.IsYDBGoSDK() { out = append(out, "pq.Array(&"+v.Name+"."+f.Name+")") } else { out = append(out, "&"+v.Name+"."+f.Name) @@ -269,32 +290,6 @@ func addDollarPrefix(name string) string { return "$" + name } -// YDBParamMapEntries returns entries for a map[string]any literal for YDB parameters. -func (v QueryValue) YDBParamMapEntries() string { - if v.isEmpty() { - return "" - } - - var parts []string - for _, field := range v.getParameterFields() { - if field.Column != nil && field.Column.IsNamedParam { - name := field.Column.GetName() - if name != "" { - key := fmt.Sprintf("%q", addDollarPrefix(name)) - variable := v.VariableForField(field) - parts = append(parts, key+": "+escape(variable)) - } - } - } - - if len(parts) == 0 { - return "" - } - - parts = append(parts, "") - return "\n" + strings.Join(parts, ",\n") -} - // ydbBuilderMethodForColumnType maps a YDB column data type to a ParamsBuilder method name. func ydbBuilderMethodForColumnType(dbType string) string { baseType := extractBaseType(strings.ToLower(dbType)) @@ -353,52 +348,89 @@ func ydbBuilderMethodForColumnType(dbType string) string { } } -// YDBParamsBuilder emits Go code that constructs YDB params using ParamsBuilder. -func (v QueryValue) YDBParamsBuilder() string { +// ydbIterateNamedParams iterates over named parameters and calls the provided function for each one. +// The function receives the field and method name, and should return true to continue iteration. +func (v QueryValue) ydbIterateNamedParams(fn func(field Field, method string) bool) bool { if v.isEmpty() { - return "" + return false } - var lines []string - for _, field := range v.getParameterFields() { - if field.Column != nil && field.Column.IsNamedParam { + if field.Column != nil { name := field.Column.GetName() if name == "" { continue } - paramName := fmt.Sprintf("%q", addDollarPrefix(name)) - variable := escape(v.VariableForField(field)) var method string if field.Column != nil && field.Column.Type != nil { method = ydbBuilderMethodForColumnType(sdk.DataType(field.Column.Type)) } - goType := field.Type - isPtr := strings.HasPrefix(goType, "*") - if isPtr { - goType = strings.TrimPrefix(goType, "*") + if !fn(field, method) { + return false } + } + } + return true +} - if method == "" { - panic(fmt.Sprintf("unknown YDB column type for param %s (goType=%s)", name, goType)) - } +// YDBHasComplexContainers returns true if there are complex container types that sqlc cannot handle automatically. +func (v QueryValue) YDBHasComplexContainers() bool { + hasComplex := false + v.ydbIterateNamedParams(func(field Field, method string) bool { + if method == "" { + hasComplex = true + return false + } + return true + }) + return hasComplex +} - if isPtr { - lines = append(lines, fmt.Sprintf("\t\t\tParam(%s).BeginOptional().%s(%s).EndOptional().", paramName, method, variable)) - } else { - lines = append(lines, fmt.Sprintf("\t\t\tParam(%s).%s(%s).", paramName, method, variable)) - } +// YDBParamsBuilder emits Go code that constructs YDB params using ParamsBuilder. +func (v QueryValue) YDBParamsBuilder() string { + var lines []string + + v.ydbIterateNamedParams(func(field Field, method string) bool { + if method == "" { + return true } - } - if len(lines) == 0 { - return "" - } + name := field.Column.GetName() + paramName := fmt.Sprintf("%q", addDollarPrefix(name)) + variable := escape(v.VariableForField(field)) + + goType := field.Type + isPtr := strings.HasPrefix(goType, "*") + isArray := field.Column.IsArray || field.Column.IsSqlcSlice + + if isArray { + lines = append(lines, fmt.Sprintf("\tvar list = parameters.Param(%s).BeginList()", paramName)) + lines = append(lines, fmt.Sprintf("\tfor _, param := range %s {", variable)) + lines = append(lines, fmt.Sprintf("\t\tlist = list.Add().%s(param)", method)) + lines = append(lines, "\t}") + lines = append(lines, "\tparameters = list.EndList()") + } else if isPtr { + lines = append(lines, fmt.Sprintf("\tparameters = parameters.Param(%s).BeginOptional().%s(%s).EndOptional()", paramName, method, variable)) + } else { + lines = append(lines, fmt.Sprintf("\tparameters = parameters.Param(%s).%s(%s)", paramName, method, variable)) + } + + return true + }) params := strings.Join(lines, "\n") - return fmt.Sprintf("\nquery.WithParameters(\n\t\tydb.ParamsBuilder().\n%s\n\t\t\tBuild(),\n\t\t),\n", params) + return fmt.Sprintf("\tparameters := ydb.ParamsBuilder()\n%s", params) +} + +func (v QueryValue) YDBHasParams() bool { + hasParams := false + v.ydbIterateNamedParams(func(field Field, method string) bool { + hasParams = true + return false + }) + return hasParams } func (v QueryValue) getParameterFields() []Field { @@ -415,6 +447,53 @@ func (v QueryValue) getParameterFields() []Field { return v.Struct.Fields } +// YDBPair returns the argument name and type for YDB query methods, filtering out interface{} parameters +// that are handled by manualParams instead. +func (v QueryValue) YDBPair() string { + if v.isEmpty() { + return "" + } + + var out []string + for _, arg := range v.YDBPairs() { + out = append(out, arg.Name+" "+arg.Type) + } + return strings.Join(out, ",") +} + +// YDBPairs returns the argument name and type for YDB query methods, filtering out interface{} parameters +// that are handled by manualParams instead. +func (v QueryValue) YDBPairs() []Argument { + if v.isEmpty() { + return nil + } + + if !v.EmitStruct() && v.IsStruct() { + var out []Argument + for _, f := range v.Struct.Fields { + if strings.HasSuffix(f.Type, "interface{}") { + continue + } + out = append(out, Argument{ + Name: escape(toLowerCase(f.Name)), + Type: f.Type, + }) + } + return out + } + + if strings.HasSuffix(v.Typ, "interface{}") { + return nil + } + + return []Argument{ + { + Name: escape(v.Name), + Type: v.DefineType(), + }, + } +} + // A struct used to generate methods and fields on the Queries struct type Query struct { Cmd string diff --git a/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl b/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl index c56fc953f8..9c60ded05f 100644 --- a/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl +++ b/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl @@ -6,8 +6,8 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} {{$.Q}} {{if .Arg.EmitStruct}} -type {{.Arg.Type}} struct { {{- range .Arg.UniqueFields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} +type {{.Arg.Type}} struct { {{- range .Arg.YDBUniqueFieldsWithComments}} + {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}{{if .Comment}} {{.Comment}}{{end}} {{- end}} } {{end}} @@ -22,13 +22,27 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{if eq .Cmd ":one"}} {{range .Comments}}//{{.}} {{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ({{.Ret.DefineType}}, error) { +{{if .Arg.YDBHasComplexContainers}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if .Arg.YDBPair}}{{.Arg.YDBPair}}, {{end}}params ydb.Params, opts ...query.ExecuteOption) ({{.Ret.DefineType}}, error) { +{{else}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if .Arg.YDBPair}}{{.Arg.YDBPair}}, {{end}}opts ...query.ExecuteOption) ({{.Ret.DefineType}}, error) { +{{end -}} {{- $dbArg := "q.db" }}{{- if $.EmitMethodsWithDBArgument }}{{- $dbArg = "db" }}{{- end -}} - {{- if .Arg.IsEmpty }} + {{- if .Arg.IsEmpty -}} row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, opts...) - {{- else }} + {{- else -}} + {{- .Arg.YDBParamsBuilder}} + {{- if .Arg.YDBHasComplexContainers }} + for name, value := range params.Range() { + parameters = parameters.Param(name).Any(value) + } + {{- end }} row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, - append(opts, {{.Arg.YDBParamsBuilder}})..., + {{- if .Arg.YDBHasParams }} + append(opts, query.WithParameters(parameters.Build()))..., + {{- else }} + opts..., + {{- end }} ) {{- end }} {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} @@ -58,13 +72,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA {{if eq .Cmd ":many"}} {{range .Comments}}//{{.}} {{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ([]{{.Ret.DefineType}}, error) { +{{if .Arg.YDBHasComplexContainers}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if .Arg.YDBPair}}{{.Arg.YDBPair}}, {{end}}params ydb.Params, opts ...query.ExecuteOption) ([]{{.Ret.DefineType}}, error) { +{{else}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if .Arg.YDBPair}}{{.Arg.YDBPair}}, {{end}}opts ...query.ExecuteOption) ([]{{.Ret.DefineType}}, error) { +{{end}} {{- $dbArg := "q.db" }}{{- if $.EmitMethodsWithDBArgument }}{{- $dbArg = "db" }}{{- end -}} - {{- if .Arg.IsEmpty }} + {{- if .Arg.IsEmpty -}} result, err := {{$dbArg}}.QueryResultSet(ctx, {{.ConstantName}}, opts...) - {{- else }} + {{- else -}} + {{- .Arg.YDBParamsBuilder}} + {{- if .Arg.YDBHasComplexContainers }} + for name, value := range params.Range() { + parameters = parameters.Param(name).Any(value) + } + {{- end }} result, err := {{$dbArg}}.QueryResultSet(ctx, {{.ConstantName}}, - append(opts, {{.Arg.YDBParamsBuilder}})..., + {{- if .Arg.YDBHasParams }} + append(opts, query.WithParameters(parameters.Build()))..., + {{- else }} + opts..., + {{- end }} ) {{- end }} if err != nil { @@ -111,13 +139,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA {{if eq .Cmd ":exec"}} {{range .Comments}}//{{.}} {{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) error { +{{if .Arg.YDBHasComplexContainers}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if .Arg.YDBPair}}{{.Arg.YDBPair}}, {{end}}params ydb.Params, opts ...query.ExecuteOption) error { +{{else}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if .Arg.YDBPair}}{{.Arg.YDBPair}}, {{end}}opts ...query.ExecuteOption) error { +{{end -}} {{- $dbArg := "q.db" }}{{- if $.EmitMethodsWithDBArgument }}{{- $dbArg = "db" }}{{- end -}} - {{- if .Arg.IsEmpty }} + {{- if .Arg.IsEmpty -}} err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, opts...) - {{- else }} + {{- else -}} + {{- .Arg.YDBParamsBuilder}} + {{- if .Arg.YDBHasComplexContainers }} + for name, value := range params.Range() { + parameters = parameters.Param(name).Any(value) + } + {{- end }} err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, - append(opts, {{.Arg.YDBParamsBuilder}})..., + {{- if .Arg.YDBHasParams }} + append(opts, query.WithParameters(parameters.Build()))..., + {{- else }} + opts..., + {{- end }} ) {{- end }} if err != nil { diff --git a/internal/codegen/golang/ydb_type.go b/internal/codegen/golang/ydb_type.go index 3a01cc8de7..ada576795d 100644 --- a/internal/codegen/golang/ydb_type.go +++ b/internal/codegen/golang/ydb_type.go @@ -12,7 +12,8 @@ import ( func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string { columnType := strings.ToLower(sdk.DataType(col.Type)) - notNull := (col.NotNull || col.IsArray) && !isNullableType(columnType) + isArray := col.IsArray || col.IsSqlcSlice + notNull := (col.NotNull || isArray) && (!isNullableType(columnType) || isArray) emitPointersForNull := options.EmitPointersForNullTypes columnType = extractBaseType(columnType) diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index 0fa339fa56..fb818a5d5b 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -1754,6 +1754,11 @@ func (c *cc) VisitColumn_schema(n *parser.Column_schemaContext) interface{} { if !ok { return todo("VisitColumn_schema", tnb) } + if typeName.ArrayBounds != nil && len(typeName.ArrayBounds.Items) > 0 { + col.IsArray = true + col.ArrayDims = len(typeName.ArrayBounds.Items) + typeName.ArrayBounds = nil + } col.TypeName = typeName } if colCons := n.Opt_column_constraints(); colCons != nil { @@ -1792,6 +1797,7 @@ func (c *cc) VisitType_name_or_bind(n *parser.Type_name_or_bindContext) interfac if !ok { return todo("VisitType_name_or_bind", b) } + // FIXME: this is not working right now for type definitions return &ast.TypeName{ Names: &ast.List{ Items: []ast.Node{param}, @@ -1892,81 +1898,36 @@ func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) inte } if tuple := n.Type_name_tuple(); tuple != nil { - if typeNames := tuple.AllType_name_or_bind(); len(typeNames) > 0 { - var items []ast.Node - for _, tn := range typeNames { - tnNode, ok := tn.Accept(c).(ast.Node) - if !ok { - return todo("VisitType_name_composite", tn) - } - items = append(items, tnNode) - } - return &ast.TypeName{ - Name: "Tuple", - TypeOid: 0, - Names: &ast.List{Items: items}, - } - } + return tuple.Accept(c) } if struct_ := n.Type_name_struct(); struct_ != nil { if structArgs := struct_.AllStruct_arg(); len(structArgs) > 0 { - var items []ast.Node - for range structArgs { - // TODO: Handle struct field names and types - items = append(items, &ast.TODO{}) - } return &ast.TypeName{ - Name: "Struct", + Name: "any", TypeOid: 0, - Names: &ast.List{Items: items}, } } } if variant := n.Type_name_variant(); variant != nil { if variantArgs := variant.AllVariant_arg(); len(variantArgs) > 0 { - var items []ast.Node - for range variantArgs { - // TODO: Handle variant arguments - items = append(items, &ast.TODO{}) - } return &ast.TypeName{ - Name: "Variant", + Name: "any", TypeOid: 0, - Names: &ast.List{Items: items}, } } } if list := n.Type_name_list(); list != nil { - if typeName := list.Type_name_or_bind(); typeName != nil { - tn, ok := typeName.Accept(c).(ast.Node) - if !ok { - return todo("VisitType_name_composite", typeName) - } - return &ast.TypeName{ - Name: "List", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{tn}, - }, - } - } + return list.Accept(c) } if stream := n.Type_name_stream(); stream != nil { - if typeName := stream.Type_name_or_bind(); typeName != nil { - tn, ok := typeName.Accept(c).(ast.Node) - if !ok { - return todo("VisitType_name_composite", typeName) - } + if stream.Type_name_or_bind() != nil { return &ast.TypeName{ - Name: "Stream", + Name: "any", TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{tn}, - }, } } } @@ -1976,40 +1937,19 @@ func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) inte } if dict := n.Type_name_dict(); dict != nil { - if typeNames := dict.AllType_name_or_bind(); len(typeNames) >= 2 { - first, ok := typeNames[0].Accept(c).(ast.Node) - if !ok { - return todo("VisitType_name_composite", typeNames[0]) - } - second, ok := typeNames[1].Accept(c).(ast.Node) - if !ok { - return todo("VisitType_name_composite", typeNames[1]) - } + if dict.AllType_name_or_bind() != nil { return &ast.TypeName{ - Name: "Dict", + Name: "any", TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{ - first, - second, - }, - }, } } } if set := n.Type_name_set(); set != nil { - if typeName := set.Type_name_or_bind(); typeName != nil { - tn, ok := typeName.Accept(c).(ast.Node) - if !ok { - return todo("VisitType_name_composite", typeName) - } + if set.Type_name_or_bind() != nil { return &ast.TypeName{ - Name: "Set", + Name: "any", TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{tn}, - }, } } } @@ -2050,10 +1990,76 @@ func (c *cc) VisitType_name_optional(n *parser.Type_name_optionalContext) interf return &ast.TypeName{ Name: name, TypeOid: 0, - Names: &ast.List{}, } } +func (c *cc) VisitType_name_list(n *parser.Type_name_listContext) interface{} { + if n == nil || n.Type_name_or_bind() == nil { + return todo("VisitType_name_list", n) + } + + tn, ok := n.Type_name_or_bind().Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_list", n.Type_name_or_bind()) + } + innerTypeName, ok := tn.(*ast.TypeName) + if !ok { + return todo("VisitType_name_list", n.Type_name_or_bind()) + } + + if innerTypeName.ArrayBounds != nil { + return &ast.TypeName{ + Name: "any", + TypeOid: 0, + } + } + + return &ast.TypeName{ + Name: innerTypeName.Name, + TypeOid: 0, + ArrayBounds: &ast.List{ + Items: []ast.Node{&ast.TODO{}}, + }, + } +} + +func (c *cc) VisitType_name_tuple(n *parser.Type_name_tupleContext) interface{} { + if n == nil || len(n.AllType_name_or_bind()) == 0 { + return todo("VisitType_name_tuple", n) + } + + var items []ast.Node + for _, tn := range n.AllType_name_or_bind() { + tnNode, ok := tn.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_tuple", tn) + } + items = append(items, tnNode) + } + + var typeName string + for _, node := range items { + switch innerTypeName := node.(type) { + case *ast.TypeName: + if typeName == "" { + typeName = innerTypeName.Name + } else if typeName != innerTypeName.Name { + typeName = "any" + break + } + default: + typeName = "any" + } + } + + return &ast.TypeName{ + Name: typeName, + TypeOid: 0, + ArrayBounds: &ast.List{Items: []ast.Node{&ast.TODO{}}}, + Location: c.pos(n.GetStart()), + } + +} func (c *cc) VisitSql_stmt_core(n *parser.Sql_stmt_coreContext) interface{} { if n == nil { return todo("VisitSql_stmt_core", n) @@ -2359,21 +2365,16 @@ func (c *cc) VisitXor_subexpr(n *parser.Xor_subexprContext) interface{} { } if condCtx := n.Cond_expr(); condCtx != nil { - switch { case condCtx.IN() != nil: if inExpr := condCtx.In_expr(); inExpr != nil { - temp, ok := inExpr.Accept(c).(ast.Node) - if !ok { - return todo("VisitXor_subexpr", inExpr) - } - list, ok := temp.(*ast.List) + node, ok := inExpr.Accept(c).(ast.Node) if !ok { return todo("VisitXor_subexpr", inExpr) } return &ast.In{ Expr: base, - List: list.Items, + List: []ast.Node{node}, Not: condCtx.NOT() != nil, Location: c.pos(n.GetStart()), } @@ -2708,6 +2709,145 @@ func (c *cc) VisitCon_subexpr(n *parser.Con_subexprContext) interface{} { } +func (c *cc) VisitIn_expr(n *parser.In_exprContext) interface{} { + if n == nil || n.In_unary_subexpr() == nil { + return todo("VisitIn_expr", n) + } + return n.In_unary_subexpr().Accept(c) +} + +func (c *cc) VisitIn_unary_subexpr(n *parser.In_unary_subexprContext) interface{} { + if n == nil || (n.In_unary_casual_subexpr() == nil && n.Json_api_expr() == nil) { + return todo("VisitIn_unary_subexpr", n) + } + if unary := n.In_unary_casual_subexpr(); unary != nil { + expr, ok := unary.Accept(c).(ast.Node) + if !ok { + return todo("VisitIn_unary_subexpr", unary) + } + return expr + } + jsonExpr, ok := n.Json_api_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitIn_unary_subexpr", n.Json_api_expr()) + } + return jsonExpr +} + +func (c *cc) VisitIn_unary_casual_subexpr(n *parser.In_unary_casual_subexprContext) interface{} { + var current ast.Node + switch { + case n.Id_expr_in() != nil: + expr, ok := n.Id_expr_in().Accept(c).(ast.Node) + if !ok { + return todo("VisitIn_unary_casual_subexpr", n.Id_expr_in()) + } + current = expr + case n.In_atom_expr() != nil: + expr, ok := n.In_atom_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitIn_unary_casual_subexpr", n.In_atom_expr()) + } + current = expr + default: + return todo("VisitIn_unary_casual_subexpr", n) + } + + if suffix := n.Unary_subexpr_suffix(); suffix != nil { + current = c.processSuffixChain(current, suffix.(*parser.Unary_subexpr_suffixContext)) + } + + return current +} + +func (c *cc) VisitId_expr_in(n *parser.Id_expr_inContext) interface{} { + if n == nil { + return todo("VisitId_expr", n) + } + + ref := &ast.ColumnRef{ + Fields: &ast.List{}, + Location: c.pos(n.GetStart()), + } + + if id := n.Identifier(); id != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(id.GetText())) + return ref + } + + if keyword := n.Keyword_compat(); keyword != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText())) + return ref + } + + if keyword := n.Keyword_alter_uncompat(); keyword != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText())) + return ref + } + + if keyword := n.Keyword_window_uncompat(); keyword != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText())) + return ref + } + + if keyword := n.Keyword_hint_uncompat(); keyword != nil { + ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText())) + return ref + } + + return todo("VisitId_expr_in", n) +} + +func (c *cc) VisitIn_atom_expr(n *parser.In_atom_exprContext) interface{} { + if n == nil { + return todo("VisitAtom_expr", n) + } + + switch { + case n.An_id_or_type() != nil: + if n.NAMESPACE() != nil { + return NewIdentifier(parseAnIdOrType(n.An_id_or_type()) + "::" + parseIdOrType(n.Id_or_type())) + } + return NewIdentifier(parseAnIdOrType(n.An_id_or_type())) + case n.Literal_value() != nil: + expr, ok := n.Literal_value().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Literal_value()) + } + return expr + case n.Bind_parameter() != nil: + expr, ok := n.Bind_parameter().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Bind_parameter()) + } + return expr + case n.Cast_expr() != nil: + expr, ok := n.Cast_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Cast_expr()) + } + return expr + + case n.LPAREN() != nil && n.Select_stmt() != nil && n.RPAREN() != nil: + selectStmt, ok := n.Select_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Select_stmt()) + } + return selectStmt + + case n.List_literal() != nil: + list, ok := n.List_literal().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.List_literal()) + } + return list + + // TODO: check other cases + default: + return todo("VisitAtom_expr", n) + } +} + func (c *cc) VisitUnary_subexpr(n *parser.Unary_subexprContext) interface{} { if n == nil || (n.Unary_casual_subexpr() == nil && n.Json_api_expr() == nil) { return todo("VisitUnary_subexpr", n) @@ -2769,7 +2909,7 @@ func (c *cc) processSuffixChain(base ast.Node, suffix *parser.Unary_subexpr_suff case *parser.Key_exprContext: current = c.handleKeySuffix(current, elem) case *parser.Invoke_exprContext: - current = c.handleInvokeSuffix(current, elem, i) + current = c.handleInvokeSuffix(current, elem) case antlr.TerminalNode: if elem.GetText() == "." { current = c.handleDotSuffix(current, suffix, &i) @@ -2806,7 +2946,7 @@ func (c *cc) handleKeySuffix(base ast.Node, keyCtx *parser.Key_exprContext) ast. } } -func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprContext, idx int) ast.Node { +func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprContext) ast.Node { temp, ok := invokeCtx.Accept(c).(ast.Node) if !ok { return todo("VisitInvoke_expr", invokeCtx) @@ -2816,48 +2956,50 @@ func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprCont return todo("VisitInvoke_expr", invokeCtx) } - if idx == 0 { - switch baseNode := base.(type) { - case *ast.ColumnRef: - if len(baseNode.Fields.Items) > 0 { - var nameParts []string - for _, item := range baseNode.Fields.Items { - if s, ok := item.(*ast.String); ok { - nameParts = append(nameParts, s.Str) - } + switch baseNode := base.(type) { + case *ast.ColumnRef: + if len(baseNode.Fields.Items) > 0 { + var nameParts []string + for _, item := range baseNode.Fields.Items { + if s, ok := item.(*ast.String); ok { + nameParts = append(nameParts, s.Str) } - funcName := strings.Join(nameParts, ".") + } + funcCall.Func = &ast.FuncName{} + if len(nameParts) == 2 { + funcCall.Func.Schema = nameParts[0] + funcCall.Func.Name = nameParts[1] + } else { + funcCall.Func.Name = strings.Join(nameParts, ".") + } - if funcName == "coalesce" || funcName == "nvl" { - return &ast.CoalesceExpr{ - Args: funcCall.Args, - Location: baseNode.Location, - } + if funcCall.Func.Name == "coalesce" || funcCall.Func.Name == "nvl" { + return &ast.CoalesceExpr{ + Args: funcCall.Args, + Location: baseNode.Location, } + } - if funcName == "greatest" || funcName == "max_of" { - return &ast.MinMaxExpr{ - Op: ast.MinMaxOp(1), - Args: funcCall.Args, - Location: baseNode.Location, - } + if funcCall.Func.Name == "greatest" || funcCall.Func.Name == "max_of" { + return &ast.MinMaxExpr{ + Op: ast.MinMaxOp(1), + Args: funcCall.Args, + Location: baseNode.Location, } - if funcName == "least" || funcName == "min_of" { - return &ast.MinMaxExpr{ - Op: ast.MinMaxOp(2), - Args: funcCall.Args, - Location: baseNode.Location, - } + } + if funcCall.Func.Name == "least" || funcCall.Func.Name == "min_of" { + return &ast.MinMaxExpr{ + Op: ast.MinMaxOp(2), + Args: funcCall.Args, + Location: baseNode.Location, } - - funcCall.Func = &ast.FuncName{Name: funcName} - funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: funcName}) - - return funcCall } - default: - return todo("VisitInvoke_expr", invokeCtx) + funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: funcCall.Func.Name}) + funcCall.Location = baseNode.Location + return funcCall } + default: + return todo("VisitInvoke_expr", invokeCtx) } stmt := &ast.FuncExpr{ @@ -3029,29 +3171,65 @@ func (c *cc) VisitAtom_expr(n *parser.Atom_exprContext) interface{} { } switch { - case n.An_id_or_type() != nil: - if n.NAMESPACE() != nil { - return NewIdentifier(parseAnIdOrType(n.An_id_or_type()) + "::" + parseIdOrType(n.Id_or_type())) - } - return NewIdentifier(parseAnIdOrType(n.An_id_or_type())) case n.Literal_value() != nil: expr, ok := n.Literal_value().Accept(c).(ast.Node) if !ok { return todo("VisitAtom_expr", n.Literal_value()) } return expr + case n.Bind_parameter() != nil: expr, ok := n.Bind_parameter().Accept(c).(ast.Node) if !ok { return todo("VisitAtom_expr", n.Bind_parameter()) } return expr + + case n.Lambda() != nil: + expr, ok := n.Lambda().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Lambda()) + } + return expr + case n.Cast_expr() != nil: expr, ok := n.Cast_expr().Accept(c).(ast.Node) if !ok { return todo("VisitAtom_expr", n.Cast_expr()) } return expr + + case n.Exists_expr() != nil: + return todo("VisitAtom_expr", n.Exists_expr()) + + case n.Case_expr() != nil: + return todo("VisitAtom_expr", n.Case_expr()) + + case n.An_id_or_type() != nil: + if n.NAMESPACE() != nil { + return NewIdentifier(parseAnIdOrType(n.An_id_or_type()) + "::" + parseIdOrType(n.Id_or_type())) + } + return NewIdentifier(parseAnIdOrType(n.An_id_or_type())) + + case n.Value_constructor() != nil: + return todo("VisitAtom_expr", n.Value_constructor()) + + case n.Bitcast_expr() != nil: + return todo("VisitAtom_expr", n.Bitcast_expr()) + + case n.List_literal() != nil: + list, ok := n.List_literal().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.List_literal()) + } + return list + + case n.Dict_literal() != nil: + return todo("VisitAtom_expr", n.Dict_literal()) + + case n.Struct_literal() != nil: + return todo("VisitAtom_expr", n.Struct_literal()) + // TODO: check other cases default: return todo("VisitAtom_expr", n) @@ -3067,7 +3245,7 @@ func (c *cc) VisitCast_expr(n *parser.Cast_exprContext) interface{} { if !ok { return todo("VisitCast_expr", n.Expr()) } - + temp, ok := n.Type_name_or_bind().Accept(c).(ast.Node) if !ok { return todo("VisitCast_expr", n.Type_name_or_bind()) @@ -3084,6 +3262,27 @@ func (c *cc) VisitCast_expr(n *parser.Cast_exprContext) interface{} { } } +func (c *cc) VisitList_literal(n *parser.List_literalContext) interface{} { + if n == nil || n.LBRACE_SQUARE() == nil || n.RBRACE_SQUARE() == nil || n.Expr_list() == nil { + return todo("VisitList_literal", n) + } + + array := &ast.A_ArrayExpr{ + Elements: &ast.List{}, + Location: c.pos(n.GetStart()), + } + + for _, item := range n.Expr_list().AllExpr() { + expr, ok := item.Accept(c).(ast.Node) + if !ok { + return todo("VisitList_literal", item) + } + array.Elements.Items = append(array.Elements.Items, expr) + } + + return array +} + func (c *cc) VisitLiteral_value(n *parser.Literal_valueContext) interface{} { if n == nil { return todo("VisitLiteral_value", n) @@ -3149,6 +3348,44 @@ func (c *cc) VisitLiteral_value(n *parser.Literal_valueContext) interface{} { } } +func (c *cc) VisitLambda(n *parser.LambdaContext) interface{} { + if n == nil || n.Smart_parenthesis() == nil { + return todo("VisitLambda", n) + } + + if n.ARROW() != nil { + log.Panicln("Lambda stmts are not supported in SQLC") + return todo("VisitLambda", n) + } + + lambdaBody, ok := n.Smart_parenthesis().Accept(c).(ast.Node) + if !ok { + return todo("VisitLambda", n.Smart_parenthesis()) + } + + return lambdaBody +} + +func (c *cc) VisitSmart_parenthesis(n *parser.Smart_parenthesisContext) interface{} { + if n == nil || n.Named_expr_list() == nil || n.LPAREN() == nil || n.RPAREN() == nil { + return todo("VisitSmart_parenthesis", n) + } + + var args ast.List + for _, namedExpr := range n.Named_expr_list().AllNamed_expr() { + expr, ok := namedExpr.Accept(c).(ast.Node) + if !ok { + return todo("VisitSmart_parenthesis", namedExpr) + } + args.Items = append(args.Items, expr) + } + + return &ast.A_ArrayExpr{ + Elements: &args, + Location: c.pos(n.GetStart()), + } +} + func (c *cc) VisitSql_stmt(n *parser.Sql_stmtContext) interface{} { if n == nil || n.Sql_stmt_core() == nil { return todo("VisitSql_stmt", n) diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index 9146d17e08..5f213ae6c4 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -102,7 +102,9 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, }) var replace string - if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { + if engine == config.EngineYDB { + replace = fmt.Sprintf("$%s", param.Name()) + } else if engine == config.EngineMySQL || engine == config.EngineSQLite || !dollar { if param.IsSqlcSlice() { // This sequence is also replicated in internal/codegen/golang.Field // since it's needed during template generation for replacement