diff --git a/.github/actions/exa-cluster/action.yaml b/.github/actions/exa-cluster/action.yaml index 615e00af..92800d13 100644 --- a/.github/actions/exa-cluster/action.yaml +++ b/.github/actions/exa-cluster/action.yaml @@ -8,12 +8,9 @@ inputs: description: "Number of nodes to spawn in the cluster" required: true outputs: - no-tls-url: - description: "Connection string for the database with TLS disabled" - value: ${{ steps.connection-strings.outputs.no-tls-url }} - tls-url: - description: "Connection string for the database with TLS enabled" - value: ${{ steps.connection-strings.outputs.tls-url }} + url: + description: "Connection string for the database" + value: ${{ steps.connection-string.outputs.url }} runs: using: "composite" steps: @@ -39,13 +36,38 @@ runs: - name: Initialize Exasol cluster run: | + ############################# + #### Create initial node #### + ############################# + docker run -v $HOME/sqlx1:/exa --rm --privileged -i exasol/docker-db:${{ inputs.exasol-version }} init-sc --template --num-nodes ${{ inputs.num-nodes }} docker run --rm -v $HOME/sqlx1:/exa exasol/docker-db:${{ inputs.exasol-version }} exaconf modify-volume -n DataVolume1 -s 4GiB sudo truncate -s 6G $HOME/sqlx1/data/storage/dev.1 + + #################### + #### Clone node #### + #################### for (( i=2; i<=${{ inputs.num-nodes }}; i++ )); do sudo cp -R $HOME/sqlx1 $HOME/sqlx$i done + + ################################ + #### Generate TLS 1.3 certs #### + ################################ + + sudo openssl genpkey -algorithm RSA -out ~/sqlx1/etc/ssl/ssl.ca.key -pkeyopt rsa_keygen_bits:2048 + sudo openssl genpkey -algorithm RSA -out ~/sqlx1/etc/ssl/ssl.key -pkeyopt rsa_keygen_bits:2048 + + echo "basicConstraints=CA:FALSE" > $HOME/v3.ext + echo "keyUsage=digitalSignature,keyEncipherment" >> $HOME/v3.ext + echo "extendedKeyUsage=clientAuth,serverAuth" >> $HOME/v3.ext + echo "subjectAltName=DNS:exacluster.local" >> $HOME/v3.ext + + sudo openssl req -x509 -new -nodes -key $HOME/sqlx1/etc/ssl/ssl.ca.key -sha256 -days 3650 -out $HOME/sqlx1/etc/ssl/ssl.ca -subj "/CN=exacluster.local" + sudo openssl req -new -key $HOME/sqlx1/etc/ssl/ssl.key -out $HOME/ssl.csr -subj "/CN=exacluster.local" + sudo openssl x509 -req -in $HOME/ssl.csr -CA $HOME/sqlx1/etc/ssl/ssl.ca -CAkey $HOME/sqlx1/etc/ssl/ssl.ca.key -CAcreateserial -out $HOME/sqlx1/etc/ssl/ssl.crt -days 365 -sha256 -extfile $HOME/v3.ext + shell: bash - name: Start Exasol cluster @@ -67,10 +89,7 @@ runs: done shell: bash - - name: Create connection strings - id: connection-strings - run: | - DATABASE_URL="exa://sys:exasol@10.10.10.11..1$NUM_NODES:8563" - echo "no-tls-url=$DATABASE_URL?ssl-mode=disabled" >> $GITHUB_OUTPUT - echo "tls-url=$DATABASE_URL?ssl-mode=required" >> $GITHUB_OUTPUT + - name: Create connection string + id: connection-string + run: echo "url=exa://sys:exasol@10.10.10.11..1$NUM_NODES:8563" >> $GITHUB_OUTPUT shell: bash diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 990bd938..f4a6fdd8 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -30,74 +30,43 @@ jobs: - name: Check format run: cargo +nightly fmt --check - clippy: - name: Clippy - needs: format + docs: + name: Docs runs-on: ubuntu-latest - strategy: - matrix: - etl: - [ - "--features etl_native_tls", - "--features etl_rustls", - "--features etl", - "", - ] - other: ["--features compression,migrate,rust_decimal,uuid,chrono", ""] steps: - uses: actions/checkout@v4 - uses: Swatinem/rust-cache@v2 + - uses: dtolnay/rust-toolchain@1.86.0 - - uses: dtolnay/rust-toolchain@1.85.0 - with: - components: clippy - - - name: Run clippy - run: cargo clippy --tests ${{ matrix.etl }} ${{ matrix.other }} - env: - RUSTFLAGS: -D warnings - - check_windows: - name: Check Windows builds - needs: clippy - runs-on: windows-latest - steps: - - uses: actions/checkout@v4 - - uses: Swatinem/rust-cache@v2 - - - uses: dtolnay/rust-toolchain@1.85.0 - with: - components: clippy - - - name: Run clippy - run: cargo clippy --tests --features compression,migrate,etl,rust_decimal,uuid,chrono + - name: Check format + run: cargo doc --workspace --no-deps --document-private-items --all-features env: - RUSTFLAGS: -D warnings - # See: https://aws.github.io/aws-lc-rs/resources.html#troubleshooting - AWS_LC_SYS_NO_ASM: 1 + RUSTDOCFLAGS: -D warnings - check_mac_os: - name: Check MacOS builds - needs: clippy - runs-on: macos-latest + clippy: + name: Clippy + needs: [format, docs] + runs-on: ubuntu-latest + strategy: + matrix: + features: ["--all-features", ""] steps: - uses: actions/checkout@v4 - uses: Swatinem/rust-cache@v2 - - uses: dtolnay/rust-toolchain@1.85.0 + - uses: dtolnay/rust-toolchain@1.86.0 with: components: clippy - name: Run clippy - run: cargo clippy --tests --features compression,migrate,etl,rust_decimal,uuid,chrono + run: cargo clippy --tests ${{ matrix.features }} env: RUSTFLAGS: -D warnings - - connection_tests: - name: Connection tests + + io_tests: + name: IO tests needs: clippy runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 @@ -112,26 +81,55 @@ jobs: exasol-version: ${{ env.EXASOL_VERSION }} num-nodes: ${{ env.NUM_NODES }} - - uses: dtolnay/rust-toolchain@1.85.0 + - uses: dtolnay/rust-toolchain@1.86.0 - uses: Swatinem/rust-cache@v2 - - name: Connection tests + - name: Test IO combos (no TLS, no compression) timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} - run: cargo test --features migrate,rust_decimal,uuid,chrono -- --nocapture + run: cargo test --features runtime-tokio,etl -- it_works_with_io_combo --ignored --nocapture env: - DATABASE_URL: ${{ steps.exa-cluster.outputs.no-tls-url }} - - - name: Connection tests with compression + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }} + SQLX_OFFLINE: true + + - name: Test IO combos (no TLS, with compression) + timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} + run: cargo test --features runtime-tokio,etl,compression -- it_works_with_io_combo --ignored --nocapture + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }} + SQLX_OFFLINE: true + + - name: Test IO combos (NativeTLS, no compression) + timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} + run: cargo test --features runtime-tokio,etl,tls-native-tls -- it_works_with_io_combo --ignored --nocapture + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }} + SQLX_OFFLINE: true + + - name: Test IO combos (NativeTLS, with compression) + timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} + run: cargo test --features runtime-tokio,etl,tls-native-tls,compression -- it_works_with_io_combo --ignored --nocapture + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }} + SQLX_OFFLINE: true + + - name: Test IO combos (Rustls, no compression) + timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} + run: cargo test --features runtime-tokio,etl,tls-rustls-aws-lc-rs -- it_works_with_io_combo --ignored --nocapture + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }} + SQLX_OFFLINE: true + + - name: Test IO combos (Rustls, with compression) timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} - run: cargo test --features migrate,compression -- --ignored --nocapture + run: cargo test --features runtime-tokio,etl,tls-rustls-aws-lc-rs,compression -- it_works_with_io_combo --ignored --nocapture env: - DATABASE_URL: ${{ steps.exa-cluster.outputs.no-tls-url }} + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }} + SQLX_OFFLINE: true - tls_connection_tests: - name: TLS connection tests + driver_tests: + name: Driver tests needs: clippy runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 @@ -146,26 +144,65 @@ jobs: exasol-version: ${{ env.EXASOL_VERSION }} num-nodes: ${{ env.NUM_NODES }} - - uses: dtolnay/rust-toolchain@1.85.0 + - uses: dtolnay/rust-toolchain@1.86.0 - uses: Swatinem/rust-cache@v2 + + - name: Drop database (non-existent) + run: cargo run -p sqlx-exasol-cli -- database drop -y + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test + + - name: Setup database + run: cargo run -p sqlx-exasol-cli -- database create + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test + + - name: Run migrations + run: cargo run -p sqlx-exasol-cli -- migrate run --source tests/migrations_compile_time + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test + + - name: Run prepare + run: cargo run -p sqlx-exasol-cli -- prepare -- --tests --all-features + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test - - name: TLS connection tests + - name: Run tests (preferred crates) + run: cargo test --features runtime-tokio,bigdecimal,time,uuid -- test_compile_time --ignored --nocapture + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test + + - name: Run tests (chrono) + run: cargo test --features runtime-tokio,chrono -- test_compile_time_chrono --ignored --nocapture + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test + + - name: Run tests (rust_decimal) + run: cargo test --features runtime-tokio,rust_decimal -- test_compile_time_rust_decimal --ignored --nocapture + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test + + - name: Unit tests timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} - run: cargo test --features migrate,rust_decimal,uuid,chrono -- --nocapture + run: cargo test -p sqlx-exasol-impl --features migrate -- --nocapture env: - DATABASE_URL: ${{ steps.exa-cluster.outputs.tls-url }} + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test - - name: TLS connection tests with compression + - name: Integration connection tests timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} - run: cargo test --features migrate,compression -- --ignored --nocapture + run: cargo test --all-features -- --nocapture + env: + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test?ssl-mode=disabled&compression=disabled + + - name: Drop database (existent) + run: cargo run -p sqlx-exasol-cli -- database drop -y env: - DATABASE_URL: ${{ steps.exa-cluster.outputs.tls-url }} + DATABASE_URL: ${{ steps.exa-cluster.outputs.url }}/test etl_tests: name: ETL tests needs: clippy runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 @@ -180,33 +217,32 @@ jobs: exasol-version: ${{ env.EXASOL_VERSION }} num-nodes: ${{ env.NUM_NODES }} - - uses: dtolnay/rust-toolchain@1.85.0 + - uses: dtolnay/rust-toolchain@1.86.0 - uses: Swatinem/rust-cache@v2 - - name: ETL tests + - name: ETL tests (no TLS) timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} - run: cargo test --features migrate,compression,etl -- --ignored --nocapture --test-threads `nproc` - env: - DATABASE_URL: ${{ steps.exa-cluster.outputs.no-tls-url }} - - - name: ETL without TLS feature but TLS connection (should fail) - run: cargo test --features migrate,etl -- --ignored --nocapture --test-threads `nproc` || true + run: | + export DATABASE_URL="$BASE_DATABASE_URL?ssl-mode=disabled" + cargo test --features runtime-tokio,compression,etl -- test_etl --ignored --nocapture env: - DATABASE_URL: ${{ steps.exa-cluster.outputs.tls-url }} - - - name: Tests compilation failure if both ETL TLS features are enabled - run: cargo test --features etl_native_tls,etl_rustls || true - env: - DATABASE_URL: ${{ steps.exa-cluster.outputs.tls-url }} - - - name: Native-TLS ETL tests + BASE_DATABASE_URL: ${{ steps.exa-cluster.outputs.url }} + SQLX_OFFLINE: true + + - name: ETL tests (NativeTLS) timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} - run: cargo test --features migrate,compression,etl_native_tls -- --ignored --nocapture --test-threads `nproc` + run: | + export DATABASE_URL="$BASE_DATABASE_URL?ssl-mode=required" + cargo test --features runtime-tokio,compression,etl,tls-native-tls -- test_etl --ignored --nocapture env: - DATABASE_URL: ${{ steps.exa-cluster.outputs.tls-url }} - - - name: Rustls ETL tests + BASE_DATABASE_URL: ${{ steps.exa-cluster.outputs.url }} + SQLX_OFFLINE: true + + - name: ETL tests (Rustls) timeout-minutes: ${{ fromJSON(env.TESTS_TIMEOUT) }} - run: cargo test --features migrate,compression,etl_rustls -- --ignored --nocapture --test-threads `nproc` + run: | + export DATABASE_URL="$BASE_DATABASE_URL?ssl-mode=required" + cargo test --features runtime-tokio,compression,etl,tls-rustls-aws-lc-rs -- test_etl --ignored --nocapture env: - DATABASE_URL: ${{ steps.exa-cluster.outputs.tls-url }} + BASE_DATABASE_URL: ${{ steps.exa-cluster.outputs.url }} + SQLX_OFFLINE: true diff --git a/.sqlx/query-055568969cccd119506aca28b76acd233df57fe24ac09745cac91093ba373e15.json b/.sqlx/query-055568969cccd119506aca28b76acd233df57fe24ac09745cac91093ba373e15.json new file mode 100644 index 00000000..dd9a26fa --- /dev/null +++ b/.sqlx/query-055568969cccd119506aca28b76acd233df57fe24ac09745cac91093ba373e15.json @@ -0,0 +1,24 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_i8 FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_i8", + "dataType": { + "type": "DECIMAL", + "precision": 3, + "scale": 0 + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "055568969cccd119506aca28b76acd233df57fe24ac09745cac91093ba373e15" +} diff --git a/.sqlx/query-062a5e324734a06c158364ca58ef70099acd533ec5fe97f594f3bf6436f65265.json b/.sqlx/query-062a5e324734a06c158364ca58ef70099acd533ec5fe97f594f3bf6436f65265.json new file mode 100644 index 00000000..43b83867 --- /dev/null +++ b/.sqlx/query-062a5e324734a06c158364ca58ef70099acd533ec5fe97f594f3bf6436f65265.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_char_ascii) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "CHAR", + "size": 16, + "characterSet": "ASCII" + } + ] + }, + "nullable": [] + }, + "hash": "062a5e324734a06c158364ca58ef70099acd533ec5fe97f594f3bf6436f65265" +} diff --git a/.sqlx/query-13b0544bc6ffe194e012de3ab2b3f7fbf1bea457355063c72f38cef3d4f15a75.json b/.sqlx/query-13b0544bc6ffe194e012de3ab2b3f7fbf1bea457355063c72f38cef3d4f15a75.json new file mode 100644 index 00000000..3c536d39 --- /dev/null +++ b/.sqlx/query-13b0544bc6ffe194e012de3ab2b3f7fbf1bea457355063c72f38cef3d4f15a75.json @@ -0,0 +1,23 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_geometry FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_geometry", + "dataType": { + "type": "GEOMETRY", + "srid": 0 + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "13b0544bc6ffe194e012de3ab2b3f7fbf1bea457355063c72f38cef3d4f15a75" +} diff --git a/.sqlx/query-261a60d57dcba09d6eb87d34075af62881f6ffa7bd84f59cf36a924e1491bf86.json b/.sqlx/query-261a60d57dcba09d6eb87d34075af62881f6ffa7bd84f59cf36a924e1491bf86.json new file mode 100644 index 00000000..67ef910f --- /dev/null +++ b/.sqlx/query-261a60d57dcba09d6eb87d34075af62881f6ffa7bd84f59cf36a924e1491bf86.json @@ -0,0 +1,23 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_interval_ytm FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_interval_ytm", + "dataType": { + "type": "INTERVAL YEAR TO MONTH", + "precision": 2 + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "261a60d57dcba09d6eb87d34075af62881f6ffa7bd84f59cf36a924e1491bf86" +} diff --git a/.sqlx/query-28e676d849d1616a5367689cd08b3c2e75485ca780ac68b8958655519a2e9c4e.json b/.sqlx/query-28e676d849d1616a5367689cd08b3c2e75485ca780ac68b8958655519a2e9c4e.json new file mode 100644 index 00000000..574dc658 --- /dev/null +++ b/.sqlx/query-28e676d849d1616a5367689cd08b3c2e75485ca780ac68b8958655519a2e9c4e.json @@ -0,0 +1,22 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_f64 FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_f64", + "dataType": { + "type": "DOUBLE" + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "28e676d849d1616a5367689cd08b3c2e75485ca780ac68b8958655519a2e9c4e" +} diff --git a/.sqlx/query-3668c47214ddace8738583bac9b3decc8732b53fc485e97236f9334aa2fd7695.json b/.sqlx/query-3668c47214ddace8738583bac9b3decc8732b53fc485e97236f9334aa2fd7695.json new file mode 100644 index 00000000..dcee825b --- /dev/null +++ b/.sqlx/query-3668c47214ddace8738583bac9b3decc8732b53fc485e97236f9334aa2fd7695.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_i64) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "DECIMAL", + "precision": 20, + "scale": 0 + } + ] + }, + "nullable": [] + }, + "hash": "3668c47214ddace8738583bac9b3decc8732b53fc485e97236f9334aa2fd7695" +} diff --git a/.sqlx/query-3a440d2f78336da17378656559d1111aeecb59f6d05c212c0e782e7b7934f6c8.json b/.sqlx/query-3a440d2f78336da17378656559d1111aeecb59f6d05c212c0e782e7b7934f6c8.json new file mode 100644 index 00000000..62eca999 --- /dev/null +++ b/.sqlx/query-3a440d2f78336da17378656559d1111aeecb59f6d05c212c0e782e7b7934f6c8.json @@ -0,0 +1,24 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_char_ascii FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_char_ascii", + "dataType": { + "type": "CHAR", + "size": 16, + "characterSet": "ASCII" + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "3a440d2f78336da17378656559d1111aeecb59f6d05c212c0e782e7b7934f6c8" +} diff --git a/.sqlx/query-3d7b290979aed5237a60b63986548fb3c56f294eca83b2352028492f1bbca75b.json b/.sqlx/query-3d7b290979aed5237a60b63986548fb3c56f294eca83b2352028492f1bbca75b.json new file mode 100644 index 00000000..02b86ba1 --- /dev/null +++ b/.sqlx/query-3d7b290979aed5237a60b63986548fb3c56f294eca83b2352028492f1bbca75b.json @@ -0,0 +1,16 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_date) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "DATE" + } + ] + }, + "nullable": [] + }, + "hash": "3d7b290979aed5237a60b63986548fb3c56f294eca83b2352028492f1bbca75b" +} diff --git a/.sqlx/query-3de26dc7f09e58f035e070de3abb78014ae692a51eceb2e7f54ead6605153e11.json b/.sqlx/query-3de26dc7f09e58f035e070de3abb78014ae692a51eceb2e7f54ead6605153e11.json new file mode 100644 index 00000000..12af51cc --- /dev/null +++ b/.sqlx/query-3de26dc7f09e58f035e070de3abb78014ae692a51eceb2e7f54ead6605153e11.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_geometry) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "VARCHAR", + "size": 2000000, + "characterSet": "UTF8" + } + ] + }, + "nullable": [] + }, + "hash": "3de26dc7f09e58f035e070de3abb78014ae692a51eceb2e7f54ead6605153e11" +} diff --git a/.sqlx/query-41b6bedd2a8b479ed69e612db316dbf9d6878594d1996a79e095fda38cb3a9af.json b/.sqlx/query-41b6bedd2a8b479ed69e612db316dbf9d6878594d1996a79e095fda38cb3a9af.json new file mode 100644 index 00000000..7a560837 --- /dev/null +++ b/.sqlx/query-41b6bedd2a8b479ed69e612db316dbf9d6878594d1996a79e095fda38cb3a9af.json @@ -0,0 +1,24 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_varchar_utf8 FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_varchar_utf8", + "dataType": { + "type": "VARCHAR", + "size": 16, + "characterSet": "UTF8" + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "41b6bedd2a8b479ed69e612db316dbf9d6878594d1996a79e095fda38cb3a9af" +} diff --git a/.sqlx/query-41e61e0dabe984c30c95013d0a1683696a6fcf21759cef358c48b253efaee977.json b/.sqlx/query-41e61e0dabe984c30c95013d0a1683696a6fcf21759cef358c48b253efaee977.json new file mode 100644 index 00000000..6da82eab --- /dev/null +++ b/.sqlx/query-41e61e0dabe984c30c95013d0a1683696a6fcf21759cef358c48b253efaee977.json @@ -0,0 +1,24 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_i32 FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_i32", + "dataType": { + "type": "DECIMAL", + "precision": 10, + "scale": 0 + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "41e61e0dabe984c30c95013d0a1683696a6fcf21759cef358c48b253efaee977" +} diff --git a/.sqlx/query-450079655abddaeec26a10d192cffc83489d22f870c0c6291269525b7fe3dcb0.json b/.sqlx/query-450079655abddaeec26a10d192cffc83489d22f870c0c6291269525b7fe3dcb0.json new file mode 100644 index 00000000..e1fd39de --- /dev/null +++ b/.sqlx/query-450079655abddaeec26a10d192cffc83489d22f870c0c6291269525b7fe3dcb0.json @@ -0,0 +1,22 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_date FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_date", + "dataType": { + "type": "DATE" + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "450079655abddaeec26a10d192cffc83489d22f870c0c6291269525b7fe3dcb0" +} diff --git a/.sqlx/query-5a75a4ca4710d23f95967028da7b3f556df2ca97272dee3e1b02f6c66f0bf587.json b/.sqlx/query-5a75a4ca4710d23f95967028da7b3f556df2ca97272dee3e1b02f6c66f0bf587.json new file mode 100644 index 00000000..39ddb537 --- /dev/null +++ b/.sqlx/query-5a75a4ca4710d23f95967028da7b3f556df2ca97272dee3e1b02f6c66f0bf587.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_i16) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "DECIMAL", + "precision": 5, + "scale": 0 + } + ] + }, + "nullable": [] + }, + "hash": "5a75a4ca4710d23f95967028da7b3f556df2ca97272dee3e1b02f6c66f0bf587" +} diff --git a/.sqlx/query-5e1cd1e73985ddd52fa73e8d29176dfe4c7c5007b6dfe9e63fdbb05979d708a8.json b/.sqlx/query-5e1cd1e73985ddd52fa73e8d29176dfe4c7c5007b6dfe9e63fdbb05979d708a8.json new file mode 100644 index 00000000..cad173de --- /dev/null +++ b/.sqlx/query-5e1cd1e73985ddd52fa73e8d29176dfe4c7c5007b6dfe9e63fdbb05979d708a8.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_varchar_utf8) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "VARCHAR", + "size": 16, + "characterSet": "UTF8" + } + ] + }, + "nullable": [] + }, + "hash": "5e1cd1e73985ddd52fa73e8d29176dfe4c7c5007b6dfe9e63fdbb05979d708a8" +} diff --git a/.sqlx/query-653ed8098c4599902a2af5d9f0bee1179e8e875a80446e04f6b7447f94eb2fd1.json b/.sqlx/query-653ed8098c4599902a2af5d9f0bee1179e8e875a80446e04f6b7447f94eb2fd1.json new file mode 100644 index 00000000..05514d19 --- /dev/null +++ b/.sqlx/query-653ed8098c4599902a2af5d9f0bee1179e8e875a80446e04f6b7447f94eb2fd1.json @@ -0,0 +1,23 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_uuid FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_uuid", + "dataType": { + "type": "HASHTYPE", + "size": 32 + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "653ed8098c4599902a2af5d9f0bee1179e8e875a80446e04f6b7447f94eb2fd1" +} diff --git a/.sqlx/query-6dbc5fa76ade6b4e760b40d0f98a2cd1022b35cdbeef78bc626c20a5d2bbb6eb.json b/.sqlx/query-6dbc5fa76ade6b4e760b40d0f98a2cd1022b35cdbeef78bc626c20a5d2bbb6eb.json new file mode 100644 index 00000000..b49422ec --- /dev/null +++ b/.sqlx/query-6dbc5fa76ade6b4e760b40d0f98a2cd1022b35cdbeef78bc626c20a5d2bbb6eb.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_decimal) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "DECIMAL", + "precision": 36, + "scale": 28 + } + ] + }, + "nullable": [] + }, + "hash": "6dbc5fa76ade6b4e760b40d0f98a2cd1022b35cdbeef78bc626c20a5d2bbb6eb" +} diff --git a/.sqlx/query-815b980e8e162a452678484dd753e341a4806e206599b5c5d1d08d4b7b66a8ad.json b/.sqlx/query-815b980e8e162a452678484dd753e341a4806e206599b5c5d1d08d4b7b66a8ad.json new file mode 100644 index 00000000..511c8535 --- /dev/null +++ b/.sqlx/query-815b980e8e162a452678484dd753e341a4806e206599b5c5d1d08d4b7b66a8ad.json @@ -0,0 +1,24 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_i64 FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_i64", + "dataType": { + "type": "DECIMAL", + "precision": 20, + "scale": 0 + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "815b980e8e162a452678484dd753e341a4806e206599b5c5d1d08d4b7b66a8ad" +} diff --git a/.sqlx/query-81b4ade0a5a7215dd3017b100782145518492f8b6f662018fcccfe66356777b9.json b/.sqlx/query-81b4ade0a5a7215dd3017b100782145518492f8b6f662018fcccfe66356777b9.json new file mode 100644 index 00000000..e30723ca --- /dev/null +++ b/.sqlx/query-81b4ade0a5a7215dd3017b100782145518492f8b6f662018fcccfe66356777b9.json @@ -0,0 +1,16 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_bool) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "BOOLEAN" + } + ] + }, + "nullable": [] + }, + "hash": "81b4ade0a5a7215dd3017b100782145518492f8b6f662018fcccfe66356777b9" +} diff --git a/.sqlx/query-8321407ac20b9c436b7df04c8e365c016777f44011bc9854a0e9e60e1398eeba.json b/.sqlx/query-8321407ac20b9c436b7df04c8e365c016777f44011bc9854a0e9e60e1398eeba.json new file mode 100644 index 00000000..f7c1ab83 --- /dev/null +++ b/.sqlx/query-8321407ac20b9c436b7df04c8e365c016777f44011bc9854a0e9e60e1398eeba.json @@ -0,0 +1,23 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_hashtype FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_hashtype", + "dataType": { + "type": "HASHTYPE", + "size": 30 + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "8321407ac20b9c436b7df04c8e365c016777f44011bc9854a0e9e60e1398eeba" +} diff --git a/.sqlx/query-840ff949e0c9f6fbdd705430b5b24956567b399761d6466933b5f026661a3319.json b/.sqlx/query-840ff949e0c9f6fbdd705430b5b24956567b399761d6466933b5f026661a3319.json new file mode 100644 index 00000000..7e2608e6 --- /dev/null +++ b/.sqlx/query-840ff949e0c9f6fbdd705430b5b24956567b399761d6466933b5f026661a3319.json @@ -0,0 +1,24 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_varchar_ascii FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_varchar_ascii", + "dataType": { + "type": "VARCHAR", + "size": 16, + "characterSet": "ASCII" + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "840ff949e0c9f6fbdd705430b5b24956567b399761d6466933b5f026661a3319" +} diff --git a/.sqlx/query-893ae1614d107a99ef9f206996dfa9b50676cf8c32860d07648b564c16a17af7.json b/.sqlx/query-893ae1614d107a99ef9f206996dfa9b50676cf8c32860d07648b564c16a17af7.json new file mode 100644 index 00000000..3ff42760 --- /dev/null +++ b/.sqlx/query-893ae1614d107a99ef9f206996dfa9b50676cf8c32860d07648b564c16a17af7.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_i32) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "DECIMAL", + "precision": 10, + "scale": 0 + } + ] + }, + "nullable": [] + }, + "hash": "893ae1614d107a99ef9f206996dfa9b50676cf8c32860d07648b564c16a17af7" +} diff --git a/.sqlx/query-91e31f8260972a96aad75a4438b2727a48449fd70fd28d44d2d9d67c5e0059ec.json b/.sqlx/query-91e31f8260972a96aad75a4438b2727a48449fd70fd28d44d2d9d67c5e0059ec.json new file mode 100644 index 00000000..9cca2500 --- /dev/null +++ b/.sqlx/query-91e31f8260972a96aad75a4438b2727a48449fd70fd28d44d2d9d67c5e0059ec.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_i8) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "DECIMAL", + "precision": 3, + "scale": 0 + } + ] + }, + "nullable": [] + }, + "hash": "91e31f8260972a96aad75a4438b2727a48449fd70fd28d44d2d9d67c5e0059ec" +} diff --git a/.sqlx/query-990ff677752a7d412718de81c7aec9df7ed783f6a57ea07f4ee615083afd9656.json b/.sqlx/query-990ff677752a7d412718de81c7aec9df7ed783f6a57ea07f4ee615083afd9656.json new file mode 100644 index 00000000..a204d2ca --- /dev/null +++ b/.sqlx/query-990ff677752a7d412718de81c7aec9df7ed783f6a57ea07f4ee615083afd9656.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_char_utf8) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "CHAR", + "size": 16, + "characterSet": "UTF8" + } + ] + }, + "nullable": [] + }, + "hash": "990ff677752a7d412718de81c7aec9df7ed783f6a57ea07f4ee615083afd9656" +} diff --git a/.sqlx/query-9c1c57fd0b7de481c4290d26f7e5b67cfcb55a1a66161cee89f670dd113f3c71.json b/.sqlx/query-9c1c57fd0b7de481c4290d26f7e5b67cfcb55a1a66161cee89f670dd113f3c71.json new file mode 100644 index 00000000..59a47690 --- /dev/null +++ b/.sqlx/query-9c1c57fd0b7de481c4290d26f7e5b67cfcb55a1a66161cee89f670dd113f3c71.json @@ -0,0 +1,17 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_hashtype) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "HASHTYPE", + "size": 30 + } + ] + }, + "nullable": [] + }, + "hash": "9c1c57fd0b7de481c4290d26f7e5b67cfcb55a1a66161cee89f670dd113f3c71" +} diff --git a/.sqlx/query-b67ec9400282c642ee1ca8d9ac7ee9977eb5099667c1ed999da283d010b43ed1.json b/.sqlx/query-b67ec9400282c642ee1ca8d9ac7ee9977eb5099667c1ed999da283d010b43ed1.json new file mode 100644 index 00000000..862337d0 --- /dev/null +++ b/.sqlx/query-b67ec9400282c642ee1ca8d9ac7ee9977eb5099667c1ed999da283d010b43ed1.json @@ -0,0 +1,16 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_f64) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "DOUBLE" + } + ] + }, + "nullable": [] + }, + "hash": "b67ec9400282c642ee1ca8d9ac7ee9977eb5099667c1ed999da283d010b43ed1" +} diff --git a/.sqlx/query-b7c21c8b73969b30dbf4768febd07a416bb655a1f96f20206abfec1fabc48f6a.json b/.sqlx/query-b7c21c8b73969b30dbf4768febd07a416bb655a1f96f20206abfec1fabc48f6a.json new file mode 100644 index 00000000..b2bdfa59 --- /dev/null +++ b/.sqlx/query-b7c21c8b73969b30dbf4768febd07a416bb655a1f96f20206abfec1fabc48f6a.json @@ -0,0 +1,24 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_decimal FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_decimal", + "dataType": { + "type": "DECIMAL", + "precision": 36, + "scale": 28 + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "b7c21c8b73969b30dbf4768febd07a416bb655a1f96f20206abfec1fabc48f6a" +} diff --git a/.sqlx/query-d365eb26b4c8f44b659383e2e16d116eda810ec34c915172f0981d90b94ca5c3.json b/.sqlx/query-d365eb26b4c8f44b659383e2e16d116eda810ec34c915172f0981d90b94ca5c3.json new file mode 100644 index 00000000..a4ce9801 --- /dev/null +++ b/.sqlx/query-d365eb26b4c8f44b659383e2e16d116eda810ec34c915172f0981d90b94ca5c3.json @@ -0,0 +1,17 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_interval_ytm) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "INTERVAL YEAR TO MONTH", + "precision": 2 + } + ] + }, + "nullable": [] + }, + "hash": "d365eb26b4c8f44b659383e2e16d116eda810ec34c915172f0981d90b94ca5c3" +} diff --git a/.sqlx/query-d8a7b48efe1f22930a701ec953a2acbc5b014d5e7020fe7188a4b647b8f589e7.json b/.sqlx/query-d8a7b48efe1f22930a701ec953a2acbc5b014d5e7020fe7188a4b647b8f589e7.json new file mode 100644 index 00000000..18a76c41 --- /dev/null +++ b/.sqlx/query-d8a7b48efe1f22930a701ec953a2acbc5b014d5e7020fe7188a4b647b8f589e7.json @@ -0,0 +1,24 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_i16 FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_i16", + "dataType": { + "type": "DECIMAL", + "precision": 5, + "scale": 0 + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "d8a7b48efe1f22930a701ec953a2acbc5b014d5e7020fe7188a4b647b8f589e7" +} diff --git a/.sqlx/query-e61fc2c50d09b4f2acf877477e262aff2b3a3914ee7aa12278adc53083ef81fa.json b/.sqlx/query-e61fc2c50d09b4f2acf877477e262aff2b3a3914ee7aa12278adc53083ef81fa.json new file mode 100644 index 00000000..093d6b77 --- /dev/null +++ b/.sqlx/query-e61fc2c50d09b4f2acf877477e262aff2b3a3914ee7aa12278adc53083ef81fa.json @@ -0,0 +1,17 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_uuid) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "HASHTYPE", + "size": 32 + } + ] + }, + "nullable": [] + }, + "hash": "e61fc2c50d09b4f2acf877477e262aff2b3a3914ee7aa12278adc53083ef81fa" +} diff --git a/.sqlx/query-e84e3b6f74283a5215059b5e88cfc7f2b875ccafe30391751d3e432a7c3e601d.json b/.sqlx/query-e84e3b6f74283a5215059b5e88cfc7f2b875ccafe30391751d3e432a7c3e601d.json new file mode 100644 index 00000000..5686697b --- /dev/null +++ b/.sqlx/query-e84e3b6f74283a5215059b5e88cfc7f2b875ccafe30391751d3e432a7c3e601d.json @@ -0,0 +1,22 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_bool FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_bool", + "dataType": { + "type": "BOOLEAN" + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "e84e3b6f74283a5215059b5e88cfc7f2b875ccafe30391751d3e432a7c3e601d" +} diff --git a/.sqlx/query-f0e049d4d9e125c5459855bc5a9f1c12e4809ab4d02e8879044c2c060214c7d8.json b/.sqlx/query-f0e049d4d9e125c5459855bc5a9f1c12e4809ab4d02e8879044c2c060214c7d8.json new file mode 100644 index 00000000..82c07491 --- /dev/null +++ b/.sqlx/query-f0e049d4d9e125c5459855bc5a9f1c12e4809ab4d02e8879044c2c060214c7d8.json @@ -0,0 +1,18 @@ +{ + "db_name": "Exasol", + "query": "INSERT INTO compile_time_tests (column_varchar_ascii) VALUES(?);", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "type": "VARCHAR", + "size": 16, + "characterSet": "ASCII" + } + ] + }, + "nullable": [] + }, + "hash": "f0e049d4d9e125c5459855bc5a9f1c12e4809ab4d02e8879044c2c060214c7d8" +} diff --git a/.sqlx/query-fd240e87b541a74dc85c6990ebf44eb91b458c4076d3e25049a9ad8ac706fafa.json b/.sqlx/query-fd240e87b541a74dc85c6990ebf44eb91b458c4076d3e25049a9ad8ac706fafa.json new file mode 100644 index 00000000..12e39fc9 --- /dev/null +++ b/.sqlx/query-fd240e87b541a74dc85c6990ebf44eb91b458c4076d3e25049a9ad8ac706fafa.json @@ -0,0 +1,24 @@ +{ + "db_name": "Exasol", + "query": "SELECT column_char_utf8 FROM compile_time_tests;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "column_char_utf8", + "dataType": { + "type": "CHAR", + "size": 16, + "characterSet": "UTF8" + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "fd240e87b541a74dc85c6990ebf44eb91b458c4076d3e25049a9ad8ac706fafa" +} diff --git a/.zed/settings.json b/.zed/settings.json index 76f75727..0f9cdcae 100644 --- a/.zed/settings.json +++ b/.zed/settings.json @@ -7,14 +7,7 @@ "rust-analyzer": { "initialization_options": { "cargo": { - "features": [ - "etl_native_tls", - "compression", - "uuid", - "chrono", - "rust_decimal", - "migrate" - ] + "features": "all" } } } diff --git a/Cargo.toml b/Cargo.toml index f4c67f7c..eb408e86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,92 +1,209 @@ -[package] -name = "sqlx-exasol" -version = "0.8.6-hotifx1" -edition = "2024" -authors = ["bobozaur"] -rust-version = "1.85.0" -description = "Exasol driver for the SQLx framework." +[workspace] +members = [".", "sqlx-exasol-cli", "sqlx-exasol-impl", "sqlx-exasol-macros"] + +[workspace.package] +version = "0.9.0-alpha.1" license = "MIT OR Apache-2.0" +edition = "2024" +rust-version = "1.86.0" repository = "https://github.com/bobozaur/sqlx-exasol" keywords = ["database", "sql", "exasol", "sqlx", "driver"] +categories = ["database", "asynchronous"] +authors = ["Bogdan Mircea "] + +[package] +name = "sqlx-exasol" +documentation = "https://docs.rs/sqlx-exasol" +description = "Exasol driver for the SQLx framework." +version.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +repository.workspace = true exclude = ["tests/*"] -categories = ["database"] [package.metadata.docs.rs] -features = ["etl", "chrono", "rust_decimal", "uuid", "compression", "migrate"] +features = [ + "etl", + "bigdecimal", + "chrono", + "rust_decimal", + "time", + "uuid", + "compression", + "migrate", + "macros", +] [features] -# ############################################ -# ########### User-facing features ########### -# ############################################ - -compression = ["dep:async-compression"] -uuid = ["dep:uuid"] -chrono = ["dep:chrono"] -rust_decimal = ["dep:rust_decimal"] -migrate = ["sqlx-core/migrate", "dep:dotenvy", "dep:hex"] -etl = ["dep:http-body-util", "dep:hyper", "dep:futures-channel"] -etl_rustls = ["dep:rustls", "dep:rcgen", "etl"] -etl_native_tls = ["dep:native-tls", "dep:rcgen", "etl"] - -# ############################################ -# ############################################ -# ############################################ +default = ["any", "macros", "migrate", "json"] -[dependencies] -arrayvec = "0.7" -async-tungstenite = "0.29" -base64 = "0.22" -futures-io = "0.3" -futures-util = "0.3" -futures-core = "0.3" -lru = "0.12" -rand = "0.8" -rsa = "0.9" -serde = { version = "1", features = ["derive", "rc"] } -serde_json = { version = "1", features = ["raw_value"] } -sqlx-core = "=0.8.6" -thiserror = "1" -tracing = { version = "0.1", features = ["log"] } -url = "2" - -# Feature flagged optional dependencies -async-compression = { version = "0.4", features = [ +derive = ["sqlx/derive", "sqlx-exasol-macros?/derive"] +macros = ["derive", "sqlx/macros", "sqlx-exasol-macros/macros"] +migrate = [ + "sqlx/migrate", + "sqlx-exasol-impl/migrate", + "sqlx-exasol-macros?/migrate", +] + +# Base runtime features without TLS +runtime-async-std = ["sqlx/runtime-async-std"] +runtime-tokio = ["sqlx/runtime-tokio"] + +# TLS features +tls-native-tls = ["sqlx/tls-native-tls", "sqlx-exasol-impl/native-tls"] +tls-rustls-aws-lc-rs = [ + "sqlx/tls-rustls-aws-lc-rs", + "sqlx-exasol-impl/rustls-aws-lc-rs", +] +tls-rustls-ring-webpki = [ + "sqlx/tls-rustls-ring-webpki", + "sqlx-exasol-impl/rustls-ring", +] +tls-rustls-ring-native-roots = [ + "sqlx/tls-rustls-ring-native-roots", + "sqlx-exasol-impl/rustls-ring", +] + +# Database +any = ["sqlx/any", "sqlx-exasol-impl/any"] + +# Types +bigdecimal = [ + "sqlx/bigdecimal", + "sqlx-exasol-impl/bigdecimal", + "sqlx-exasol-macros?/bigdecimal", +] +chrono = [ + "sqlx/chrono", + "sqlx-exasol-impl/chrono", + "sqlx-exasol-macros?/chrono", +] +geo-types = ["sqlx-exasol-impl/geo-types", "sqlx-exasol-macros?/geo-types"] +json = ["sqlx/json", "sqlx-exasol-impl/json", "sqlx-exasol-macros?/json"] +rust_decimal = [ + "sqlx/rust_decimal", + "sqlx-exasol-impl/rust_decimal", + "sqlx-exasol-macros?/rust_decimal", +] +time = ["sqlx/time", "sqlx-exasol-impl/time", "sqlx-exasol-macros?/time"] +uuid = ["sqlx/uuid", "sqlx-exasol-impl/uuid", "sqlx-exasol-macros?/uuid"] + +# Driver specific features +compression = ["sqlx-exasol-impl/compression"] +etl = ["sqlx-exasol-impl/etl"] + +[workspace.dependencies] +# Internal +sqlx-exasol = { path = "." } +sqlx-exasol-macros = { path = "sqlx-exasol-macros" } +sqlx-exasol-impl = { path = "sqlx-exasol-impl" } + +# External +anyhow = { version = "1", default-features = false } +arrayvec = { version = "0.7", default-features = false } +async-compression = { version = "0.4", default-features = false, features = [ "futures-io", "gzip", "zlib", -], optional = true } -chrono = { version = "0.4", features = ["serde"], optional = true } -dotenvy = { version = "0.15", optional = true } -futures-channel = { version = "0.3", features = ["sink"], optional = true } -hex = { version = "0.4", optional = true } -http-body-util = { version = "0.1", optional = true } -hyper = { version = "1.4", features = ["server", "http1"], optional = true } -native-tls = { version = "0.2", optional = true } -uuid = { version = "1", features = ["serde"], optional = true } -rcgen = { version = "0.13", optional = true } -rust_decimal = { version = "1", optional = true } +] } +async-tungstenite = { version = "0.29", default-features = false, features = [ + "handshake", + "futures-03-sink", +] } +base64 = { version = "0.22", default-features = false } +bigdecimal = { version = "0.4", default-features = false, features = [ + "std", + "serde-json", +] } +chrono = { version = "0.4", default-features = false, features = ["serde"] } +clap = { version = "4", default-features = false, features = ["derive"] } +console = { version = "0.15", default-features = false } +dotenvy = { version = "0.15", default-features = false } +futures-io = { version = "0.3", default-features = false } +futures-util = { version = "0.3", default-features = false } +futures-channel = { version = "0.3", default-features = false, features = [ + "sink", +] } +futures-core = { version = "0.3", default-features = false } +geo-types = { version = "0.7", default-features = false, features = ["std"] } +hex = { version = "0.4", default-features = false, features = ["std"] } +http-body-util = { version = "0.1", default-features = false } +hyper = { version = "1.4", default-features = false, features = [ + "server", + "http1", +] } +lru = { version = "0.12", default-features = false } +native-tls = { version = "0.2", default-features = false } +paste = { version = "1", default-features = false } +quote = { version = "1", default-features = false } +rand = { version = "0.8", default-features = false, features = [ + "std", + "std_rng", +] } +rcgen = { version = "0.13", default-features = false, features = ["pem"] } +rsa = { version = "0.9", default-features = false, features = ["pem", "std"] } +rust_decimal = { version = "1", default-features = false, features = ["serde"] } rustls = { version = "0.23", default-features = false, features = [ "std", "tls12", -], optional = true } - -[dev-dependencies] -anyhow = "1" -paste = "1" -rustls = "0.23" -sqlx = { version = "=0.8.6", features = [ - "runtime-tokio", - "tls-native-tls", +] } +serde = { version = "1", default-features = false, features = ["derive", "rc"] } +serde_json = { version = "1", default-features = false, features = [ + "std", + "raw_value", +] } +sqlx = { git = "https://github.com/launchbadge/sqlx", rev = "e77f32ea5e597d8dba6ad09b7722384ab6ed2d06", default-features = false } +sqlx-cli = { git = "https://github.com/launchbadge/sqlx", rev = "e77f32ea5e597d8dba6ad09b7722384ab6ed2d06", default-features = false } +sqlx-core = { git = "https://github.com/launchbadge/sqlx", rev = "e77f32ea5e597d8dba6ad09b7722384ab6ed2d06", default-features = false, features = [ + "offline", "migrate", ] } -tokio = { version = "1", features = ["full"] } +sqlx-macros-core = { git = "https://github.com/launchbadge/sqlx", rev = "e77f32ea5e597d8dba6ad09b7722384ab6ed2d06", default-features = false } +syn = { version = "2", default-features = false, features = [ + "parsing", + "proc-macro", +] } +thiserror = { version = "1", default-features = false } +time = { version = "0.3", default-features = false, features = [ + "std", + "serde", + "formatting", + "parsing", + "macros", +] } +tokio = { version = "1", default-features = false, features = [ + "rt-multi-thread", +] } +tracing = { version = "0.1", default-features = false, features = ["log"] } +url = { version = "2", default-features = false } +uuid = { version = "1", default-features = false, features = ["serde"] } +wkt = { version = "0.14", default-features = false, features = [ + "geo-types", + "serde", +] } + +[dependencies] +sqlx-exasol-macros = { workspace = true, optional = true } +sqlx-exasol-impl = { workspace = true } +sqlx = { workspace = true } + +[dev-dependencies] +anyhow = { workspace = true } +dotenvy = { workspace = true } +futures-util = { workspace = true } +paste = { workspace = true } +serde = { workspace = true } +time = { workspace = true } +url = { workspace = true } -[lints.clippy] +[workspace.lints.clippy] all = { level = "warn", priority = -1 } pedantic = { level = "warn", priority = -1 } module_name_repetitions = "allow" -[lints.rust] +[workspace.lints.rust] rust_2018_idioms = { level = "warn", priority = -1 } rust_2021_compatibility = { level = "warn", priority = -1 } meta_variable_misuse = "warn" diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..cf6d0f55 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "1.86.0" diff --git a/sqlx-exasol-cli/Cargo.toml b/sqlx-exasol-cli/Cargo.toml new file mode 100644 index 00000000..e7e8d71f --- /dev/null +++ b/sqlx-exasol-cli/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "sqlx-exasol-cli" +description = "Command-line utility for sqlx-exasol." +version.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors.workspace = true +default-run = "sqlx-exasol" + +[[bin]] +name = "sqlx-exasol" +path = "src/sqlx-exasol.rs" + +# enables invocation as `cargo sqlx-exasol`; required for `prepare` subcommand +[[bin]] +name = "cargo-sqlx-exasol" +path = "src/cargo-sqlx-exasol.rs" + +[features] +default = ["native-tls", "completions", "sqlx-toml"] + +rustls = ["sqlx-cli/rustls"] +native-tls = ["sqlx-cli/native-tls"] + +openssl-vendored = ["sqlx-cli/openssl-vendored"] + +completions = ["sqlx-cli/completions"] + +sqlx-toml = ["sqlx-cli/sqlx-toml"] + +[dependencies] +clap = { workspace = true } +console = { workspace = true } +sqlx-exasol = { workspace = true, features = [ + "runtime-tokio", + "migrate", + "any", +] } +sqlx-cli = { workspace = true } +tokio = { workspace = true } + +[lints] +workspace = true diff --git a/sqlx-exasol-cli/src/cargo-sqlx-exasol.rs b/sqlx-exasol-cli/src/cargo-sqlx-exasol.rs new file mode 100644 index 00000000..9ef0eb6f --- /dev/null +++ b/sqlx-exasol-cli/src/cargo-sqlx-exasol.rs @@ -0,0 +1,23 @@ +use clap::Parser; +use console::style; +use sqlx_cli::Opt; +use sqlx_exasol::any::DRIVER; + +/// Cargo invokes this binary as `cargo-sqlx-exasol sqlx-exasol ` +#[derive(Parser, Debug)] +#[clap(bin_name = "cargo")] +enum Cli { + SqlxExasol(Opt), +} + +#[tokio::main] +async fn main() { + sqlx_cli::maybe_apply_dotenv(); + sqlx_exasol::any::install_drivers(&[DRIVER]).expect("driver installation failed"); + let Cli::SqlxExasol(opt) = Cli::parse(); + + if let Err(error) = sqlx_cli::run(opt).await { + println!("{} {}", style("error:").bold().red(), error); + std::process::exit(1); + } +} diff --git a/sqlx-exasol-cli/src/sqlx-exasol.rs b/sqlx-exasol-cli/src/sqlx-exasol.rs new file mode 100644 index 00000000..86911bd0 --- /dev/null +++ b/sqlx-exasol-cli/src/sqlx-exasol.rs @@ -0,0 +1,16 @@ +use clap::Parser; +use console::style; +use sqlx_cli::Opt; +use sqlx_exasol::any::DRIVER; + +#[tokio::main] +async fn main() { + sqlx_cli::maybe_apply_dotenv(); + sqlx_exasol::any::install_drivers(&[DRIVER]).expect("driver installation failed"); + let opt = Opt::parse(); + + if let Err(error) = sqlx_cli::run(opt).await { + println!("{} {}", style("error:").bold().red(), error); + std::process::exit(1); + } +} diff --git a/sqlx-exasol-impl/Cargo.toml b/sqlx-exasol-impl/Cargo.toml new file mode 100644 index 00000000..82f94ec3 --- /dev/null +++ b/sqlx-exasol-impl/Cargo.toml @@ -0,0 +1,84 @@ +[package] +name = "sqlx-exasol-impl" +description = "Driver implementation for sqlx-exasol. Not meant to be used directly." +version.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors.workspace = true + +[features] +# SQLx inherited features +any = ["sqlx-core/any"] +migrate = ["sqlx-core/migrate", "dep:dotenvy", "dep:hex"] +macros = ["dep:sqlx-macros-core", "sqlx-macros-core?/macros"] + +# Type Integration features +bigdecimal = ["sqlx-core/bigdecimal", "dep:bigdecimal"] +chrono = ["sqlx-core/chrono", "dep:chrono"] +geo-types = ["dep:geo-types", "dep:wkt"] +json = ["sqlx-core/json"] +rust_decimal = ["sqlx-core/rust_decimal", "dep:rust_decimal"] +time = ["sqlx-core/time", "dep:time"] +uuid = ["sqlx-core/uuid", "dep:uuid"] + +# TLS features +tls = [] +native-tls = [ + "dep:native-tls", + "sqlx-core/_tls-native-tls", + "tls", + "rcgen?/aws_lc_rs", +] +rustls-aws-lc-rs = ["rustls", "rcgen?/aws_lc_rs"] +rustls-ring = ["rustls", "rcgen?/ring"] +rustls = ["dep:rustls", "sqlx-core/_tls-rustls", "tls"] + +# Driver specific features +compression = ["dep:async-compression"] +etl = ["dep:http-body-util", "dep:hyper", "dep:futures-channel", "dep:rcgen"] + +[dependencies] +arrayvec = { workspace = true } +async-tungstenite = { workspace = true } +base64 = { workspace = true } +futures-io = { workspace = true } +futures-util = { workspace = true } +futures-core = { workspace = true } +lru = { workspace = true } +rand = { workspace = true } +rsa = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sqlx-core = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } +url = { workspace = true } + +# Feature flagged optional dependencies +async-compression = { workspace = true, optional = true } +bigdecimal = { workspace = true, optional = true } +chrono = { workspace = true, optional = true } +dotenvy = { workspace = true, optional = true } +futures-channel = { workspace = true, optional = true } +geo-types = { workspace = true, optional = true } +hex = { workspace = true, optional = true } +http-body-util = { workspace = true, optional = true } +hyper = { workspace = true, optional = true } +native-tls = { workspace = true, optional = true } +rcgen = { workspace = true, optional = true } +rustls = { workspace = true, optional = true } +rust_decimal = { workspace = true, optional = true } +sqlx-macros-core = { workspace = true, optional = true } +time = { workspace = true, optional = true } +uuid = { workspace = true, optional = true } +wkt = { workspace = true, optional = true } + +[dev-dependencies] +sqlx = { workspace = true, features = ["runtime-tokio", "macros", "migrate"] } + +[lints] +workspace = true diff --git a/sqlx-exasol-impl/src/any.rs b/sqlx-exasol-impl/src/any.rs new file mode 100644 index 00000000..e537efe7 --- /dev/null +++ b/sqlx-exasol-impl/src/any.rs @@ -0,0 +1,316 @@ +use std::future; + +use futures_core::{future::BoxFuture, stream::BoxStream}; +use futures_util::{stream, FutureExt, StreamExt, TryStreamExt}; +use sqlx_core::{ + any::{ + Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, + AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, AnyValue, AnyValueKind, + }, + arguments::Arguments, + column::Column, + connection::{ConnectOptions, Connection}, + database::Database, + decode::Decode, + describe::Describe, + error::BoxDynError, + executor::Executor, + logger::QueryLogger, + row::Row, + sql_str::SqlStr, + transaction::TransactionManager, + value::ValueRef, + Either, +}; + +use crate::{ + connection::{ + stream::ResultStream, + websocket::future::{Execute, ExecuteBatch, ExecutePrepared}, + }, + type_info::ExaDataType, + ExaArguments, ExaColumn, ExaConnectOptions, ExaConnection, ExaQueryResult, ExaRow, + ExaTransactionManager, ExaTypeInfo, ExaValueRef, Exasol, SqlxError, SqlxResult, +}; + +sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Exasol); + +impl AnyConnectionBackend for ExaConnection { + fn name(&self) -> &str { + ::NAME + } + + fn close(self: Box) -> BoxFuture<'static, SqlxResult<()>> { + Connection::close(*self).boxed() + } + + fn close_hard(self: Box) -> BoxFuture<'static, SqlxResult<()>> { + Connection::close_hard(*self).boxed() + } + + fn ping(&mut self) -> BoxFuture<'_, SqlxResult<()>> { + Connection::ping(self).boxed() + } + + fn begin(&mut self, statement: Option) -> BoxFuture<'_, SqlxResult<()>> { + ExaTransactionManager::begin(self, statement).boxed() + } + + fn commit(&mut self) -> BoxFuture<'_, SqlxResult<()>> { + ExaTransactionManager::commit(self).boxed() + } + + fn rollback(&mut self) -> BoxFuture<'_, SqlxResult<()>> { + ExaTransactionManager::rollback(self).boxed() + } + + fn start_rollback(&mut self) { + ExaTransactionManager::start_rollback(self); + } + + fn get_transaction_depth(&self) -> usize { + ExaTransactionManager::get_transaction_depth(self) + } + + fn shrink_buffers(&mut self) { + Connection::shrink_buffers(self); + } + + fn flush(&mut self) -> BoxFuture<'_, SqlxResult<()>> { + Connection::flush(self).boxed() + } + + fn should_flush(&self) -> bool { + Connection::should_flush(self) + } + + #[cfg(feature = "migrate")] + fn as_migrate( + &mut self, + ) -> SqlxResult<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> { + Ok(self) + } + + fn fetch_many<'q>( + &'q mut self, + sql: SqlStr, + persistent: bool, + arguments: Option>, + ) -> BoxStream<'q, SqlxResult>> { + let logger = QueryLogger::new(sql, self.log_settings.clone()); + let sql = logger.sql().clone(); + + let arguments = match arguments.as_ref().map(convert_arguments_to).transpose() { + Ok(arguments) => arguments, + Err(error) => { + return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed() + } + }; + + let filter_fn = |step| async move { + match step { + Either::Left(qr) => Ok(Some(Either::Left(map_result(qr)))), + Either::Right(row) => AnyRow::try_from(&row).map(Either::Right).map(Some), + } + }; + + if let Some(arguments) = arguments { + let future = ExecutePrepared::new(sql, persistent, arguments); + ResultStream::new(&mut self.ws, logger, future) + .try_filter_map(filter_fn) + .boxed() + } else { + let future = ExecuteBatch::new(sql); + ResultStream::new(&mut self.ws, logger, future) + .try_filter_map(filter_fn) + .boxed() + } + } + + fn fetch_optional<'q>( + &'q mut self, + sql: SqlStr, + persistent: bool, + arguments: Option>, + ) -> BoxFuture<'q, SqlxResult>> { + let logger = QueryLogger::new(sql, self.log_settings.clone()); + let sql = logger.sql().clone(); + + let arguments = arguments + .as_ref() + .map(convert_arguments_to) + .transpose() + .map_err(sqlx_core::Error::Encode); + + Box::pin(async move { + let arguments = arguments?; + + let mut stream = if let Some(arguments) = arguments { + let future = ExecutePrepared::new(sql, persistent, arguments); + ResultStream::new(&mut self.ws, logger, future) + } else { + let future = Execute::new(sql); + ResultStream::new(&mut self.ws, logger, future) + }; + + while let Some(result) = stream.try_next().await? { + if let Either::Right(row) = result { + return Ok(Some(AnyRow::try_from(&row)?)); + } + } + + Ok(None) + }) + } + + #[expect(unused_lifetimes, reason = "recent trait changes")] + fn prepare_with<'c, 'q: 'c>( + &'c mut self, + sql: SqlStr, + _parameters: &[AnyTypeInfo], + ) -> BoxFuture<'c, SqlxResult> { + Box::pin(async move { + let statement = Executor::prepare_with(self, sql, &[]).await?; + let column_names = statement.metadata.column_names.clone(); + AnyStatement::try_from_statement(statement, column_names) + }) + } + + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, SqlxResult>> { + Box::pin(async move { + let describe = Executor::describe(self, sql).await?; + describe.try_into_any() + }) + } +} + +impl<'a> TryFrom<&'a ExaTypeInfo> for AnyTypeInfo { + type Error = SqlxError; + + fn try_from(type_info: &'a ExaTypeInfo) -> Result { + Ok(AnyTypeInfo { + kind: match &type_info.data_type { + ExaDataType::Boolean => AnyTypeInfoKind::Bool, + ExaDataType::Decimal(_) => AnyTypeInfoKind::BigInt, + ExaDataType::Double => AnyTypeInfoKind::Double, + ExaDataType::Char { .. } | ExaDataType::Varchar { .. } => AnyTypeInfoKind::Text, + _ => { + return Err(sqlx_core::Error::AnyDriverError( + format!("Any driver does not support Exasol type {type_info:?}").into(), + )) + } + }, + }) + } +} + +impl<'a> TryFrom<&'a ExaColumn> for AnyColumn { + type Error = sqlx_core::Error; + + fn try_from(column: &'a ExaColumn) -> Result { + let type_info = AnyTypeInfo::try_from(&column.data_type)?; + + Ok(AnyColumn { + ordinal: column.ordinal, + name: column.name.to_string().into(), + type_info, + }) + } +} + +impl<'a> TryFrom<&'a ExaRow> for AnyRow { + type Error = sqlx_core::Error; + + fn try_from(row: &'a ExaRow) -> Result { + fn decode<'r, T: Decode<'r, Exasol>>(valueref: ExaValueRef<'r>) -> SqlxResult { + Decode::decode(valueref).map_err(SqlxError::decode) + } + + let mut row_out = AnyRow { + column_names: row.column_names.clone(), + columns: Vec::with_capacity(row.columns().len()), + values: Vec::with_capacity(row.columns().len()), + }; + + for col in row.columns() { + let i = col.ordinal(); + + let any_col = AnyColumn::try_from(col)?; + + let value = row.try_get_raw(i)?; + + // Map based on the _value_ type info, not the column type info. + let type_info = AnyTypeInfo::try_from(value.type_info().as_ref()).map_err(|e| { + SqlxError::ColumnDecode { + index: col.ordinal().to_string(), + source: e.into(), + } + })?; + + let value_kind = match type_info.kind { + k if value.is_null() => AnyValueKind::Null(k), + AnyTypeInfoKind::Null => AnyValueKind::Null(AnyTypeInfoKind::Null), + AnyTypeInfoKind::Bool => AnyValueKind::Bool(decode(value)?), + AnyTypeInfoKind::SmallInt => AnyValueKind::SmallInt(decode(value)?), + AnyTypeInfoKind::Integer => AnyValueKind::Integer(decode(value)?), + AnyTypeInfoKind::BigInt => AnyValueKind::BigInt(decode(value)?), + AnyTypeInfoKind::Double => AnyValueKind::Double(decode(value)?), + AnyTypeInfoKind::Text => AnyValueKind::Text(decode::(value)?.into()), + a => Err(SqlxError::decode(format!( + "unsupported data type {a:?} by the `any` driver" + )))?, + }; + + row_out.columns.push(any_col); + row_out.values.push(AnyValue { kind: value_kind }); + } + + Ok(row_out) + } +} + +impl<'a> TryFrom<&'a AnyConnectOptions> for ExaConnectOptions { + type Error = sqlx_core::Error; + + fn try_from(any_opts: &'a AnyConnectOptions) -> Result { + let mut opts = Self::from_url(&any_opts.database_url)?; + opts.log_settings = any_opts.log_settings.clone(); + Ok(opts) + } +} + +fn map_result(result: ExaQueryResult) -> AnyQueryResult { + AnyQueryResult { + rows_affected: result.rows_affected(), + last_insert_id: None, + } +} + +fn convert_arguments_to<'q, 'a>(args: &'a AnyArguments<'q>) -> Result +where + 'q: 'a, +{ + let mut out = ExaArguments::default(); + + for arg in &args.values.0 { + match arg { + AnyValueKind::Null(AnyTypeInfoKind::Null) => out.add(Option::::None), /* data type does not matter here */ + AnyValueKind::Null(AnyTypeInfoKind::Bool) => out.add(Option::::None), + AnyValueKind::Null(AnyTypeInfoKind::SmallInt) => out.add(Option::::None), + AnyValueKind::Null(AnyTypeInfoKind::Integer) => out.add(Option::::None), + AnyValueKind::Null(AnyTypeInfoKind::BigInt) => out.add(Option::::None), + AnyValueKind::Null(AnyTypeInfoKind::Real) => out.add(Option::::None), + AnyValueKind::Null(AnyTypeInfoKind::Text) => out.add(Option::::None), + AnyValueKind::Bool(b) => out.add(b), + AnyValueKind::SmallInt(i) => out.add(i), + AnyValueKind::Integer(i) => out.add(i), + AnyValueKind::BigInt(i) => out.add(i), + AnyValueKind::Double(d) => out.add(d), + AnyValueKind::Text(t) => out.add(&**t), + a => Err(format!( + "Exasol does not support `any` driver datatype {a:?}" + ))?, + }?; + } + Ok(out) +} diff --git a/sqlx-exasol-impl/src/arguments/geo_types.rs b/sqlx-exasol-impl/src/arguments/geo_types.rs new file mode 100644 index 00000000..a0c8b764 --- /dev/null +++ b/sqlx-exasol-impl/src/arguments/geo_types.rs @@ -0,0 +1,28 @@ +use geo_types::{CoordNum, Geometry}; +use wkt::ToWkt; + +use crate::arguments::ExaBuffer; + +impl ExaBuffer { + /// Serializes a [`geo_types::Geometry`] value as a WKT string. + pub fn append_geometry(&mut self, value: &Geometry) -> std::io::Result<()> + where + T: CoordNum + std::fmt::Display, + { + self.col_params_counter += 1; + + // SAFETY: `serde_json` will only write valid UTF-8. + let writer = unsafe { self.buffer.as_mut_vec() }; + + // Open geometry string + writer.push(b'"'); + + // Serialize geometry data + let res = value.write_wkt(&mut *writer); + + // Close geometry string + writer.push(b'"'); + + res + } +} diff --git a/sqlx-exasol-impl/src/arguments/json.rs b/sqlx-exasol-impl/src/arguments/json.rs new file mode 100644 index 00000000..ef57c349 --- /dev/null +++ b/sqlx-exasol-impl/src/arguments/json.rs @@ -0,0 +1,123 @@ +use serde::Serialize; +use serde_json::{ser::CharEscape, Error as JsonError, Serializer}; + +use crate::arguments::ExaBuffer; + +impl ExaBuffer { + /// Serializes a JSON value as a string containing JSON. + pub fn append_json(&mut self, value: T) -> Result<(), JsonError> + where + T: Serialize, + { + self.col_params_counter += 1; + + // SAFETY: `serde_json` will only write valid UTF-8. + let writer = unsafe { self.buffer.as_mut_vec() }; + + // Open the string containing JSON + writer.push(b'"'); + + // Serialize + let mut serializer = Serializer::with_formatter(&mut *writer, ExaJsonStrFormatter); + let res = value.serialize(&mut serializer); + + // Close string containing JSON + writer.push(b'"'); + + res + } +} + +/// Used to create strings containing JSON by double escaping special characters in a single +/// serialization pass. +struct ExaJsonStrFormatter; + +impl serde_json::ser::Formatter for ExaJsonStrFormatter { + /// Escapes the string beginning + fn begin_string(&mut self, writer: &mut W) -> std::io::Result<()> + where + W: ?Sized + std::io::Write, + { + writer.write_all(br#"\""#) + } + + /// Escapes the string end + fn end_string(&mut self, writer: &mut W) -> std::io::Result<()> + where + W: ?Sized + std::io::Write, + { + writer.write_all(br#"\""#) + } + + /// Escapes the special character and its inherent escape. + fn write_char_escape( + &mut self, + writer: &mut W, + char_escape: CharEscape, + ) -> std::io::Result<()> + where + W: ?Sized + std::io::Write, + { + #[allow(clippy::needless_raw_string_hashes, reason = "false positive")] + let s: &[u8] = match char_escape { + CharEscape::Quote => br#"\\\""#, + CharEscape::ReverseSolidus => br#"\\\\"#, + CharEscape::Solidus => br#"\\/"#, + CharEscape::Backspace => br#"\\b"#, + CharEscape::FormFeed => br#"\\f"#, + CharEscape::LineFeed => br#"\\n"#, + CharEscape::CarriageReturn => br#"\\r"#, + CharEscape::Tab => br#"\\t"#, + CharEscape::AsciiControl(byte) => { + static HEX_DIGITS: [u8; 16] = *b"0123456789abcdef"; + let bytes = &[ + b'\\', + b'\\', + b'u', + b'0', + b'0', + HEX_DIGITS[(byte >> 4) as usize], + HEX_DIGITS[(byte & 0xF) as usize], + ]; + return writer.write_all(bytes); + } + }; + + writer.write_all(s) + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + + #[test] + fn test_json_string() { + let mut string = String::new(); + for i in 0..15u8 { + string.push(' '); + string.push(i as char); + } + + let res = json!({ + "field1": 1, + "field2": string + }); + + let mut provided = String::new(); + + let writer = unsafe { provided.as_mut_vec() }; + writer.push(b'"'); + let mut serializer = Serializer::with_formatter(&mut *writer, ExaJsonStrFormatter); + res.serialize(&mut serializer).unwrap(); + writer.push(b'"'); + + let expected = serde_json::to_string(&res) + .and_then(|s| serde_json::to_string(&s)) + .unwrap(); + + assert_eq!(expected, provided); + } +} diff --git a/src/arguments.rs b/sqlx-exasol-impl/src/arguments/mod.rs similarity index 97% rename from src/arguments.rs rename to sqlx-exasol-impl/src/arguments/mod.rs index 626e1608..2eb3343f 100644 --- a/src/arguments.rs +++ b/sqlx-exasol-impl/src/arguments/mod.rs @@ -1,3 +1,8 @@ +#[cfg(feature = "geo-types")] +mod geo_types; +#[cfg(feature = "json")] +mod json; + use serde::Serialize; use serde_json::Error as SerdeError; use sqlx_core::{arguments::Arguments, encode::Encode, error::BoxDynError, types::Type}; @@ -77,7 +82,7 @@ impl ExaBuffer { pub fn append_iter<'q, I, T>(&mut self, iter: I) -> Result<(), BoxDynError> where I: IntoIterator, - T: 'q + Encode<'q, Exasol> + Type, + T: 'q + Encode<'q, Exasol>, { let mut iter = iter.into_iter(); @@ -137,7 +142,7 @@ impl ExaBuffer { Some(n) if n == count => (), Some(n) => Err(ExaProtocolError::ParameterLengthMismatch(count, n))?, None => self.first_col_params_num = Some(count), - }; + } Ok(()) } diff --git a/sqlx-exasol-impl/src/column.rs b/sqlx-exasol-impl/src/column.rs new file mode 100644 index 00000000..26f1ed05 --- /dev/null +++ b/sqlx-exasol-impl/src/column.rs @@ -0,0 +1,59 @@ +use std::{borrow::Cow, fmt::Display}; + +use serde::{Deserialize, Deserializer, Serialize}; +use sqlx_core::{column::Column, database::Database, ext::ustr::UStr}; + +use crate::{database::Exasol, type_info::ExaTypeInfo}; + +/// Implementor of [`Column`]. +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ExaColumn { + #[serde(default)] + pub(crate) ordinal: usize, + #[serde(deserialize_with = "ExaColumn::lowercase_name")] + pub(crate) name: UStr, + pub(crate) data_type: ExaTypeInfo, +} + +impl ExaColumn { + fn lowercase_name<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + /// Intermediate type used to take advantage of [`Cow`] borrowed deserialization when + /// possible. + /// + /// In regular usage, an owned buffer is used and a borrowed [`str`] could be + /// used, but for offline query deserialization a reader seems to be used and the buffer is + /// shortlived, hence a string slice would fail deserialization. + #[derive(Deserialize)] + struct CowStr<'a>(#[serde(borrow)] Cow<'a, str>); + + CowStr::deserialize(deserializer) + .map(|c| c.0.to_lowercase()) + .map(From::from) + } +} + +impl Display for ExaColumn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.name, self.data_type) + } +} + +impl Column for ExaColumn { + type Database = Exasol; + + fn ordinal(&self) -> usize { + self.ordinal + } + + fn name(&self) -> &str { + &self.name + } + + fn type_info(&self) -> &::TypeInfo { + &self.data_type + } +} diff --git a/src/connection/etl/error.rs b/sqlx-exasol-impl/src/connection/etl/error.rs similarity index 100% rename from src/connection/etl/error.rs rename to sqlx-exasol-impl/src/connection/etl/error.rs diff --git a/src/connection/etl/export/compression.rs b/sqlx-exasol-impl/src/connection/etl/export/compression.rs similarity index 77% rename from src/connection/etl/export/compression.rs rename to sqlx-exasol-impl/src/connection/etl/export/compression.rs index 72f4c6e7..3d5d710c 100644 --- a/src/connection/etl/export/compression.rs +++ b/sqlx-exasol-impl/src/connection/etl/export/compression.rs @@ -21,18 +21,18 @@ pub enum ExaExportReader { } impl ExaExportReader { + #[allow(unused_variables, reason = "conditionally compiled")] pub fn new(socket: ExaSocket, with_compression: bool) -> Self { let reader = ExaReader::new(socket); - match with_compression { - #[cfg(feature = "compression")] - true => { - let mut reader = GzipDecoder::new(reader); - reader.multiple_members(true); - Self::Compressed(reader) - } - _ => Self::Plain(reader), + #[cfg(feature = "compression")] + if with_compression { + let mut reader = GzipDecoder::new(reader); + reader.multiple_members(true); + return Self::Compressed(reader); } + + Self::Plain(reader) } } diff --git a/src/connection/etl/export/mod.rs b/sqlx-exasol-impl/src/connection/etl/export/mod.rs similarity index 99% rename from src/connection/etl/export/mod.rs rename to sqlx-exasol-impl/src/connection/etl/export/mod.rs index 985b0354..f94ebbd9 100644 --- a/src/connection/etl/export/mod.rs +++ b/sqlx-exasol-impl/src/connection/etl/export/mod.rs @@ -43,7 +43,7 @@ impl AsyncRead for ExaExport { let reader = ExaExportReader::new(socket, *with_compression); self.set(Self(ExaExportState::Poll(reader))); } - }; + } } } } diff --git a/src/connection/etl/export/options.rs b/sqlx-exasol-impl/src/connection/etl/export/options.rs similarity index 99% rename from src/connection/etl/export/options.rs rename to sqlx-exasol-impl/src/connection/etl/export/options.rs index e6f3794e..799fd442 100644 --- a/src/connection/etl/export/options.rs +++ b/sqlx-exasol-impl/src/connection/etl/export/options.rs @@ -150,7 +150,7 @@ impl EtlJob for ExportBuilder<'_> { query.push_str(qr); query.push_str("\n)"); } - }; + } query.push(' '); diff --git a/src/connection/etl/export/reader.rs b/sqlx-exasol-impl/src/connection/etl/export/reader.rs similarity index 95% rename from src/connection/etl/export/reader.rs rename to sqlx-exasol-impl/src/connection/etl/export/reader.rs index ec1dfacd..7803d5a6 100644 --- a/src/connection/etl/export/reader.rs +++ b/sqlx-exasol-impl/src/connection/etl/export/reader.rs @@ -1,4 +1,5 @@ use std::{ + fmt, future::Future, io, pin::Pin, @@ -73,7 +74,7 @@ impl Future for ExportFuture { match ready!(self.0.poll_unpin(cx)) { Ok(()) => (), Err(e) => return Poll::Ready(Err(e)), - }; + } let response = Response::builder() .status(StatusCode::OK) @@ -119,7 +120,6 @@ impl Stream for ReaderStream { } } -#[derive(Debug)] pub struct ExaReader { reader: ExportReader, conn: ExportConnection, @@ -229,6 +229,18 @@ impl ExaReader { } } +// Not derived so that the reader field is ignored as it's mainly just a data buffer that takes a +// lot of space. +impl fmt::Debug for ExaReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ExaReader { conn, state, .. } = self; + f.debug_struct("ExaReader") + .field("conn", &conn) + .field("state", &state) + .finish() + } +} + #[derive(Debug)] enum ExaReaderState { Reading, diff --git a/src/connection/etl/import/compression.rs b/sqlx-exasol-impl/src/connection/etl/import/compression.rs similarity index 88% rename from src/connection/etl/import/compression.rs rename to sqlx-exasol-impl/src/connection/etl/import/compression.rs index 4dd86717..df8aaa28 100644 --- a/src/connection/etl/import/compression.rs +++ b/sqlx-exasol-impl/src/connection/etl/import/compression.rs @@ -21,14 +21,16 @@ pub enum ExaImportWriter { } impl ExaImportWriter { + #[allow(unused_variables, reason = "conditionally compiled")] pub fn new(socket: ExaSocket, buffer_size: usize, with_compression: bool) -> Self { let writer = ExaWriter::new(socket, buffer_size); - match with_compression { - #[cfg(feature = "compression")] - true => Self::Compressed(GzipEncoder::new(writer)), - _ => Self::Plain(writer), + #[cfg(feature = "compression")] + if with_compression { + return Self::Compressed(GzipEncoder::new(writer)); } + + Self::Plain(writer) } } diff --git a/src/connection/etl/import/mod.rs b/sqlx-exasol-impl/src/connection/etl/import/mod.rs similarity index 100% rename from src/connection/etl/import/mod.rs rename to sqlx-exasol-impl/src/connection/etl/import/mod.rs diff --git a/src/connection/etl/import/options.rs b/sqlx-exasol-impl/src/connection/etl/import/options.rs similarity index 100% rename from src/connection/etl/import/options.rs rename to sqlx-exasol-impl/src/connection/etl/import/options.rs diff --git a/src/connection/etl/import/writer.rs b/sqlx-exasol-impl/src/connection/etl/import/writer.rs similarity index 91% rename from src/connection/etl/import/writer.rs rename to sqlx-exasol-impl/src/connection/etl/import/writer.rs index f674b5bc..5efd9a86 100644 --- a/src/connection/etl/import/writer.rs +++ b/sqlx-exasol-impl/src/connection/etl/import/writer.rs @@ -1,4 +1,5 @@ use std::{ + fmt, future::Future, io, pin::Pin, @@ -88,7 +89,6 @@ impl Service> for ImportService { } } -#[derive(Debug)] pub struct ExaWriter { conn: ImportConnection, buffer: BytesMut, @@ -152,6 +152,24 @@ impl ExaWriter { } } +// Not derived so that the buffer field is ignored as it's mainly just a data buffer that takes a +// lot of space. +impl fmt::Debug for ExaWriter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ExaWriter { + conn, + max_buf_size, + sink, + .. + } = self; + f.debug_struct("ExaWriter") + .field("conn", &conn) + .field("max_buf_size", &max_buf_size) + .field("sink", &sink) + .finish() + } +} + impl AsyncWrite for ExaWriter { fn poll_write( self: Pin<&mut Self>, diff --git a/src/connection/etl/job/maybe_tls/mod.rs b/sqlx-exasol-impl/src/connection/etl/job/maybe_tls/mod.rs similarity index 64% rename from src/connection/etl/job/maybe_tls/mod.rs rename to sqlx-exasol-impl/src/connection/etl/job/maybe_tls/mod.rs index febc535f..060bbee3 100644 --- a/src/connection/etl/job/maybe_tls/mod.rs +++ b/sqlx-exasol-impl/src/connection/etl/job/maybe_tls/mod.rs @@ -1,4 +1,4 @@ -#[cfg(any(feature = "etl_native_tls", feature = "etl_rustls"))] +#[cfg(feature = "tls")] pub mod tls; use futures_util::FutureExt; @@ -7,25 +7,25 @@ use sqlx_core::net::{Socket, WithSocket}; use crate::{ connection::websocket::socket::WithExaSocket, etl::job::{SocketSetup, WithSocketMaker}, - SqlxError, SqlxResult, + SqlxResult, }; /// Implementor of [`WithSocketMaker`] that abstracts away the TLS/non-TLS socket creation. pub enum WithMaybeTlsSocketMaker { NonTls, - #[cfg(any(feature = "etl_native_tls", feature = "etl_rustls"))] + #[cfg(feature = "tls")] Tls(tls::WithTlsSocketMaker), } impl WithMaybeTlsSocketMaker { + #[allow(unused_variables, reason = "conditionally compiled")] pub fn new(with_tls: bool) -> SqlxResult { - match with_tls { - #[cfg(any(feature = "etl_native_tls", feature = "etl_rustls"))] - true => tls::with_worker().map(Self::Tls), - #[allow(unreachable_patterns, reason = "reachable with no TLS feature ")] - true => Err(SqlxError::Tls("No ETL TLS feature set".into())), - false => Ok(Self::NonTls), + #[cfg(feature = "tls")] + if with_tls { + return tls::with_worker().map(Self::Tls); } + + Ok(Self::NonTls) } } @@ -35,7 +35,7 @@ impl WithSocketMaker for WithMaybeTlsSocketMaker { fn make_with_socket(&self, with_socket: WithExaSocket) -> Self::WithSocket { match self { Self::NonTls => WithMaybeTlsSocket::NonTls(with_socket), - #[cfg(any(feature = "etl_native_tls", feature = "etl_rustls"))] + #[cfg(feature = "tls")] Self::Tls(w) => WithMaybeTlsSocket::Tls(w.make_with_socket(with_socket)), } } @@ -44,7 +44,7 @@ impl WithSocketMaker for WithMaybeTlsSocketMaker { /// Implementor of [`WithSocket`] that abstracts away the TLS/non-TLS socket creation. pub enum WithMaybeTlsSocket { NonTls(WithExaSocket), - #[cfg(any(feature = "etl_native_tls", feature = "etl_rustls"))] + #[cfg(feature = "tls")] Tls(tls::WithTlsSocket), } @@ -54,7 +54,7 @@ impl WithSocket for WithMaybeTlsSocket { async fn with_socket(self, socket: S) -> Self::Output { match self { WithMaybeTlsSocket::NonTls(w) => w.with_socket(socket).map(Ok).boxed(), - #[cfg(any(feature = "etl_native_tls", feature = "etl_rustls"))] + #[cfg(feature = "tls")] WithMaybeTlsSocket::Tls(w) => w.with_socket(socket).await, } } diff --git a/src/connection/etl/job/maybe_tls/tls/mod.rs b/sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/mod.rs similarity index 76% rename from src/connection/etl/job/maybe_tls/tls/mod.rs rename to sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/mod.rs index ee951fc6..a650f543 100644 --- a/src/connection/etl/job/maybe_tls/tls/mod.rs +++ b/sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/mod.rs @@ -1,6 +1,6 @@ -#[cfg(feature = "etl_native_tls")] +#[cfg(feature = "native-tls")] mod native_tls; -#[cfg(feature = "etl_rustls")] +#[cfg(feature = "rustls")] mod rustls; mod sync_socket; @@ -12,17 +12,14 @@ use rsa::{ use crate::{error::ToSqlxError, SqlxError, SqlxResult}; -#[cfg(all(feature = "etl_native_tls", feature = "etl_rustls"))] -compile_error!("Only enable one of 'etl_antive_tls' or 'etl_rustls' features"); - -#[cfg(feature = "etl_native_tls")] +#[cfg(feature = "native-tls")] pub type WithTlsSocketMaker = native_tls::WithNativeTlsSocketMaker; -#[cfg(feature = "etl_rustls")] +#[cfg(all(feature = "rustls", not(feature = "native-tls")))] pub type WithTlsSocketMaker = rustls::WithRustlsSocketMaker; -#[cfg(feature = "etl_native_tls")] +#[cfg(feature = "native-tls")] pub type WithTlsSocket = native_tls::WithNativeTlsSocket; -#[cfg(feature = "etl_rustls")] +#[cfg(all(feature = "rustls", not(feature = "native-tls")))] pub type WithTlsSocket = rustls::WithRustlsSocket; /// Returns the dedicated [`impl WithSocketMaker`] for the chosen TLS implementation. @@ -42,9 +39,9 @@ pub fn with_worker() -> SqlxResult { .self_signed(&key_pair) .map_err(ToSqlxError::to_sqlx_err)?; - #[cfg(feature = "etl_native_tls")] + #[cfg(feature = "native-tls")] return native_tls::WithNativeTlsSocketMaker::new(&cert, &key_pair); - #[cfg(feature = "etl_rustls")] + #[cfg(all(feature = "rustls", not(feature = "native-tls")))] return rustls::WithRustlsSocketMaker::new(&cert, &key_pair); } diff --git a/src/connection/etl/job/maybe_tls/tls/native_tls.rs b/sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/native_tls.rs similarity index 99% rename from src/connection/etl/job/maybe_tls/tls/native_tls.rs rename to sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/native_tls.rs index 8a37734e..6ee460f7 100644 --- a/src/connection/etl/job/maybe_tls/tls/native_tls.rs +++ b/sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/native_tls.rs @@ -72,7 +72,7 @@ impl WithNativeTlsSocket { poll_fn(|cx| h.get_mut().poll_write_ready(cx)).await?; res = h.handshake(); } - }; + } } } } @@ -113,7 +113,7 @@ where match self.0.shutdown() { Err(e) if e.kind() == io::ErrorKind::WouldBlock => (), ready => return Poll::Ready(ready), - }; + } ready!(self.0.get_mut().poll_read_ready(cx))?; self.0.get_mut().poll_write_ready(cx) diff --git a/src/connection/etl/job/maybe_tls/tls/rustls.rs b/sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/rustls.rs similarity index 98% rename from src/connection/etl/job/maybe_tls/tls/rustls.rs rename to sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/rustls.rs index 40ec4d4b..2db4f1da 100644 --- a/src/connection/etl/job/maybe_tls/tls/rustls.rs +++ b/sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/rustls.rs @@ -25,6 +25,7 @@ use crate::{ pub struct WithRustlsSocketMaker(Arc); impl WithRustlsSocketMaker { + #[cfg_attr(feature = "native-tls", allow(dead_code))] pub fn new(cert: &Certificate, key_pair: &KeyPair) -> SqlxResult { tracing::trace!("creating 'rustls' socket spawner"); @@ -121,7 +122,7 @@ where match self.state.complete_io(&mut self.inner) { Err(e) if e.kind() == io::ErrorKind::WouldBlock => (), ready => return Poll::Ready(ready.map(|_| ())), - }; + } ready!(self.inner.poll_read_ready(cx))?; } @@ -132,7 +133,7 @@ where match self.state.complete_io(&mut self.inner) { Err(e) if e.kind() == io::ErrorKind::WouldBlock => (), ready => return Poll::Ready(ready.map(|_| ())), - }; + } ready!(self.inner.poll_write_ready(cx))?; } diff --git a/src/connection/etl/job/maybe_tls/tls/sync_socket.rs b/sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/sync_socket.rs similarity index 100% rename from src/connection/etl/job/maybe_tls/tls/sync_socket.rs rename to sqlx-exasol-impl/src/connection/etl/job/maybe_tls/tls/sync_socket.rs diff --git a/src/connection/etl/job/mod.rs b/sqlx-exasol-impl/src/connection/etl/job/mod.rs similarity index 97% rename from src/connection/etl/job/mod.rs rename to sqlx-exasol-impl/src/connection/etl/job/mod.rs index 353e4fd0..d3f78482 100644 --- a/src/connection/etl/job/mod.rs +++ b/sqlx-exasol-impl/src/connection/etl/job/mod.rs @@ -9,7 +9,10 @@ use std::{ use arrayvec::ArrayString; use futures_core::future::BoxFuture; -use sqlx_core::net::{Socket, WithSocket}; +use sqlx_core::{ + net::{Socket, WithSocket}, + sql_str::{AssertSqlSafe, SqlSafeStr}, +}; use crate::{ connection::websocket::{ @@ -118,8 +121,8 @@ pub trait EtlJob: Sized + Send + Sync { .await?; // Query execution driving future to be returned and awaited alongside the worker IO - let query = self.query(addrs, with_tls, with_compression); - let future = ExecuteEtl(ExaRoundtrip::new(Execute(query.into()))).future(&mut conn.ws); + let query = AssertSqlSafe(self.query(addrs, with_tls, with_compression)).into_sql_str(); + let future = ExecuteEtl(ExaRoundtrip::new(Execute(query))).future(&mut conn.ws); Ok((EtlQuery(future), workers)) } diff --git a/src/connection/etl/mod.rs b/sqlx-exasol-impl/src/connection/etl/mod.rs similarity index 97% rename from src/connection/etl/mod.rs rename to sqlx-exasol-impl/src/connection/etl/mod.rs index 05d222a8..07e8251f 100644 --- a/src/connection/etl/mod.rs +++ b/sqlx-exasol-impl/src/connection/etl/mod.rs @@ -16,7 +16,7 @@ //! ends when all the workers receive EOF. They can be dropped afterwards. //! //! ETL jobs can use TLS, compression, or both and will do so in a -//! consistent manner with the [`ExaConnection`] they are executed on. +//! consistent manner with the [`crate::ExaConnection`] they are executed on. //! That means that if the connection uses TLS / compression, so will the ETL job. //! //! **NOTE:** Trying to run ETL jobs with TLS without an ETL TLS feature flag results @@ -131,7 +131,7 @@ impl Future for EtlQuery<'_> { /// Implementor of [`WebSocketFuture`] that executes an owned ETL query. #[derive(Debug)] -struct ExecuteEtl(ExaRoundtrip, SingleResult>); +struct ExecuteEtl(ExaRoundtrip); impl WebSocketFuture for ExecuteEtl { type Output = ExaQueryResult; diff --git a/src/connection/executor.rs b/sqlx-exasol-impl/src/connection/executor.rs similarity index 75% rename from src/connection/executor.rs rename to sqlx-exasol-impl/src/connection/executor.rs index 6ef2626c..9708c1ae 100644 --- a/src/connection/executor.rs +++ b/sqlx-exasol-impl/src/connection/executor.rs @@ -1,11 +1,10 @@ -use std::borrow::Cow; - use futures_core::{future::BoxFuture, stream::BoxStream}; use futures_util::{FutureExt, StreamExt, TryStreamExt}; use sqlx_core::{ describe::Describe, executor::{Execute, Executor}, logger::QueryLogger, + sql_str::SqlStr, Either, }; @@ -81,6 +80,40 @@ impl<'c> Executor<'c> for &'c mut ExaConnection { } } + fn fetch_all<'e, 'q: 'e, E>(self, query: E) -> BoxFuture<'e, SqlxResult>> + where + 'c: 'e, + E: 'q + Execute<'q, Self::Database>, + { + match self.fetch_impl(query) { + Ok(stream) => stream + .try_filter_map(|v| std::future::ready(Ok(v.right()))) + .try_collect() + .boxed(), + Err(e) => std::future::ready(Err(e)).boxed(), + } + } + + fn fetch_one<'e, 'q: 'e, E>(self, query: E) -> BoxFuture<'e, SqlxResult> + where + 'c: 'e, + E: 'q + Execute<'q, Self::Database>, + { + let stream = match self.fetch_impl(query) { + Ok(stream) => stream, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + + Box::pin(async move { + stream + .try_filter_map(|v| std::future::ready(Ok(v.right()))) + .try_next() + .await + .transpose() + .unwrap_or(Err(SqlxError::RowNotFound)) + }) + } + fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, SqlxResult>> where 'q: 'e, @@ -100,20 +133,21 @@ impl<'c> Executor<'c> for &'c mut ExaConnection { }) } - fn prepare_with<'e, 'q>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: SqlStr, _parameters: &'e [ExaTypeInfo], - ) -> BoxFuture<'e, SqlxResult>> + ) -> BoxFuture<'e, SqlxResult> where - 'q: 'e, 'c: 'e, { Box::pin(async move { - let prepared = GetOrPrepare::new(sql, true).future(&mut self.ws).await?; + let prepared = GetOrPrepare::new(sql.clone(), true) + .future(&mut self.ws) + .await?; Ok(ExaStatement { - sql: Cow::Borrowed(sql), + sql, metadata: ExaStatementMetadata::new( prepared.columns.clone(), prepared.parameters.clone(), @@ -123,9 +157,8 @@ impl<'c> Executor<'c> for &'c mut ExaConnection { } /// Exasol does not provide nullability information, unfortunately. - fn describe<'e, 'q>(self, sql: &'q str) -> BoxFuture<'e, SqlxResult>> + fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, SqlxResult>> where - 'q: 'e, 'c: 'e, { Box::pin(async move { @@ -153,10 +186,10 @@ impl ExaConnection { 'c: 'e, E: 'q + Execute<'q, Exasol>, { - let sql = query.sql(); let persist = query.persistent(); - let logger = QueryLogger::new(sql, self.log_settings.clone()); let arguments = query.take_arguments().map_err(SqlxError::Encode)?; + let logger = QueryLogger::new(query.sql(), self.log_settings.clone()); + let sql = logger.sql().clone(); if let Some(arguments) = arguments { let future = ExecutePrepared::new(sql, persist, arguments); @@ -173,10 +206,10 @@ impl ExaConnection { 'c: 'e, E: 'q + Execute<'q, Exasol>, { - let sql = query.sql(); let persist = query.persistent(); - let logger = QueryLogger::new(sql, self.log_settings.clone()); let arguments = query.take_arguments().map_err(SqlxError::Encode)?; + let logger = QueryLogger::new(query.sql(), self.log_settings.clone()); + let sql = logger.sql().clone(); if let Some(arguments) = arguments { let future = ExecutePrepared::new(sql, persist, arguments); diff --git a/src/connection/mod.rs b/sqlx-exasol-impl/src/connection/mod.rs similarity index 51% rename from src/connection/mod.rs rename to sqlx-exasol-impl/src/connection/mod.rs index ef775b1d..35f078ea 100644 --- a/src/connection/mod.rs +++ b/sqlx-exasol-impl/src/connection/mod.rs @@ -1,16 +1,16 @@ #[cfg(feature = "etl")] pub mod etl; mod executor; -mod stream; +pub mod stream; pub mod websocket; use std::net::SocketAddr; -use futures_core::future::BoxFuture; -use futures_util::{FutureExt, SinkExt}; +use futures_util::SinkExt; use rand::{seq::SliceRandom, thread_rng}; use sqlx_core::{ connection::{Connection, LogSettings}, + executor::Executor, transaction::Transaction, }; use websocket::{socket::WithExaSocket, ExaWebSocket}; @@ -30,8 +30,8 @@ use crate::{ #[derive(Debug)] pub struct ExaConnection { pub(crate) ws: ExaWebSocket, + pub(crate) log_settings: LogSettings, session_info: SessionInfo, - log_settings: LogSettings, } impl ExaConnection { @@ -100,7 +100,7 @@ impl ExaConnection { }; // Break if we successfully connect a websocket. - match ExaWebSocket::new(host, opts.port, socket, opts.into(), with_tls).await { + match ExaWebSocket::new(host, opts.port, socket, opts.try_into()?, with_tls).await { Ok(ws) => { ws_result = Ok(ws); break; @@ -111,14 +111,25 @@ impl ExaConnection { } let (ws, session_info) = ws_result?; - let con = Self { + let mut con = Self { ws, log_settings: LogSettings::default(), session_info, }; + con.configure_session().await?; + Ok(con) } + + /// Sets session parameters for the open connection. + async fn configure_session(&mut self) -> SqlxResult<()> { + // We rely on this for consistent size output for HASHTYPE columns. + // This allows to reliably use UUID at compile-time. + self.execute("ALTER SESSION SET HASHTYPE_FORMAT = 'HEX';") + .await?; + Ok(()) + } } impl Connection for ExaConnection { @@ -126,43 +137,39 @@ impl Connection for ExaConnection { type Options = ExaConnectOptions; - fn close(mut self) -> BoxFuture<'static, SqlxResult<()>> { - Box::pin(async move { - Disconnect::default().future(&mut self.ws).await?; - self.ws.close().await?; - Ok(()) - }) + async fn close(mut self) -> SqlxResult<()> { + Disconnect::default().future(&mut self.ws).await?; + self.ws.close().await?; + Ok(()) } - fn close_hard(mut self) -> BoxFuture<'static, SqlxResult<()>> { - Box::pin(async move { self.ws.close().await }) + async fn close_hard(mut self) -> SqlxResult<()> { + self.ws.close().await } - fn ping(&mut self) -> BoxFuture<'_, SqlxResult<()>> { - self.ws.ping().boxed() + async fn ping(&mut self) -> SqlxResult<()> { + self.ws.ping().await } - fn begin(&mut self) -> BoxFuture<'_, SqlxResult>> + async fn begin(&mut self) -> SqlxResult> where Self: Sized, { - Transaction::begin(self, None) + Transaction::begin(self, None).await } fn shrink_buffers(&mut self) {} - fn flush(&mut self) -> BoxFuture<'_, SqlxResult<()>> { - Box::pin(async { - if let Some(future) = self.ws.pending_close.take() { - future.future(&mut self.ws).await?; - } + async fn flush(&mut self) -> SqlxResult<()> { + if let Some(future) = self.ws.pending_close.take() { + future.future(&mut self.ws).await?; + } - if let Some(future) = self.ws.pending_rollback.take() { - future.future(&mut self.ws).await?; - } + if let Some(future) = self.ws.pending_rollback.take() { + future.future(&mut self.ws).await?; + } - Ok(()) - }) + Ok(()) } fn should_flush(&self) -> bool { @@ -176,70 +183,31 @@ impl Connection for ExaConnection { self.ws.statement_cache.len() } - fn clear_cached_statements(&mut self) -> BoxFuture<'_, SqlxResult<()>> + async fn clear_cached_statements(&mut self) -> SqlxResult<()> where Self::Database: sqlx_core::database::HasStatementCache, { - Box::pin(async { - while let Some((_, prep)) = self.ws.statement_cache.pop_lru() { - ClosePrepared::new(prep.statement_handle) - .future(&mut self.ws) - .await?; - } + while let Some((_, prep)) = self.ws.statement_cache.pop_lru() { + ClosePrepared::new(prep.statement_handle) + .future(&mut self.ws) + .await?; + } - Ok(()) - }) + Ok(()) } } #[cfg(test)] #[cfg(feature = "migrate")] +#[allow(clippy::large_futures, reason = "silencing clippy")] mod tests { use std::num::NonZeroUsize; use futures_util::TryStreamExt; - use sqlx::{query, Connection, Executor}; + use sqlx::Executor; use sqlx_core::{error::BoxDynError, pool::PoolOptions}; - use crate::{ExaConnectOptions, ExaQueryResult, Exasol}; - - #[cfg(feature = "compression")] - #[ignore] - #[sqlx::test] - async fn test_compression_feature( - pool_opts: PoolOptions, - mut exa_opts: ExaConnectOptions, - ) -> Result<(), BoxDynError> { - exa_opts.compression = true; - - let pool = pool_opts.connect_with(exa_opts).await?; - let mut con = pool.acquire().await?; - let schema = "TEST_SWITCH_SCHEMA"; - - con.execute(format!("CREATE SCHEMA IF NOT EXISTS {schema};").as_str()) - .await?; - - let new_schema: String = sqlx::query_scalar("SELECT CURRENT_SCHEMA") - .fetch_one(&mut *con) - .await?; - - con.execute(format!("DROP SCHEMA IF EXISTS {schema} CASCADE;").as_str()) - .await?; - - assert_eq!(schema, new_schema); - - Ok(()) - } - - #[cfg(not(feature = "compression"))] - #[sqlx::test] - async fn test_compression_no_feature( - pool_opts: PoolOptions, - mut exa_opts: ExaConnectOptions, - ) { - exa_opts.compression = true; - assert!(pool_opts.connect_with(exa_opts).await.is_err()); - } + use crate::{ExaConnectOptions, Exasol}; #[sqlx::test] async fn test_stmt_cache( @@ -258,11 +226,11 @@ mod tests { assert!(!con.as_ref().ws.statement_cache.contains(sql1)); assert!(!con.as_ref().ws.statement_cache.contains(sql2)); - query(sql1).execute(&mut *con).await?; + sqlx::query(sql1).execute(&mut *con).await?; assert!(con.as_ref().ws.statement_cache.contains(sql1)); assert!(!con.as_ref().ws.statement_cache.contains(sql2)); - query(sql2).execute(&mut *con).await?; + sqlx::query(sql2).execute(&mut *con).await?; assert!(!con.as_ref().ws.statement_cache.contains(sql1)); assert!(con.as_ref().ws.statement_cache.contains(sql2)); @@ -275,23 +243,7 @@ mod tests { mut exa_opts: ExaConnectOptions, ) -> Result<(), BoxDynError> { exa_opts.schema = None; - let pool = pool_opts.connect_with(exa_opts).await?; - let mut con = pool.acquire().await?; - - let schema: Option = sqlx::query_scalar("SELECT CURRENT_SCHEMA") - .fetch_one(&mut *con) - .await?; - assert!(schema.is_none()); - - Ok(()) - } - - #[sqlx::test] - async fn test_schema_selected( - pool_opts: PoolOptions, - exa_opts: ExaConnectOptions, - ) -> Result<(), BoxDynError> { let pool = pool_opts.connect_with(exa_opts).await?; let mut con = pool.acquire().await?; @@ -299,141 +251,8 @@ mod tests { .fetch_one(&mut *con) .await?; - assert!(schema.is_some()); - - Ok(()) - } - - #[sqlx::test] - async fn test_schema_switch( - pool_opts: PoolOptions, - exa_opts: ExaConnectOptions, - ) -> Result<(), BoxDynError> { - let pool = pool_opts.connect_with(exa_opts).await?; - let mut con = pool.acquire().await?; - let schema = "TEST_SWITCH_SCHEMA"; - - con.execute(format!("CREATE SCHEMA IF NOT EXISTS {schema};").as_str()) - .await?; - - let new_schema: String = sqlx::query_scalar("SELECT CURRENT_SCHEMA") - .fetch_one(&mut *con) - .await?; - - con.execute(format!("DROP SCHEMA IF EXISTS {schema} CASCADE;").as_str()) - .await?; - - assert_eq!(schema, new_schema); - - Ok(()) - } - - #[sqlx::test] - async fn test_schema_switch_from_attr( - pool_opts: PoolOptions, - exa_opts: ExaConnectOptions, - ) -> Result<(), BoxDynError> { - let pool = pool_opts.connect_with(exa_opts).await?; - let mut con = pool.acquire().await?; - - let orig_schema: String = sqlx::query_scalar("SELECT CURRENT_SCHEMA") - .fetch_one(&mut *con) - .await?; - - let schema = "TEST_SWITCH_SCHEMA"; - - con.execute(format!("CREATE SCHEMA IF NOT EXISTS {schema};").as_str()) - .await?; - - con.attributes_mut().set_current_schema(orig_schema.clone()); - con.flush_attributes().await?; - - let new_schema: String = sqlx::query_scalar("SELECT CURRENT_SCHEMA") - .fetch_one(&mut *con) - .await?; - - assert_eq!(orig_schema, new_schema); - - Ok(()) - } - - #[sqlx::test] - async fn test_schema_close_and_empty_attr( - pool_opts: PoolOptions, - exa_opts: ExaConnectOptions, - ) -> Result<(), BoxDynError> { - let pool = pool_opts.connect_with(exa_opts).await?; - let mut con = pool.acquire().await?; - - let orig_schema: String = sqlx::query_scalar("SELECT CURRENT_SCHEMA") - .fetch_one(&mut *con) - .await?; - - assert_eq!( - con.attributes().current_schema(), - Some(orig_schema.as_str()) - ); - - con.execute("CLOSE SCHEMA").await?; - assert_eq!(con.attributes().current_schema(), None); - - Ok(()) - } - - #[sqlx::test] - async fn test_comment_stmts( - pool_opts: PoolOptions, - exa_opts: ExaConnectOptions, - ) -> Result<(), BoxDynError> { - let pool = pool_opts.connect_with(exa_opts).await?; - let mut con = pool.acquire().await?; - - con.execute_many("/* this is a comment */") - .try_collect::() - .await?; - con.execute("-- this is a comment").await?; - - Ok(()) - } - - #[sqlx::test] - async fn test_connection_flush_on_drop( - pool_opts: PoolOptions, - exa_opts: ExaConnectOptions, - ) -> Result<(), BoxDynError> { - // Only allow one connection - let pool = pool_opts.max_connections(1).connect_with(exa_opts).await?; - pool.execute("CREATE TABLE TRANSACTIONS_TEST ( col DECIMAL(1, 0) );") - .await?; - - { - let mut conn = pool.acquire().await?; - let mut tx = conn.begin().await?; - tx.execute("INSERT INTO TRANSACTIONS_TEST VALUES(1)") - .await?; - } - - let mut conn = pool.acquire().await?; - { - let mut tx = conn.begin().await?; - tx.execute("INSERT INTO TRANSACTIONS_TEST VALUES(1)") - .await?; - } - - { - let mut tx = conn.begin().await?; - tx.execute("INSERT INTO TRANSACTIONS_TEST VALUES(1)") - .await?; - } - - drop(conn); - - let inserted = pool - .fetch_all("SELECT * FROM TRANSACTIONS_TEST") - .await? - .len(); + assert!(schema.is_none()); - assert_eq!(inserted, 0); Ok(()) } @@ -445,11 +264,11 @@ mod tests { // Only allow one connection let pool = pool_opts.max_connections(1).connect_with(exa_opts).await?; let mut conn = pool.acquire().await?; - conn.execute("CREATE TABLE CLOSE_RESULTS_TEST ( col DECIMAL(1, 0) );") + conn.execute("CREATE TABLE CLOSE_RESULTS_TEST ( col DECIMAL(3, 0) );") .await?; - query("INSERT INTO CLOSE_RESULTS_TEST VALUES(?)") - .bind(vec![1; 10000]) + sqlx::query("INSERT INTO CLOSE_RESULTS_TEST VALUES(?)") + .bind(vec![1i8; 10000]) .execute(&mut *conn) .await?; diff --git a/src/connection/stream.rs b/sqlx-exasol-impl/src/connection/stream.rs similarity index 94% rename from src/connection/stream.rs rename to sqlx-exasol-impl/src/connection/stream.rs index e7cee25c..5119ef20 100644 --- a/src/connection/stream.rs +++ b/sqlx-exasol-impl/src/connection/stream.rs @@ -14,7 +14,7 @@ use std::{ use futures_core::ready; use futures_util::Stream; use serde_json::Value; -use sqlx_core::{logger::QueryLogger, Either, HashMap}; +use sqlx_core::{ext::ustr::UStr, logger::QueryLogger, Either, HashMap}; use crate::{ column::ExaColumn, @@ -35,18 +35,18 @@ use crate::{ /// from the database. /// /// This is the top of the hierarchy and the actual type used to stream rows from a [`ResultSet`]. -pub struct ResultStream<'a> { - ws: &'a mut ExaWebSocket, - logger: QueryLogger<'a>, +pub struct ResultStream<'ws> { + ws: &'ws mut ExaWebSocket, + logger: QueryLogger, result_set_handles: Vec, - state: ResultStreamState<'a>, + state: ResultStreamState, had_err: bool, } -impl<'a> ResultStream<'a> { - pub fn new(ws: &'a mut ExaWebSocket, logger: QueryLogger<'a>, future: F) -> Self +impl<'ws> ResultStream<'ws> { + pub fn new(ws: &'ws mut ExaWebSocket, logger: QueryLogger, future: F) -> Self where - ResultStreamState<'a>: From, + ResultStreamState: From, { Self { ws, @@ -132,27 +132,27 @@ impl Drop for ResultStream<'_> { /// State used to distinguish between the initial query execution and the subsequent streaming of /// rows. -pub enum ResultStreamState<'a> { - Execute(Execute<'a>), - ExecuteBatch(ExecuteBatch<'a>), - ExecutePrepared(ExecutePrepared<'a>), +pub enum ResultStreamState { + Execute(Execute), + ExecuteBatch(ExecuteBatch), + ExecutePrepared(ExecutePrepared), Stream(MultiResultStream), } -impl<'a> From> for ResultStreamState<'a> { - fn from(value: Execute<'a>) -> Self { +impl From for ResultStreamState { + fn from(value: Execute) -> Self { Self::Execute(value) } } -impl<'a> From> for ResultStreamState<'a> { - fn from(value: ExecuteBatch<'a>) -> Self { +impl From for ResultStreamState { + fn from(value: ExecuteBatch) -> Self { Self::ExecuteBatch(value) } } -impl<'a> From> for ResultStreamState<'a> { - fn from(value: ExecutePrepared<'a>) -> Self { +impl From for ResultStreamState { + fn from(value: ExecutePrepared) -> Self { Self::ExecutePrepared(value) } } @@ -445,7 +445,7 @@ enum MultiChunkStreamState { /// This is the lowest level of the streaming hierarchy and merely iterates over an already /// retrieved chunk of rows, not dealing at all with async operations. struct ChunkIter { - column_names: Arc, usize>>, + column_names: Arc>, columns: Arc<[ExaColumn]>, chunk_rows_total: usize, chunk_rows_pos: usize, diff --git a/src/connection/websocket/future.rs b/sqlx-exasol-impl/src/connection/websocket/future.rs similarity index 66% rename from src/connection/websocket/future.rs rename to sqlx-exasol-impl/src/connection/websocket/future.rs index ae929d22..356e04af 100644 --- a/src/connection/websocket/future.rs +++ b/sqlx-exasol-impl/src/connection/websocket/future.rs @@ -19,6 +19,7 @@ use serde::{ de::{DeserializeOwned, IgnoredAny}, Serialize, }; +use sqlx_core::sql_str::SqlStr; use crate::{ connection::{ @@ -80,20 +81,20 @@ pub trait WebSocketFuture: Unpin + Sized { /// Implementor of [`WebSocketFuture`] that executes a prepared statement. #[derive(Debug)] -pub struct ExecutePrepared<'a> { +pub struct ExecutePrepared { arguments: ExaArguments, - state: ExecutePreparedState<'a>, + state: ExecutePreparedState, } -impl<'a> ExecutePrepared<'a> { - pub fn new(sql: &'a str, persist: bool, arguments: ExaArguments) -> Self { +impl ExecutePrepared { + pub fn new(sql: SqlStr, persist: bool, arguments: ExaArguments) -> Self { let future = GetOrPrepare::new(sql, persist); let state = ExecutePreparedState::GetOrPrepare(future); Self { arguments, state } } } -impl WebSocketFuture for ExecutePrepared<'_> { +impl WebSocketFuture for ExecutePrepared { type Output = MultiResultStream; fn poll_unpin( @@ -105,19 +106,6 @@ impl WebSocketFuture for ExecutePrepared<'_> { match &mut self.state { ExecutePreparedState::GetOrPrepare(future) => { let prepared = ready!(future.poll_unpin(cx, ws))?; - - // Check the compatibility between provided parameter data types - // and the ones expected by the database. - let iter = std::iter::zip(prepared.parameters.as_ref(), &self.arguments.types); - for (expected, provided) in iter { - if !expected.compatible(provided) { - return Err(ExaProtocolError::DatatypeMismatch( - expected.name, - provided.name, - ))?; - } - } - let buf = std::mem::take(&mut self.arguments.buf); let command = ExecutePreparedStmt::new( prepared.statement_handle, @@ -137,118 +125,23 @@ impl WebSocketFuture for ExecutePrepared<'_> { } #[derive(Debug)] -enum ExecutePreparedState<'a> { - GetOrPrepare(GetOrPrepare<'a>), +enum ExecutePreparedState { + GetOrPrepare(GetOrPrepare), ExecutePrepared(ExaRoundtrip), } /// Implementor of [`WebSocketFuture`] that executes a batch of SQL statements. #[derive(Debug)] -pub struct ExecuteBatch<'a>(ExaRoundtrip, MultiResults>); +pub struct ExecuteBatch(ExaRoundtrip); -impl<'a> ExecuteBatch<'a> { - pub fn new(sql: &'a str) -> Self { - let request = request::ExecuteBatch(Self::split_query(sql)); +impl ExecuteBatch { + pub fn new(sql: SqlStr) -> Self { + let request = request::ExecuteBatch(sql); Self(ExaRoundtrip::new(request)) } - - /// Splits a SQL query into individual statements. - /// - /// The splitting follows the following logic: - /// - trim the query to remove leading and trailing whitespace - /// - parse each character and store the string slice up to a ';' that is not inside a line or - /// block comment and not contained within single or double quotes - /// - register the next statement start as the next non-whitespace character after a split, - /// essentially ignoring whitespace between statements (but retaining comments) - /// - add the remainder string slice after the last ';' if it is not empty; this means that the - /// last statement could be a comment only, but that is okay as Exasol does not complain. - fn split_query(query: &str) -> Vec<&str> { - #[derive(Clone, Copy)] - enum Inside { - Statement, - LineComment, - BlockComment, - DoubleQuote, - SingleQuote, - Whitespace, - } - - let query = query.trim(); - let mut chars = query.char_indices().peekable(); - let mut state = Inside::Statement; - let mut statements = Vec::new(); - let mut start = 0; - - while let Some((i, c)) = chars.next() { - let mut peek = || chars.peek().map(|(_, c)| *c); - let is_whitespace = |p: Option| p.is_some_and(char::is_whitespace); - - #[allow(clippy::match_same_arms, reason = "better readability if split")] - match (state, c) { - // Line comment start - (Inside::Statement, '-') if Some('-') == peek() => { - chars.next(); - state = Inside::LineComment; - } - // Block comment start - (Inside::Statement, '/') if Some('*') == peek() => { - chars.next(); - state = Inside::BlockComment; - } - // Double quote start - (Inside::Statement, '"') => state = Inside::DoubleQuote, - // Single quote start - (Inside::Statement, '\'') => state = Inside::SingleQuote, - // Statement end - (Inside::Statement, ';') => { - statements.push(&query[start..=i]); - start = i + 1; - - // Whitespace between statements start - if is_whitespace(peek()) { - state = Inside::Whitespace; - } - } - // Skip escaped double quote - (Inside::DoubleQuote, '"') if Some('"') == peek() => { - chars.next(); - } - // Skip escaped single quote - (Inside::SingleQuote, '\'') if Some('\'') == peek() => { - chars.next(); - } - // Double quote end - (Inside::DoubleQuote, '"') => state = Inside::Statement, - // Single quote end - (Inside::SingleQuote, '\'') => state = Inside::Statement, - // Line comment end - (Inside::LineComment, '\n') => state = Inside::Statement, - // Block comment end - (Inside::BlockComment, '*') if Some('/') == peek() => { - chars.next(); - state = Inside::Statement; - } - // Whitespace between statements end - (Inside::Whitespace, _) if !is_whitespace(peek()) => { - start = i + 1; - state = Inside::Statement; - } - _ => (), - } - } - - // Add final part if anything remains after the last `;`. - // NOTE: Exasol does not complain about trailing comments, but only empty queries. - let remaining = &query[start..]; - if !remaining.is_empty() { - statements.push(remaining); - } - - statements - } } -impl WebSocketFuture for ExecuteBatch<'_> { +impl WebSocketFuture for ExecuteBatch { type Output = MultiResultStream; fn poll_unpin( @@ -263,15 +156,15 @@ impl WebSocketFuture for ExecuteBatch<'_> { /// Implementor of [`WebSocketFuture`] that executes a single SQL statement. #[derive(Debug)] -pub struct Execute<'a>(ExaRoundtrip, SingleResult>); +pub struct Execute(ExaRoundtrip); -impl<'a> Execute<'a> { - pub fn new(sql: &'a str) -> Self { - Self(ExaRoundtrip::new(request::Execute(sql.into()))) +impl Execute { + pub fn new(sql: SqlStr) -> Self { + Self(ExaRoundtrip::new(request::Execute(sql))) } } -impl WebSocketFuture for Execute<'_> { +impl WebSocketFuture for Execute { type Output = MultiResultStream; fn poll_unpin( @@ -286,14 +179,14 @@ impl WebSocketFuture for Execute<'_> { /// Implementor of [`WebSocketFuture`] that retrieves a prepared statement from the cache, storing /// it first on cache miss. #[derive(Debug)] -pub struct GetOrPrepare<'a> { - sql: &'a str, +pub struct GetOrPrepare { + sql: SqlStr, persist: bool, - state: GetOrPrepareState<'a>, + state: GetOrPrepareState, } -impl<'a> GetOrPrepare<'a> { - pub fn new(sql: &'a str, persist: bool) -> Self { +impl GetOrPrepare { + pub fn new(sql: SqlStr, persist: bool) -> Self { Self { sql, persist, @@ -302,7 +195,7 @@ impl<'a> GetOrPrepare<'a> { } } -impl WebSocketFuture for GetOrPrepare<'_> { +impl WebSocketFuture for GetOrPrepare { type Output = PreparedStatement; fn poll_unpin( @@ -313,11 +206,13 @@ impl WebSocketFuture for GetOrPrepare<'_> { loop { match &mut self.state { GetOrPrepareState::GetCached => { - self.state = match ws.statement_cache.get(self.sql).cloned() { + self.state = match ws.statement_cache.get(self.sql.as_str()).cloned() { // Cache hit, simply return Some(prepared) => return Poll::Ready(Ok(prepared)), // Cache miss, switch state and prepare statement - None => GetOrPrepareState::CreatePrepared(CreatePrepared::new(self.sql)), + None => { + GetOrPrepareState::CreatePrepared(CreatePrepared::new(self.sql.clone())) + } } } GetOrPrepareState::CreatePrepared(future) => { @@ -334,7 +229,7 @@ impl WebSocketFuture for GetOrPrepare<'_> { // Otherwise we go to simply retrieving it from the cache. self.state = ws .statement_cache - .push(self.sql.to_owned(), prepared) + .push(self.sql.as_str().to_owned(), prepared) .map(|(_, p)| p.statement_handle) .map(ClosePrepared::new) .map_or( @@ -352,9 +247,9 @@ impl WebSocketFuture for GetOrPrepare<'_> { } #[derive(Debug)] -enum GetOrPrepareState<'a> { +enum GetOrPrepareState { GetCached, - CreatePrepared(CreatePrepared<'a>), + CreatePrepared(CreatePrepared), ClosePrepared(ClosePrepared), } @@ -384,19 +279,19 @@ impl WebSocketFuture for FetchChunk { /// database support, a statement gets prepared and closed right after to retrieve the statement /// information. #[derive(Debug)] -pub enum Describe<'a> { - CreatePrepared(ExaRoundtrip, DescribeStatement>), +pub enum Describe { + CreatePrepared(ExaRoundtrip), ClosePrepared(ClosePrepared, DescribeStatement), Finished, } -impl<'a> Describe<'a> { - pub fn new(sql: &'a str) -> Self { +impl Describe { + pub fn new(sql: SqlStr) -> Self { Self::CreatePrepared(ExaRoundtrip::new(CreatePreparedStmt(sql))) } } -impl WebSocketFuture for Describe<'_> { +impl WebSocketFuture for Describe { type Output = DescribeStatement; fn poll_unpin( @@ -425,15 +320,15 @@ impl WebSocketFuture for Describe<'_> { /// Implementor of [`WebSocketFuture`] that creates a prepared statement. #[derive(Debug)] -pub struct CreatePrepared<'a>(ExaRoundtrip, PreparedStatement>); +pub struct CreatePrepared(ExaRoundtrip); -impl<'a> CreatePrepared<'a> { - pub fn new(sql: &'a str) -> Self { +impl CreatePrepared { + pub fn new(sql: SqlStr) -> Self { Self(ExaRoundtrip::new(CreatePreparedStmt(sql))) } } -impl WebSocketFuture for CreatePrepared<'_> { +impl WebSocketFuture for CreatePrepared { type Output = PreparedStatement; fn poll_unpin( @@ -550,11 +445,12 @@ impl WebSocketFuture for GetAttributes { /// Implementor of [`WebSocketFuture`] that executes a `COMMIT;` statement. #[derive(Debug)] -pub struct Commit(ExaRoundtrip, Option>); +pub struct Commit(ExaRoundtrip>); impl Default for Commit { fn default() -> Self { - Self(ExaRoundtrip::new(request::Execute("COMMIT;".into()))) + let sql = SqlStr::from_static("COMMIT;"); + Self(ExaRoundtrip::new(request::Execute(sql))) } } @@ -581,11 +477,12 @@ impl WebSocketFuture for Commit { /// Implementor of [`WebSocketFuture`] that executes a `ROLLBACK;` statement. #[derive(Debug)] -pub struct Rollback(ExaRoundtrip, Option>); +pub struct Rollback(ExaRoundtrip>); impl Default for Rollback { fn default() -> Self { - Self(ExaRoundtrip::new(request::Execute("ROLLBACK;".into()))) + let sql = SqlStr::from_static("ROLLBACK;"); + Self(ExaRoundtrip::new(request::Execute(sql))) } } @@ -828,164 +725,3 @@ where } } } - -#[cfg(test)] -mod tests { - use super::ExecuteBatch; - - #[test] - fn test_simple_statements() { - assert_eq!( - ExecuteBatch::split_query("SELECT * FROM users; SELECT * FROM orders;"), - vec!["SELECT * FROM users;", "SELECT * FROM orders;"] - ); - } - - #[test] - fn test_semicolon_in_single_quote() { - assert_eq!( - ExecuteBatch::split_query("SELECT ';' AS val; SELECT 'abc;def' AS val2;"), - vec!["SELECT ';' AS val;", "SELECT 'abc;def' AS val2;"] - ); - } - - #[test] - fn test_semicolon_in_double_quote() { - assert_eq!( - ExecuteBatch::split_query("SELECT \"col;name\" FROM table;"), - vec!["SELECT \"col;name\" FROM table;"] - ); - } - - #[test] - fn test_semicolon_in_line_comment() { - assert_eq!( - ExecuteBatch::split_query( - "SELECT 1; -- this is a comment; with a semicolon\nSELECT 2;" - ), - vec![ - "SELECT 1;", - "-- this is a comment; with a semicolon\nSELECT 2;" - ] - ); - } - - #[test] - fn test_semicolon_in_block_comment() { - assert_eq!( - ExecuteBatch::split_query("SELECT 1; /* multi-line ; comment */ SELECT 2;"), - vec!["SELECT 1;", "/* multi-line ; comment */ SELECT 2;"] - ); - } - - #[test] - fn test_escaped_quotes() { - assert_eq!( - ExecuteBatch::split_query( - "SELECT 'It''s a test; really'; SELECT \"escaped\"\"quote\" FROM dual;" - ), - vec![ - "SELECT 'It''s a test; really';", - "SELECT \"escaped\"\"quote\" FROM dual;" - ] - ); - } - - #[test] - fn test_trailing_semicolon_and_whitespace() { - assert_eq!( - ExecuteBatch::split_query("SELECT 1;; \n \n;"), - vec!["SELECT 1;", ";", ";"] - ); - } - - #[test] - fn test_leading_semicolon() { - assert_eq!( - ExecuteBatch::split_query(";SELECT 1;"), - vec![";", "SELECT 1;"] - ); - } - - #[test] - fn test_leading_semicolon_and_whitespace() { - assert_eq!( - ExecuteBatch::split_query(" ; SELECT 1;"), - vec![";", "SELECT 1;"] - ); - } - - #[test] - fn test_no_semicolon() { - assert_eq!(ExecuteBatch::split_query("SELECT 1"), vec!["SELECT 1"]); - } - - #[test] - fn test_no_whitespace_between_statements() { - assert_eq!( - ExecuteBatch::split_query("SELECT 1;SELECT 2"), - vec!["SELECT 1;", "SELECT 2"] - ); - } - - #[test] - fn test_no_whitespace_between_stmt_and_comment() { - assert_eq!( - ExecuteBatch::split_query("SELECT 1;/*testing*/SELECT 2;"), - vec!["SELECT 1;", "/*testing*/SELECT 2;"] - ); - } - - #[test] - fn test_trailing_comment() { - assert_eq!( - ExecuteBatch::split_query("SELECT 1;/*testing*/"), - vec!["SELECT 1;", "/*testing*/"] - ); - } - - #[test] - fn test_whitespace_between_statements() { - let query = " - /* Writing some comments */ - SELECT 1; - - -- Then writing some more comments - SELECT 2; - "; - assert_eq!( - ExecuteBatch::split_query(query), - vec![ - "/* Writing some comments */ - SELECT 1;", - "-- Then writing some more comments - SELECT 2;" - ] - ); - } - - #[test] - fn test_empty_input() { - assert_eq!(ExecuteBatch::split_query(""), Vec::<&str>::new()); - } - - #[test] - fn test_mixed_content() { - let query = r#" - SELECT 'test;--'; -- line comment with ; - /* block comment ; - over lines */ - SELECT "str;with;semicolons"; - "#; - assert_eq!( - ExecuteBatch::split_query(query), - vec![ - "SELECT 'test;--';", - r#"-- line comment with ; - /* block comment ; - over lines */ - SELECT "str;with;semicolons";"# - ] - ); - } -} diff --git a/src/connection/websocket/mod.rs b/sqlx-exasol-impl/src/connection/websocket/mod.rs similarity index 100% rename from src/connection/websocket/mod.rs rename to sqlx-exasol-impl/src/connection/websocket/mod.rs diff --git a/src/connection/websocket/request.rs b/sqlx-exasol-impl/src/connection/websocket/request.rs similarity index 57% rename from src/connection/websocket/request.rs rename to sqlx-exasol-impl/src/connection/websocket/request.rs index d82685c4..a2847c97 100644 --- a/src/connection/websocket/request.rs +++ b/sqlx-exasol-impl/src/connection/websocket/request.rs @@ -7,6 +7,7 @@ use base64::{engine::general_purpose::STANDARD as STD_BASE64_ENGINE, Engine}; use rsa::{Pkcs1v15Encrypt, RsaPublicKey}; use serde::{Serialize, Serializer}; use serde_json::value::RawValue; +use sqlx_core::sql_str::SqlStr; use crate::{ arguments::ExaBuffer, options::ProtocolVersion, responses::ExaRwAttributes, ExaAttributes, @@ -198,20 +199,17 @@ impl Serialize for WithAttributes<'_, Fetch> { } /// Request to execute a single SQL statement. -// This is internall used in the IMPORT/EXPORT jobs as well, since they rely on query execution too. -// However, in these scenarios the query is an owned string, hence the usage of [`Cow`] here to -// support that. #[derive(Clone, Debug)] -pub struct Execute<'a>(pub Cow<'a, str>); +pub struct Execute(pub SqlStr); -impl Serialize for WithAttributes<'_, Execute<'_>> { +impl Serialize for WithAttributes<'_, Execute> { fn serialize(&self, serializer: S) -> Result where S: Serializer, { let command = Command::Execute { attributes: self.needs_send.then_some(self.attributes), - sql_text: self.request.0.as_ref(), + sql_text: self.request.0.as_str(), }; command.serialize(serializer) @@ -220,16 +218,115 @@ impl Serialize for WithAttributes<'_, Execute<'_>> { /// Request to execute a batch of SQL statements. #[derive(Clone, Debug)] -pub struct ExecuteBatch<'a>(pub Vec<&'a str>); +pub struct ExecuteBatch(pub SqlStr); -impl Serialize for WithAttributes<'_, ExecuteBatch<'_>> { +impl ExecuteBatch { + /// Splits a SQL query into individual statements. + /// + /// The splitting follows the following logic: + /// - trim the query to remove leading and trailing whitespace + /// - parse each character and store the string slice up to a ';' that is not inside a line or + /// block comment and not contained within single or double quotes + /// - register the next statement start as the next non-whitespace character after a split, + /// essentially ignoring whitespace between statements (but retaining comments) + /// - add the remainder string slice after the last ';' if it is not empty; this means that the + /// last statement could be a comment only, but that is okay as Exasol does not complain. + fn split_query(&self) -> Vec<&str> { + #[derive(Clone, Copy)] + enum Inside { + Statement, + LineComment, + BlockComment, + DoubleQuote, + SingleQuote, + Whitespace, + } + + let query = self.0.as_str().trim(); + // NOTE: Using [`char`] as the iterator element for the sake of `char::is_whitespace` which + // is more exhaustive than `u8::is_ascii_whitespace`. + let mut chars = query.char_indices().peekable(); + let mut state = Inside::Statement; + let mut statements = Vec::new(); + let mut start = 0; + + while let Some((i, c)) = chars.next() { + let mut peek = || chars.peek().map(|(_, c)| *c); + let is_whitespace = |p: Option| p.is_some_and(char::is_whitespace); + + #[allow(clippy::match_same_arms, reason = "better readability if split")] + match (state, c) { + // Line comment start + (Inside::Statement, '-') if Some('-') == peek() => { + chars.next(); + state = Inside::LineComment; + } + // Block comment start + (Inside::Statement, '/') if Some('*') == peek() => { + chars.next(); + state = Inside::BlockComment; + } + // Double quote start + (Inside::Statement, '"') => state = Inside::DoubleQuote, + // Single quote start + (Inside::Statement, '\'') => state = Inside::SingleQuote, + // Statement end + (Inside::Statement, ';') => { + statements.push(&query[start..=i]); + start = i + 1; + + // Whitespace between statements start + if is_whitespace(peek()) { + state = Inside::Whitespace; + } + } + // Skip escaped double quote + (Inside::DoubleQuote, '"') if Some('"') == peek() => { + chars.next(); + } + // Skip escaped single quote + (Inside::SingleQuote, '\'') if Some('\'') == peek() => { + chars.next(); + } + // Double quote end + (Inside::DoubleQuote, '"') => state = Inside::Statement, + // Single quote end + (Inside::SingleQuote, '\'') => state = Inside::Statement, + // Line comment end + (Inside::LineComment, '\n') => state = Inside::Statement, + // Block comment end + (Inside::BlockComment, '*') if Some('/') == peek() => { + chars.next(); + state = Inside::Statement; + } + // Whitespace between statements end + (Inside::Whitespace, _) if !is_whitespace(peek()) => { + start = i + 1; + state = Inside::Statement; + } + _ => (), + } + } + + // Add final part if anything remains after the last `;`. + // NOTE: Exasol does not complain about trailing comments, but only empty queries. + let remaining = &query[start..]; + if !remaining.is_empty() { + statements.push(remaining); + } + + statements + } +} + +impl Serialize for WithAttributes<'_, ExecuteBatch> { fn serialize(&self, serializer: S) -> Result where S: Serializer, { let command = Command::ExecuteBatch { attributes: self.needs_send.then_some(self.attributes), - sql_texts: &self.request.0, + sql_texts: &self.request.split_query(), }; command.serialize(serializer) @@ -238,16 +335,16 @@ impl Serialize for WithAttributes<'_, ExecuteBatch<'_>> { /// Request to create a prepared statement. #[derive(Clone, Debug)] -pub struct CreatePreparedStmt<'a>(pub &'a str); +pub struct CreatePreparedStmt(pub SqlStr); -impl Serialize for WithAttributes<'_, CreatePreparedStmt<'_>> { +impl Serialize for WithAttributes<'_, CreatePreparedStmt> { fn serialize(&self, serializer: S) -> Result where S: Serializer, { let command = Command::CreatePreparedStatement { attributes: self.needs_send.then_some(self.attributes), - sql_text: self.request.0, + sql_text: self.request.0.as_str(), }; command.serialize(serializer) @@ -324,7 +421,7 @@ pub struct ExaLoginRequest<'a> { impl ExaLoginRequest<'_> { /// Encrypts the password with the provided key. /// - /// When connecting using [`Login::Credentials`], Exasol first sends out a public key to encrypt + /// When connecting using [`LoginCreds`], Exasol first sends out a public key to encrypt /// the password with. pub fn encrypt_password(&mut self, key: &RsaPublicKey) -> SqlxResult<()> { let LoginRef::Credentials { password, .. } = &mut self.login else { @@ -443,12 +540,35 @@ enum Command<'a> { num_columns: usize, num_rows: usize, #[serde(skip_serializing_if = "<[ExaTypeInfo]>::is_empty")] + #[serde(serialize_with = "serialize_params")] columns: &'a [ExaTypeInfo], #[serde(skip_serializing_if = "PreparedStmtData::is_empty")] data: &'a PreparedStmtData, }, } +/// Thin serialization wrapper that helps respect the serialization format of +/// prepared statement parameters (same layout as [`crate::ExaColumn`]). +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +struct PreparedStmtParam<'a> { + data_type: &'a ExaTypeInfo, +} + +impl<'a> From<&'a ExaTypeInfo> for PreparedStmtParam<'a> { + fn from(data_type: &'a ExaTypeInfo) -> Self { + Self { data_type } + } +} + +/// Serialization helper function that maps a [`ExaTypeInfo`] reference to [`PreparedStmtParam`]. +fn serialize_params(params: &[ExaTypeInfo], serializer: S) -> Result +where + S: Serializer, +{ + serializer.collect_seq(params.iter().map(PreparedStmtParam::from)) +} + /// Type containing the parameters data to be passed as part of executing a prepared statement. /// It ensures the parameter sequence in the [`ExaBuffer`] is appropriately ended. #[derive(Debug, Clone)] @@ -485,3 +605,183 @@ impl From for PreparedStmtData { } } } + +#[cfg(test)] +mod tests { + use sqlx_core::sql_str::SqlStr; + + use super::ExecuteBatch; + + #[test] + fn test_simple_statements() { + assert_eq!( + ExecuteBatch(SqlStr::from_static( + "SELECT * FROM users; SELECT * FROM orders;" + )) + .split_query(), + vec!["SELECT * FROM users;", "SELECT * FROM orders;"] + ); + } + + #[test] + fn test_semicolon_in_single_quote() { + assert_eq!( + ExecuteBatch(SqlStr::from_static( + "SELECT ';' AS val; SELECT 'abc;def' AS val2;" + )) + .split_query(), + vec!["SELECT ';' AS val;", "SELECT 'abc;def' AS val2;"] + ); + } + + #[test] + fn test_semicolon_in_double_quote() { + assert_eq!( + ExecuteBatch(SqlStr::from_static("SELECT \"col;name\" FROM table;")).split_query(), + vec!["SELECT \"col;name\" FROM table;"] + ); + } + + #[test] + fn test_semicolon_in_line_comment() { + assert_eq!( + ExecuteBatch(SqlStr::from_static( + "SELECT 1; -- this is a comment; with a semicolon\nSELECT 2;" + )) + .split_query(), + vec![ + "SELECT 1;", + "-- this is a comment; with a semicolon\nSELECT 2;" + ] + ); + } + + #[test] + fn test_semicolon_in_block_comment() { + assert_eq!( + ExecuteBatch(SqlStr::from_static( + "SELECT 1; /* multi-line ; comment */ SELECT 2;" + )) + .split_query(), + vec!["SELECT 1;", "/* multi-line ; comment */ SELECT 2;"] + ); + } + + #[test] + fn test_escaped_quotes() { + assert_eq!( + ExecuteBatch(SqlStr::from_static( + "SELECT 'It''s a test; really'; SELECT \"escaped\"\"quote\" FROM dual;" + )) + .split_query(), + vec![ + "SELECT 'It''s a test; really';", + "SELECT \"escaped\"\"quote\" FROM dual;" + ] + ); + } + + #[test] + fn test_trailing_semicolon_and_whitespace() { + assert_eq!( + ExecuteBatch(SqlStr::from_static("SELECT 1;; \n \n;")).split_query(), + vec!["SELECT 1;", ";", ";"] + ); + } + + #[test] + fn test_leading_semicolon() { + assert_eq!( + ExecuteBatch(SqlStr::from_static(";SELECT 1;")).split_query(), + vec![";", "SELECT 1;"] + ); + } + + #[test] + fn test_leading_semicolon_and_whitespace() { + assert_eq!( + ExecuteBatch(SqlStr::from_static(" ; SELECT 1;")).split_query(), + vec![";", "SELECT 1;"] + ); + } + + #[test] + fn test_no_semicolon() { + assert_eq!( + ExecuteBatch(SqlStr::from_static("SELECT 1")).split_query(), + vec!["SELECT 1"] + ); + } + + #[test] + fn test_no_whitespace_between_statements() { + assert_eq!( + ExecuteBatch(SqlStr::from_static("SELECT 1;SELECT 2")).split_query(), + vec!["SELECT 1;", "SELECT 2"] + ); + } + + #[test] + fn test_no_whitespace_between_stmt_and_comment() { + assert_eq!( + ExecuteBatch(SqlStr::from_static("SELECT 1;/*testing*/SELECT 2;")).split_query(), + vec!["SELECT 1;", "/*testing*/SELECT 2;"] + ); + } + + #[test] + fn test_trailing_comment() { + assert_eq!( + ExecuteBatch(SqlStr::from_static("SELECT 1;/*testing*/")).split_query(), + vec!["SELECT 1;", "/*testing*/"] + ); + } + + #[test] + fn test_whitespace_between_statements() { + let query = " + /* Writing some comments */ + SELECT 1; + + -- Then writing some more comments + SELECT 2; + "; + assert_eq!( + ExecuteBatch(SqlStr::from_static(query)).split_query(), + vec![ + "/* Writing some comments */ + SELECT 1;", + "-- Then writing some more comments + SELECT 2;" + ] + ); + } + + #[test] + fn test_empty_input() { + assert_eq!( + ExecuteBatch(SqlStr::from_static("")).split_query(), + Vec::<&str>::new() + ); + } + + #[test] + fn test_mixed_content() { + let query = r#" + SELECT 'test;--'; -- line comment with ; + /* block comment ; + over lines */ + SELECT "str;with;semicolons"; + "#; + assert_eq!( + ExecuteBatch(SqlStr::from_static(query)).split_query(), + vec![ + "SELECT 'test;--';", + r#"-- line comment with ; + /* block comment ; + over lines */ + SELECT "str;with;semicolons";"# + ] + ); + } +} diff --git a/src/connection/websocket/socket.rs b/sqlx-exasol-impl/src/connection/websocket/socket.rs similarity index 100% rename from src/connection/websocket/socket.rs rename to sqlx-exasol-impl/src/connection/websocket/socket.rs diff --git a/src/connection/websocket/tls.rs b/sqlx-exasol-impl/src/connection/websocket/tls.rs similarity index 100% rename from src/connection/websocket/tls.rs rename to sqlx-exasol-impl/src/connection/websocket/tls.rs diff --git a/src/connection/websocket/transport/compressed.rs b/sqlx-exasol-impl/src/connection/websocket/transport/compressed.rs similarity index 81% rename from src/connection/websocket/transport/compressed.rs rename to sqlx-exasol-impl/src/connection/websocket/transport/compressed.rs index 6a2d987a..1d3320e1 100644 --- a/src/connection/websocket/transport/compressed.rs +++ b/sqlx-exasol-impl/src/connection/websocket/transport/compressed.rs @@ -25,7 +25,7 @@ pub struct CompressedWebSocket { /// Future for the currently decoding message. decoding: Option>>>, /// Future for the currently encoding message. - encoding: Option>>>, + encoding: EncodingState, } impl Stream for CompressedWebSocket { @@ -75,31 +75,37 @@ impl Sink for CompressedWebSocket { fn start_send(mut self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> { // Sanity check - if self.encoding.is_some() { + if !matches!(self.encoding, EncodingState::Ready) { return Err(ExaProtocolError::SendNotReady)?; } // Register the item for compression. let bytes = item.into_bytes().into_boxed_slice().into(); - self.encoding = Some(Compression::new(bytes, 0)); + self.encoding = EncodingState::Buffered(Compression::new(bytes, 0)); Ok(()) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - if let Some(future) = self.encoding.as_mut() { + match &mut self.encoding { // Compress the last registered item. - let bytes = ready!(future.poll_unpin(cx))?; - self.encoding = None; - self.inner - .start_send_unpin(Message::Binary(bytes)) - .map_err(ToSqlxError::to_sqlx_err)?; - } else { + EncodingState::Buffered(future) => { + let bytes = ready!(future.poll_unpin(cx))?; + self.encoding = EncodingState::NeedsFlush; + self.inner + .start_send_unpin(Message::Binary(bytes)) + .map_err(ToSqlxError::to_sqlx_err)?; + } // Flush the compressed message. - return self - .inner - .poll_flush_unpin(cx) - .map_err(ToSqlxError::to_sqlx_err); + EncodingState::NeedsFlush => { + ready!(self + .inner + .poll_flush_unpin(cx) + .map_err(ToSqlxError::to_sqlx_err))?; + + self.encoding = EncodingState::Ready; + } + EncodingState::Ready => return Poll::Ready(Ok(())), } } } @@ -116,11 +122,21 @@ impl From for CompressedWebSocket { Self { inner: value.0, decoding: None, - encoding: None, + encoding: EncodingState::Ready, } } } +/// Enum containing the message encoding state. +/// Necessary because blindly flushing without sending any data does not play well +/// with `rustls`, although `native-tls` does not have a problem with that. +#[derive(Debug)] +enum EncodingState { + Buffered(Compression>>), + NeedsFlush, + Ready, +} + /// Future for awaiting the compression/decompression of a message. #[derive(Debug)] struct Compression { diff --git a/src/connection/websocket/transport/mod.rs b/sqlx-exasol-impl/src/connection/websocket/transport/mod.rs similarity index 93% rename from src/connection/websocket/transport/mod.rs rename to sqlx-exasol-impl/src/connection/websocket/transport/mod.rs index 6fa84e85..d3334011 100644 --- a/src/connection/websocket/transport/mod.rs +++ b/sqlx-exasol-impl/src/connection/websocket/transport/mod.rs @@ -30,11 +30,14 @@ pub enum MaybeCompressedWebSocket { impl MaybeCompressedWebSocket { /// Consumes `self` to output a possibly different variant, depending on whether compression is /// wanted and enabled. + #[allow(unused_variables, reason = "conditionally compiled")] pub fn maybe_compress(self, use_compression: bool) -> Self { - match (self, use_compression) { + match self { #[cfg(feature = "compression")] - (Self::Plain(plain), true) => MaybeCompressedWebSocket::Compressed(plain.into()), - (ws, _) => ws, + Self::Plain(plain) if use_compression => { + MaybeCompressedWebSocket::Compressed(plain.into()) + } + ws => ws, } } diff --git a/src/connection/websocket/transport/uncompressed.rs b/sqlx-exasol-impl/src/connection/websocket/transport/uncompressed.rs similarity index 100% rename from src/connection/websocket/transport/uncompressed.rs rename to sqlx-exasol-impl/src/connection/websocket/transport/uncompressed.rs diff --git a/src/database.rs b/sqlx-exasol-impl/src/database.rs similarity index 95% rename from src/database.rs rename to sqlx-exasol-impl/src/database.rs index 2050607b..9c357e2c 100644 --- a/src/database.rs +++ b/sqlx-exasol-impl/src/database.rs @@ -41,7 +41,7 @@ impl Database for Exasol { type ArgumentBuffer<'q> = ExaBuffer; - type Statement<'q> = ExaStatement<'q>; + type Statement = ExaStatement; } impl HasStatementCache for Exasol {} diff --git a/src/error.rs b/sqlx-exasol-impl/src/error.rs similarity index 94% rename from src/error.rs rename to sqlx-exasol-impl/src/error.rs index bc9ff8ff..6e0ad4f5 100644 --- a/src/error.rs +++ b/sqlx-exasol-impl/src/error.rs @@ -5,7 +5,7 @@ use rsa::errors::Error as RsaError; use serde_json::error::Error as JsonError; use thiserror::Error as ThisError; -use crate::{type_info::DataTypeName, SqlxError}; +use crate::SqlxError; /// Enum representing protocol implementation errors. #[derive(Debug, ThisError)] @@ -20,8 +20,6 @@ pub enum ExaProtocolError { SendNotReady, #[error("no response received")] NoResponse, - #[error("type mismatch: expected SQL type `{0}` but was provided `{1}`")] - DatatypeMismatch(DataTypeName, DataTypeName), #[error("server closed connection; info: {0}")] WebSocketClosed(CloseError), #[error("feature 'compression' must be enabled to use compression")] diff --git a/sqlx-exasol-impl/src/lib.rs b/sqlx-exasol-impl/src/lib.rs new file mode 100644 index 00000000..40acc753 --- /dev/null +++ b/sqlx-exasol-impl/src/lib.rs @@ -0,0 +1,75 @@ +#![cfg_attr(not(test), warn(unused_crate_dependencies))] +//! **EXASOL** database driver. + +#[cfg(feature = "native-tls")] +use native_tls as _; +#[cfg(feature = "tls")] +use rcgen as _; +#[cfg(feature = "rustls")] +use rustls as _; + +#[cfg(feature = "any")] +pub mod any; +mod arguments; +mod column; +mod connection; +mod database; +mod error; +#[cfg(feature = "migrate")] +mod migrate; +mod options; +mod query_result; +mod responses; +mod row; +mod statement; +#[cfg(feature = "migrate")] +mod testing; +mod transaction; +#[cfg(feature = "macros")] +mod type_checking; +mod type_info; +pub mod types; +mod value; + +pub use arguments::ExaArguments; +pub use column::ExaColumn; +#[cfg(feature = "etl")] +pub use connection::etl; +pub use connection::ExaConnection; +pub use database::Exasol; +pub use options::{ExaCompressionMode, ExaConnectOptions, ExaConnectOptionsBuilder, ExaSslMode}; +pub use query_result::ExaQueryResult; +pub use responses::{ExaAttributes, ExaDatabaseError, SessionInfo}; +pub use row::ExaRow; +use sqlx_core::{ + executor::Executor, impl_acquire, impl_column_index_for_row, impl_column_index_for_statement, + impl_into_arguments_for_arguments, +}; +pub use statement::ExaStatement; +pub use transaction::ExaTransactionManager; +#[doc(hidden)] +#[cfg(feature = "macros")] +pub use type_checking::QUERY_DRIVER; +pub use type_info::ExaTypeInfo; +pub use value::{ExaValue, ExaValueRef}; + +/// An alias for [`Pool`][sqlx_core::pool::Pool], specialized for Exasol. +pub type ExaPool = sqlx_core::pool::Pool; + +/// An alias for [`PoolOptions`][sqlx_core::pool::PoolOptions], specialized for Exasol. +pub type ExaPoolOptions = sqlx_core::pool::PoolOptions; + +/// An alias for [`Executor<'_, Database = Exasol>`][Executor]. +pub trait ExaExecutor<'c>: Executor<'c, Database = Exasol> {} +impl<'c, T: Executor<'c, Database = Exasol>> ExaExecutor<'c> for T {} + +impl_into_arguments_for_arguments!(ExaArguments); +impl_acquire!(Exasol, ExaConnection); +impl_column_index_for_row!(ExaRow); +impl_column_index_for_statement!(ExaStatement); + +// ################### +// ##### Aliases ##### +// ################### +type SqlxError = sqlx_core::Error; +type SqlxResult = sqlx_core::Result; diff --git a/sqlx-exasol-impl/src/migrate.rs b/sqlx-exasol-impl/src/migrate.rs new file mode 100644 index 00000000..e60a2056 --- /dev/null +++ b/sqlx-exasol-impl/src/migrate.rs @@ -0,0 +1,257 @@ +use std::{ + str::FromStr, + time::{Duration, Instant}, +}; + +use futures_core::future::BoxFuture; +use sqlx_core::{ + connection::{ConnectOptions, Connection}, + executor::Executor, + migrate::{AppliedMigration, Migrate, MigrateDatabase, MigrateError, Migration}, + sql_str::AssertSqlSafe, +}; + +use crate::{ + connection::{ + websocket::future::{ExecuteBatch, WebSocketFuture}, + ExaConnection, + }, + database::Exasol, + options::ExaConnectOptions, + SqlxError, SqlxResult, +}; + +const LOCK_WARN: &str = "Exasol does not support database locking!"; + +fn parse_for_maintenance(url: &str) -> SqlxResult<(ExaConnectOptions, String)> { + let mut options = ExaConnectOptions::from_str(url)?; + + let database = options.schema.ok_or_else(|| { + SqlxError::Configuration("DATABASE_URL does not specify a database".into()) + })?; + + // switch to database for create/drop commands + options.schema = None; + + Ok((options, database)) +} + +impl MigrateDatabase for Exasol { + async fn create_database(url: &str) -> SqlxResult<()> { + let (options, database) = parse_for_maintenance(url)?; + // Escape double quotes because we'll quote the database identifier. + let database = database.replace('"', "\"\""); + let mut conn = options.connect().await?; + + let query = format!(r#"CREATE SCHEMA "{database}";"#); + conn.execute(AssertSqlSafe(query)).await?; + + Ok(()) + } + + async fn database_exists(url: &str) -> SqlxResult { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + let query = "SELECT true FROM exa_schemas WHERE schema_name = ?"; + let exists: bool = sqlx_core::query_scalar::query_scalar(query) + .bind(database) + .fetch_optional(&mut conn) + .await? + .unwrap_or_default(); + + Ok(exists) + } + + async fn drop_database(url: &str) -> SqlxResult<()> { + let (options, database) = parse_for_maintenance(url)?; + // Escape double quotes because we'll quote the database identifier. + let database = database.replace('"', "\"\""); + let mut conn = options.connect().await?; + + let query = format!(r#"DROP SCHEMA IF EXISTS "{database}" CASCADE;"#); + conn.execute(AssertSqlSafe(query)).await?; + + Ok(()) + } +} + +impl Migrate for ExaConnection { + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + let query = format!(r#"CREATE SCHEMA IF NOT EXISTS "{schema_name}";"#); + self.execute(AssertSqlSafe(query)).await?; + Ok(()) + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + let query = format!( + r#" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + version DECIMAL(20, 0), + description CLOB NOT NULL, + installed_on TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + success BOOLEAN NOT NULL, + checksum CLOB NOT NULL, + execution_time DECIMAL(20, 0) NOT NULL + );"# + ); + self.execute(AssertSqlSafe(query)).await?; + Ok(()) + }) + } + + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async move { + let query = format!( + r#" + SELECT version + FROM "{table_name}" + WHERE success = false + ORDER BY version + LIMIT 1; + "# + ); + let row: Option<(i64,)> = sqlx_core::query_as::query_as(AssertSqlSafe(query)) + .fetch_optional(self) + .await?; + Ok(row.map(|r| r.0)) + }) + } + + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async move { + let query = format!( + r#" + SELECT version, checksum + FROM "{table_name}" + ORDER BY version + "# + ); + + let rows: Vec<(i64, String)> = sqlx_core::query_as::query_as(AssertSqlSafe(query)) + .fetch_all(self) + .await?; + let mut migrations = Vec::with_capacity(rows.len()); + + for (version, checksum) in rows { + let checksum = hex::decode(checksum) + .map_err(From::from) + .map_err(MigrateError::Source)? + .into(); + + let migration = AppliedMigration { version, checksum }; + migrations.push(migration); + } + + Ok(migrations) + }) + } + + fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + Box::pin(async move { + tracing::warn!("{LOCK_WARN}"); + Ok(()) + }) + } + + fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + Box::pin(async move { + tracing::warn!("{LOCK_WARN}"); + Ok(()) + }) + } + + fn apply<'e>( + &'e mut self, + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async move { + let mut tx = self.begin().await?; + let start = Instant::now(); + + ExecuteBatch::new(migration.sql.clone()) + .future(&mut tx.ws) + .await?; + + let checksum = hex::encode(&*migration.checksum); + + let query = format!( + r#" + INSERT INTO "{table_name}" ( version, description, success, checksum, execution_time ) + VALUES ( ?, ?, TRUE, ?, -1 ); + "# + ); + + let _ = sqlx_core::query::query(AssertSqlSafe(query)) + .bind(migration.version) + .bind(&*migration.description) + .bind(checksum) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + let elapsed = start.elapsed(); + + let query = format!( + r#" + UPDATE "{table_name}" + SET execution_time = ? + WHERE version = ? + "# + ); + + #[allow(clippy::cast_possible_truncation)] + let _ = sqlx_core::query::query(AssertSqlSafe(query)) + .bind(elapsed.as_nanos() as i64) + .bind(migration.version) + .execute(self) + .await?; + + Ok(elapsed) + }) + } + + fn revert<'e>( + &'e mut self, + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async move { + let mut tx = self.begin().await?; + let start = Instant::now(); + + ExecuteBatch::new(migration.sql.clone()) + .future(&mut tx.ws) + .await?; + + let query = format!(r#" DELETE FROM "{table_name}" WHERE version = ? "#); + let _ = sqlx_core::query::query(AssertSqlSafe(query)) + .bind(migration.version) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + let elapsed = start.elapsed(); + + Ok(elapsed) + }) + } +} diff --git a/src/options/builder.rs b/sqlx-exasol-impl/src/options/builder.rs similarity index 91% rename from src/options/builder.rs rename to sqlx-exasol-impl/src/options/builder.rs index cb452b4b..003e4a07 100644 --- a/src/options/builder.rs +++ b/sqlx-exasol-impl/src/options/builder.rs @@ -6,7 +6,7 @@ use super::{ error::ExaConfigError, ssl_mode::ExaSslMode, ExaConnectOptions, Login, ProtocolVersion, DEFAULT_CACHE_CAPACITY, DEFAULT_FETCH_SIZE, DEFAULT_PORT, }; -use crate::SqlxResult; +use crate::{options::compression::ExaCompressionMode, SqlxResult}; /// Builder for [`ExaConnectOptions`]. #[derive(Clone, Debug)] @@ -26,7 +26,7 @@ pub struct ExaConnectOptionsBuilder { protocol_version: ProtocolVersion, fetch_size: usize, query_timeout: u64, - compression: bool, + compression_mode: ExaCompressionMode, feedback_interval: u64, } @@ -45,10 +45,10 @@ impl Default for ExaConnectOptionsBuilder { access_token: None, refresh_token: None, schema: None, - protocol_version: ProtocolVersion::V3, + protocol_version: ProtocolVersion::default(), fetch_size: DEFAULT_FETCH_SIZE, query_timeout: 0, - compression: false, + compression_mode: ExaCompressionMode::default(), feedback_interval: 1, } } @@ -69,10 +69,11 @@ impl ExaConnectOptionsBuilder { (Some(username), None, None) => Login::Credentials { username, password }, (None, Some(access_token), None) => Login::AccessToken { access_token }, (None, None, Some(refresh_token)) => Login::RefreshToken { refresh_token }, + (None, None, None) => return Err(ExaConfigError::MissingAuthMethod.into()), _ => return Err(ExaConfigError::MultipleAuthMethods.into()), }; - let hosts = Self::parse_hostname(hostname); + let hosts = Self::parse_hostname(&hostname); let mut hosts_details = Vec::with_capacity(hosts.len()); for host in hosts { @@ -88,12 +89,13 @@ impl ExaConnectOptionsBuilder { ssl_client_cert: self.ssl_client_cert, ssl_client_key: self.ssl_client_key, statement_cache_capacity: self.statement_cache_capacity, + hostname, login, schema: self.schema, protocol_version: self.protocol_version, fetch_size: self.fetch_size, query_timeout: self.query_timeout, - compression: self.compression, + compression_mode: self.compression_mode, feedback_interval: self.feedback_interval, log_settings: LogSettings::default(), }; @@ -173,12 +175,6 @@ impl ExaConnectOptionsBuilder { self } - #[must_use = "call build() to get connection options"] - pub fn protocol_version(mut self, protocol_version: ProtocolVersion) -> Self { - self.protocol_version = protocol_version; - self - } - #[must_use = "call build() to get connection options"] pub fn fetch_size(mut self, fetch_size: usize) -> Self { self.fetch_size = fetch_size; @@ -192,14 +188,8 @@ impl ExaConnectOptionsBuilder { } #[must_use = "call build() to get connection options"] - pub fn compression(mut self, compression: bool) -> Self { - let feature_flag = cfg!(feature = "compression"); - - if feature_flag && !compression { - tracing::warn!("compression cannot be enabled without the 'compression' feature"); - } - - self.compression = compression && feature_flag; + pub fn compression_mode(mut self, compression_mode: ExaCompressionMode) -> Self { + self.compression_mode = compression_mode; self } @@ -215,9 +205,9 @@ impl ExaConnectOptionsBuilder { /// /// We do expect the range to be in the ascending order though, so `hostname4..1.com` won't /// work. - fn parse_hostname(hostname: String) -> Vec { + fn parse_hostname(hostname: &str) -> Vec { // If multiple hosts could not be generated, then the given hostname is the only one. - Self::_parse_hostname(&hostname).unwrap_or_else(|| vec![hostname]) + Self::_parse_hostname(hostname).unwrap_or_else(|| vec![hostname.to_owned()]) } /// This method is used to attempt to generate multiple hosts out of the given hostname. @@ -303,7 +293,7 @@ mod tests { fn test_simple_ip() { let hostname = "10.10.10.10"; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, vec!(hostname)); } @@ -312,7 +302,7 @@ mod tests { let hostname = "10.10.10.1..3"; let expected = vec!["10.10.10.1", "10.10.10.2", "10.10.10.3"]; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, expected); } @@ -321,7 +311,7 @@ mod tests { let hostname = "1..3.10.10.10"; let expected = vec!["1.10.10.10", "2.10.10.10", "3.10.10.10"]; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, expected); } @@ -329,7 +319,7 @@ mod tests { fn test_simple_hostname() { let hostname = "myhost.com"; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, vec!(hostname)); } @@ -338,7 +328,7 @@ mod tests { let hostname = "myhost1..4.com"; let expected = vec!["myhost1.com", "myhost2.com", "myhost3.com", "myhost4.com"]; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, expected); } @@ -347,7 +337,7 @@ mod tests { let hostname = "myhost125..127.com"; let expected = vec!["myhost125.com", "myhost126.com", "myhost127.com"]; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, expected); } @@ -355,7 +345,7 @@ mod tests { fn test_hostname_with_inverse_range() { let hostname = "myhost127..125.com"; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert!(generated.is_empty()); } @@ -363,7 +353,7 @@ mod tests { fn test_hostname_with_numbers_no_range() { let hostname = "myhost1.4.com"; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, vec![hostname]); } @@ -371,7 +361,7 @@ mod tests { fn test_hostname_with_range_one_numbers() { let hostname = "myhost1..b.com"; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, vec![hostname]); } @@ -379,7 +369,7 @@ mod tests { fn test_hostname_with_range_no_numbers() { let hostname = "myhosta..b.com"; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, vec![hostname]); } @@ -387,7 +377,7 @@ mod tests { fn test_hostname_starts_with_range() { let hostname = "..myhost.com"; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, vec![hostname]); } @@ -395,7 +385,7 @@ mod tests { fn test_hostname_ends_with_range() { let hostname = "myhost.com.."; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, vec![hostname]); } @@ -408,7 +398,7 @@ mod tests { "myhosta..bcdef3.com", ]; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, expected); } @@ -421,7 +411,7 @@ mod tests { "myhost3cdef4..7.com", ]; - let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned()); + let generated = ExaConnectOptionsBuilder::parse_hostname(hostname); assert_eq!(generated, expected); } } diff --git a/sqlx-exasol-impl/src/options/compression.rs b/sqlx-exasol-impl/src/options/compression.rs new file mode 100644 index 00000000..b3f7c4e6 --- /dev/null +++ b/sqlx-exasol-impl/src/options/compression.rs @@ -0,0 +1,52 @@ +use std::str::FromStr; + +use super::{error::ExaConfigError, COMPRESSION}; + +/// Options for controlling the desired compression behavior of the connection to the Exasol server. +/// +/// It is used by [`crate::options::builder::ExaConnectOptionsBuilder::compression_mode`]. +#[derive(Debug, Clone, Copy, Default, PartialEq)] +pub enum ExaCompressionMode { + /// Establish an uncompressed connection. + Disabled, + + /// Establish a compressed connection if the compression feature is enabled, falling back to an + /// uncompressed connection if it's not. + /// + /// This is the default if `compression` is not specified. + #[default] + Preferred, + + /// Establish a compressed connection if the compression feature is enabled. + /// The connection attempt fails if a compressed connection cannot be established. + Required, +} + +impl ExaCompressionMode { + const DISABLED: &str = "disabled"; + const PREFERRED: &str = "preferred"; + const REQUIRED: &str = "required"; +} + +impl FromStr for ExaCompressionMode { + type Err = ExaConfigError; + + fn from_str(s: &str) -> Result { + Ok(match &*s.to_ascii_lowercase() { + Self::DISABLED => ExaCompressionMode::Disabled, + Self::PREFERRED => ExaCompressionMode::Preferred, + Self::REQUIRED => ExaCompressionMode::Required, + _ => Err(ExaConfigError::InvalidParameter(COMPRESSION))?, + }) + } +} + +impl AsRef for ExaCompressionMode { + fn as_ref(&self) -> &str { + match self { + ExaCompressionMode::Disabled => Self::DISABLED, + ExaCompressionMode::Preferred => Self::PREFERRED, + ExaCompressionMode::Required => Self::REQUIRED, + } + } +} diff --git a/src/options/error.rs b/sqlx-exasol-impl/src/options/error.rs similarity index 91% rename from src/options/error.rs rename to sqlx-exasol-impl/src/options/error.rs index d75cc03c..23ebe70f 100644 --- a/src/options/error.rs +++ b/sqlx-exasol-impl/src/options/error.rs @@ -9,6 +9,8 @@ pub enum ExaConfigError { MissingHost, #[error("could not resolve hostname")] CouldNotResolve(#[from] std::io::Error), + #[error("no authentication method provided")] + MissingAuthMethod, #[error("multiple authentication methods provided")] MultipleAuthMethods, #[error("invalid URL scheme: {0}, expected: {}", URL_SCHEME)] diff --git a/sqlx-exasol-impl/src/options/mod.rs b/sqlx-exasol-impl/src/options/mod.rs new file mode 100644 index 00000000..bc71de7b --- /dev/null +++ b/sqlx-exasol-impl/src/options/mod.rs @@ -0,0 +1,649 @@ +mod builder; +mod compression; +mod error; +mod protocol_version; +mod ssl_mode; + +use std::{borrow::Cow, net::SocketAddr, num::NonZeroUsize, path::PathBuf, str::FromStr}; + +pub use builder::ExaConnectOptionsBuilder; +pub use compression::ExaCompressionMode; +use error::ExaConfigError; +pub use protocol_version::ProtocolVersion; +use sqlx_core::{ + connection::{ConnectOptions, LogSettings}, + net::tls::CertificateInput, + percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}, +}; +pub use ssl_mode::ExaSslMode; +use tracing::log; +use url::Url; + +use crate::{ + connection::{ + websocket::request::{ExaLoginRequest, LoginRef}, + ExaConnection, + }, + error::ExaProtocolError, + responses::ExaRwAttributes, + SqlxError, SqlxResult, +}; + +const URL_SCHEME: &str = "exa"; + +const DEFAULT_FETCH_SIZE: usize = 5 * 1024 * 1024; +const DEFAULT_PORT: u16 = 8563; +const DEFAULT_CACHE_CAPACITY: NonZeroUsize = match NonZeroUsize::new(100) { + Some(v) => v, + None => unreachable!(), +}; + +const ACCESS_TOKEN: &str = "access-token"; +const REFRESH_TOKEN: &str = "refresh-token"; +const SSL_MODE: &str = "ssl-mode"; +const SSL_CA: &str = "ssl-ca"; +const SSL_CERT: &str = "ssl-cert"; +const SSL_KEY: &str = "ssl-key"; +const STATEMENT_CACHE_CAPACITY: &str = "statement-cache-capacity"; +const FETCH_SIZE: &str = "fetch-size"; +const QUERY_TIMEOUT: &str = "query-timeout"; +const COMPRESSION: &str = "compression"; +const FEEDBACK_INTERVAL: &str = "feedback-interval"; + +/// Options for connecting to the Exasol database. Implementor of [`ConnectOptions`]. +/// +/// While generally automatically created through a connection string, +/// [`ExaConnectOptions::builder()`] can be used to get a [`ExaConnectOptionsBuilder`]. +#[derive(Debug, Clone)] +pub struct ExaConnectOptions { + pub(crate) hosts_details: Vec<(String, Vec)>, + pub(crate) port: u16, + pub(crate) ssl_mode: ExaSslMode, + pub(crate) ssl_ca: Option, + pub(crate) ssl_client_cert: Option, + pub(crate) ssl_client_key: Option, + pub(crate) statement_cache_capacity: NonZeroUsize, + pub(crate) schema: Option, + pub(crate) compression_mode: ExaCompressionMode, + pub(crate) log_settings: LogSettings, + hostname: String, + login: Login, + protocol_version: ProtocolVersion, + fetch_size: usize, + query_timeout: u64, + feedback_interval: u64, +} + +impl ExaConnectOptions { + #[must_use] + pub fn builder() -> ExaConnectOptionsBuilder { + ExaConnectOptionsBuilder::default() + } +} + +impl FromStr for ExaConnectOptions { + type Err = SqlxError; + + fn from_str(s: &str) -> Result { + let url = Url::parse(s) + .map_err(From::from) + .map_err(SqlxError::Configuration)?; + Self::from_url(&url) + } +} + +impl ConnectOptions for ExaConnectOptions { + type Connection = ExaConnection; + + fn from_url(url: &Url) -> SqlxResult { + let scheme = url.scheme(); + + if URL_SCHEME != scheme { + return Err(ExaConfigError::InvalidUrlScheme(scheme.to_owned()).into()); + } + + let mut builder = Self::builder(); + + if let Some(host) = url.host_str() { + builder = builder.host(host.to_owned()); + } + + let username = url.username(); + if !username.is_empty() { + let username = percent_decode_str(username) + .decode_utf8() + .map_err(SqlxError::config)?; + builder = builder.username(username.to_string()); + } + + if let Some(password) = url.password() { + let password = percent_decode_str(password) + .decode_utf8() + .map_err(SqlxError::config)?; + builder = builder.password(password.to_string()); + } + + if let Some(port) = url.port() { + builder = builder.port(port); + } + + let path = url.path().trim_start_matches('/'); + + if !path.is_empty() { + let db_schema = percent_decode_str(path) + .decode_utf8() + .map_err(SqlxError::config)?; + builder = builder.schema(db_schema.to_string()); + } + + for (name, value) in url.query_pairs() { + match name.as_ref() { + ACCESS_TOKEN => builder = builder.access_token(value.to_string()), + + REFRESH_TOKEN => builder = builder.refresh_token(value.to_string()), + + SSL_MODE => { + let ssl_mode = value.parse::()?; + builder = builder.ssl_mode(ssl_mode); + } + + SSL_CA => { + let ssl_ca = CertificateInput::File(PathBuf::from(value.to_string())); + builder = builder.ssl_ca(ssl_ca); + } + + SSL_CERT => { + let ssl_cert = CertificateInput::File(PathBuf::from(value.to_string())); + builder = builder.ssl_client_cert(ssl_cert); + } + + SSL_KEY => { + let ssl_key = CertificateInput::File(PathBuf::from(value.to_string())); + builder = builder.ssl_client_key(ssl_key); + } + + STATEMENT_CACHE_CAPACITY => { + let capacity = value + .parse::() + .map_err(|_| ExaConfigError::InvalidParameter(STATEMENT_CACHE_CAPACITY))?; + builder = builder.statement_cache_capacity(capacity); + } + + FETCH_SIZE => { + let fetch_size = value + .parse::() + .map_err(|_| ExaConfigError::InvalidParameter(FETCH_SIZE))?; + builder = builder.fetch_size(fetch_size); + } + + QUERY_TIMEOUT => { + let query_timeout = value + .parse::() + .map_err(|_| ExaConfigError::InvalidParameter(QUERY_TIMEOUT))?; + builder = builder.query_timeout(query_timeout); + } + + COMPRESSION => { + let compression_mode = value + .parse::() + .map_err(|_| ExaConfigError::InvalidParameter(COMPRESSION))?; + builder = builder.compression_mode(compression_mode); + } + + FEEDBACK_INTERVAL => { + let feedback_interval = value + .parse::() + .map_err(|_| ExaConfigError::InvalidParameter(FEEDBACK_INTERVAL))?; + builder = builder.feedback_interval(feedback_interval); + } + + _ => { + return Err(SqlxError::Protocol(format!( + "Unknown connection string parameter: {value}" + ))) + } + } + } + + builder.build() + } + + fn to_url_lossy(&self) -> Url { + let mut url = Url::parse(&format!("{URL_SCHEME}://{}:{}", self.hostname, self.port)) + .expect("generated URL must be correct"); + + if let Some(schema) = &self.schema { + url.set_path(schema); + } + + match &self.login { + Login::Credentials { username, password } => { + url.set_username(username).ok(); + let password = utf8_percent_encode(password, NON_ALPHANUMERIC).to_string(); + url.set_password(Some(&password)).ok(); + } + Login::AccessToken { access_token } => { + url.query_pairs_mut() + .append_pair(ACCESS_TOKEN, access_token); + } + Login::RefreshToken { refresh_token } => { + url.query_pairs_mut() + .append_pair(REFRESH_TOKEN, refresh_token); + } + } + + url.query_pairs_mut() + .append_pair(SSL_MODE, self.ssl_mode.as_ref()); + + if let Some(ssl_ca) = &self.ssl_ca { + url.query_pairs_mut() + .append_pair(SSL_CA, &ssl_ca.to_string()); + } + + if let Some(ssl_cert) = &self.ssl_client_cert { + url.query_pairs_mut() + .append_pair(SSL_CERT, &ssl_cert.to_string()); + } + + if let Some(ssl_key) = &self.ssl_client_key { + url.query_pairs_mut() + .append_pair(SSL_KEY, &ssl_key.to_string()); + } + + url.query_pairs_mut().append_pair( + STATEMENT_CACHE_CAPACITY, + &self.statement_cache_capacity.to_string(), + ); + + url.query_pairs_mut() + .append_pair(FETCH_SIZE, &self.fetch_size.to_string()); + + url.query_pairs_mut() + .append_pair(QUERY_TIMEOUT, &self.query_timeout.to_string()); + + url.query_pairs_mut() + .append_pair(COMPRESSION, self.compression_mode.as_ref()); + + url.query_pairs_mut() + .append_pair(FEEDBACK_INTERVAL, &self.feedback_interval.to_string()); + + url + } + + async fn connect(&self) -> SqlxResult + where + Self::Connection: Sized, + { + ExaConnection::establish(self).await + } + + fn log_statements(mut self, level: log::LevelFilter) -> Self { + self.log_settings.log_statements(level); + self + } + + fn log_slow_statements( + mut self, + level: log::LevelFilter, + duration: std::time::Duration, + ) -> Self { + self.log_settings.log_slow_statements(level, duration); + self + } +} + +impl<'a> TryFrom<&'a ExaConnectOptions> for ExaLoginRequest<'a> { + type Error = ExaProtocolError; + + fn try_from(value: &'a ExaConnectOptions) -> Result { + let crate_version = option_env!("CARGO_PKG_VERSION").unwrap_or("UNKNOWN"); + + let attributes = ExaRwAttributes::new( + value.schema.as_deref().map(Cow::Borrowed), + value.feedback_interval, + value.query_timeout, + ); + + let compression_supported = cfg!(feature = "compression"); + + let use_compression = match value.compression_mode { + ExaCompressionMode::Disabled => false, + ExaCompressionMode::Preferred if !compression_supported => { + tracing::debug!("not using compression: compression support not compiled in"); + false + } + ExaCompressionMode::Preferred => true, + ExaCompressionMode::Required if compression_supported => true, + ExaCompressionMode::Required => return Err(ExaProtocolError::CompressionDisabled), + }; + + let output = Self { + protocol_version: value.protocol_version, + fetch_size: value.fetch_size, + statement_cache_capacity: value.statement_cache_capacity, + login: (&value.login).into(), + use_compression, + client_name: "sqlx-exasol", + client_version: crate_version, + client_os: std::env::consts::OS, + client_runtime: "RUST", + attributes, + }; + + Ok(output) + } +} + +/// Enum representing the possible ways of authenticating a connection. +/// The variant chosen dictates which login process is called. +#[derive(Clone, Debug)] +pub enum Login { + Credentials { username: String, password: String }, + AccessToken { access_token: String }, + RefreshToken { refresh_token: String }, +} + +impl<'a> From<&'a Login> for LoginRef<'a> { + fn from(value: &'a Login) -> Self { + match value { + Login::Credentials { username, password } => LoginRef::Credentials { + username, + password: Cow::Borrowed(password), + }, + Login::AccessToken { access_token } => LoginRef::AccessToken { access_token }, + Login::RefreshToken { refresh_token } => LoginRef::RefreshToken { refresh_token }, + } + } +} + +/// Helper containing TLS related options. +#[derive(Debug, Clone, Copy)] +#[allow(clippy::struct_field_names)] +pub struct ExaTlsOptionsRef<'a> { + pub ssl_mode: ExaSslMode, + pub ssl_ca: Option<&'a CertificateInput>, + pub ssl_client_cert: Option<&'a CertificateInput>, + pub ssl_client_key: Option<&'a CertificateInput>, +} + +impl<'a> From<&'a ExaConnectOptions> for ExaTlsOptionsRef<'a> { + fn from(value: &'a ExaConnectOptions) -> Self { + ExaTlsOptionsRef { + ssl_mode: value.ssl_mode, + ssl_ca: value.ssl_ca.as_ref(), + ssl_client_cert: value.ssl_client_cert.as_ref(), + ssl_client_key: value.ssl_client_key.as_ref(), + } + } +} +#[cfg(test)] +mod tests { + use std::num::NonZeroUsize; + + use super::*; + + #[test] + fn test_from_url_basic() { + let url = "exa://user:pass@localhost:8563/schema"; + let options = ExaConnectOptions::from_str(url).unwrap(); + + assert_eq!(options.hostname, "localhost"); + assert_eq!(options.port, 8563); + assert_eq!(options.schema.as_deref(), Some("schema")); + + match &options.login { + Login::Credentials { username, password } => { + assert_eq!(username, "user"); + assert_eq!(password, "pass"); + } + _ => panic!("Expected credentials login"), + } + } + + #[test] + fn test_from_url_with_query_params() { + let url = "exa://localhost:8563?access-token=token123&compression=disabled&fetch-size=1024"; + let options = ExaConnectOptions::from_str(url).unwrap(); + + match &options.login { + Login::AccessToken { access_token } => { + assert_eq!(access_token, "token123"); + } + _ => panic!("Expected access token login"), + } + + assert_eq!(options.compression_mode, ExaCompressionMode::Disabled); + assert_eq!(options.fetch_size, 1024); + } + + #[test] + fn test_from_url_refresh_token() { + let url = "exa://localhost:8563?refresh-token=refresh123"; + let options = ExaConnectOptions::from_str(url).unwrap(); + + match &options.login { + Login::RefreshToken { refresh_token } => { + assert_eq!(refresh_token, "refresh123"); + } + _ => panic!("Expected refresh token login"), + } + } + + #[test] + fn test_from_url_ssl_params() { + let url = "exa://user:p@ssw0rd@localhost:8563?ssl-mode=required&ssl-ca=/path/to/ca.crt"; + let options = ExaConnectOptions::from_str(url).unwrap(); + + assert_eq!(options.ssl_mode, ExaSslMode::Required); + assert!(options.ssl_ca.is_some()); + } + + #[test] + fn test_from_url_numeric_params() { + let url = "exa://user:p@ssw0rd@localhost:8563?statement-cache-capacity=50&\ + query-timeout=30&feedback-interval=10"; + let options = ExaConnectOptions::from_str(url).unwrap(); + + assert_eq!( + options.statement_cache_capacity, + NonZeroUsize::new(50).unwrap() + ); + assert_eq!(options.query_timeout, 30); + assert_eq!(options.feedback_interval, 10); + } + + #[test] + fn test_from_url_invalid_scheme() { + let url = "mysql://localhost:8563"; + let result = ExaConnectOptions::from_str(url); + assert!(result.is_err()); + } + + #[test] + fn test_from_url_unknown_parameter() { + let url = "exa://localhost:8563?unknown-param=value"; + let result = ExaConnectOptions::from_str(url); + assert!(result.is_err()); + } + + #[test] + fn test_to_url_lossy_credentials() { + let options = ExaConnectOptions::builder() + .host("localhost".to_string()) + .port(8563) + .username("user".to_string()) + .password("pass".to_string()) + .schema("schema".to_string()) + .build() + .unwrap(); + + let url = options.to_url_lossy(); + + assert_eq!(url.scheme(), "exa"); + assert_eq!(url.host_str(), Some("localhost")); + assert_eq!(url.port(), Some(8563)); + assert_eq!(url.path(), "/schema"); + assert_eq!(url.username(), "user"); + assert_eq!(url.password(), Some("pass")); + } + + #[test] + fn test_to_url_lossy_access_token() { + let options = ExaConnectOptions::builder() + .host("localhost".to_string()) + .access_token("token123".to_string()) + .build() + .unwrap(); + + let url = options.to_url_lossy(); + + let query_pairs: std::collections::HashMap = + url.query_pairs().into_owned().collect(); + + assert_eq!(query_pairs.get(ACCESS_TOKEN), Some(&"token123".to_string())); + } + + #[test] + fn test_to_url_lossy_refresh_token() { + let options = ExaConnectOptions::builder() + .host("localhost".to_string()) + .refresh_token("refresh123".to_string()) + .build() + .unwrap(); + + let url = options.to_url_lossy(); + + let query_pairs: std::collections::HashMap = + url.query_pairs().into_owned().collect(); + + assert_eq!( + query_pairs.get(REFRESH_TOKEN), + Some(&"refresh123".to_string()) + ); + } + + #[test] + fn test_to_url_lossy_all_params() { + let options = ExaConnectOptions::builder() + .host("localhost".to_string()) + .port(8563) + .username("user".to_string()) + .password("pass".to_string()) + .schema("schema".to_string()) + .compression_mode(ExaCompressionMode::Disabled) + .fetch_size(2048) + .query_timeout(60) + .feedback_interval(5) + .statement_cache_capacity(NonZeroUsize::new(200).unwrap()) + .build() + .unwrap(); + + let url = options.to_url_lossy(); + + let query_pairs: std::collections::HashMap = + url.query_pairs().into_owned().collect(); + + assert_eq!(query_pairs.get(COMPRESSION), Some(&"disabled".to_string())); + assert_eq!(query_pairs.get(FETCH_SIZE), Some(&"2048".to_string())); + assert_eq!(query_pairs.get(QUERY_TIMEOUT), Some(&"60".to_string())); + assert_eq!(query_pairs.get(FEEDBACK_INTERVAL), Some(&"5".to_string())); + assert_eq!( + query_pairs.get(STATEMENT_CACHE_CAPACITY), + Some(&"200".to_string()) + ); + } + + #[test] + fn test_roundtrip_conversion() { + let original_url = + "exa://user:pass@localhost:8563/schema?compression=preferred&fetch-size=1024"; + let options = ExaConnectOptions::from_str(original_url).unwrap(); + let reconstructed_url = options.to_url_lossy(); + let options2 = ExaConnectOptions::from_url(&reconstructed_url).unwrap(); + + assert_eq!(options.hostname, options2.hostname); + assert_eq!(options.port, options2.port); + assert_eq!(options.schema, options2.schema); + assert_eq!(options.compression_mode, options2.compression_mode); + assert_eq!(options.fetch_size, options2.fetch_size); + } + #[test] + fn test_compression_modes() { + // Test ExaCompressionMode::Disabled + let url = "exa://user:pass@localhost:8563?compression=disabled"; + let options = ExaConnectOptions::from_str(url).unwrap(); + assert_eq!(options.compression_mode, ExaCompressionMode::Disabled); + + // Test ExaCompressionMode::Preferred + let url = "exa://user:pass@localhost:8563?compression=preferred"; + let options = ExaConnectOptions::from_str(url).unwrap(); + assert_eq!(options.compression_mode, ExaCompressionMode::Preferred); + + // Test ExaCompressionMode::Required + let url = "exa://user:pass@localhost:8563?compression=required"; + let options = ExaConnectOptions::from_str(url).unwrap(); + assert_eq!(options.compression_mode, ExaCompressionMode::Required); + } + + #[test] + fn test_ssl_modes() { + // Test ExaSslMode::Disable + let url = "exa://user:pass@localhost:8563?ssl-mode=disabled"; + let options = ExaConnectOptions::from_str(url).unwrap(); + assert_eq!(options.ssl_mode, ExaSslMode::Disabled); + + // Test ExaSslMode::Preferred + let url = "exa://user:pass@localhost:8563?ssl-mode=preferred"; + let options = ExaConnectOptions::from_str(url).unwrap(); + assert_eq!(options.ssl_mode, ExaSslMode::Preferred); + + // Test ExaSslMode::Required + let url = "exa://user:pass@localhost:8563?ssl-mode=required"; + let options = ExaConnectOptions::from_str(url).unwrap(); + assert_eq!(options.ssl_mode, ExaSslMode::Required); + } + + #[test] + fn test_compression_and_ssl_modes_together() { + let url = "exa://user:pass@localhost:8563?compression=required&ssl-mode=required"; + let options = ExaConnectOptions::from_str(url).unwrap(); + assert_eq!(options.compression_mode, ExaCompressionMode::Required); + assert_eq!(options.ssl_mode, ExaSslMode::Required); + } + + #[test] + fn test_compression_mode_to_url_lossy() { + // Test that compression modes are correctly serialized back to URL + let options = ExaConnectOptions::builder() + .host("localhost".to_string()) + .username("user".to_string()) + .password("pass".to_string()) + .compression_mode(ExaCompressionMode::Required) + .build() + .unwrap(); + + let url = options.to_url_lossy(); + let query_pairs: std::collections::HashMap = + url.query_pairs().into_owned().collect(); + + assert_eq!(query_pairs.get(COMPRESSION), Some(&"required".to_string())); + } + + #[test] + fn test_ssl_mode_to_url_lossy() { + // Test that SSL modes are correctly serialized back to URL + let options = ExaConnectOptions::builder() + .host("localhost".to_string()) + .username("user".to_string()) + .password("pass".to_string()) + .ssl_mode(ExaSslMode::Required) + .build() + .unwrap(); + + let url = options.to_url_lossy(); + let query_pairs: std::collections::HashMap = + url.query_pairs().into_owned().collect(); + + assert_eq!(query_pairs.get(SSL_MODE), Some(&"required".to_string())); + } +} diff --git a/sqlx-exasol-impl/src/options/protocol_version.rs b/sqlx-exasol-impl/src/options/protocol_version.rs new file mode 100644 index 00000000..e040274c --- /dev/null +++ b/sqlx-exasol-impl/src/options/protocol_version.rs @@ -0,0 +1,40 @@ +use serde::{Deserialize, Serialize}; + +use super::error::ExaConfigError; + +/// Enum listing the protocol versions that can be used when establishing a websocket connection to +/// Exasol. Defaults to the highest defined protocol version and falls back to the highest protocol +/// version supported by the server. +#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Deserialize, Serialize)] +#[serde(try_from = "u8")] +#[serde(into = "u8")] +#[repr(u8)] +pub enum ProtocolVersion { + V1 = 1, + V2 = 2, + V3 = 3, + V4 = 4, + #[default] + V5 = 5, +} + +impl From for u8 { + fn from(value: ProtocolVersion) -> Self { + value as Self + } +} + +impl TryFrom for ProtocolVersion { + type Error = ExaConfigError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::V1), + 2 => Ok(Self::V2), + 3 => Ok(Self::V3), + 4 => Ok(Self::V4), + 5 => Ok(Self::V5), + _ => Err(ExaConfigError::InvalidParameter("protocol-version")), + } + } +} diff --git a/src/options/ssl_mode.rs b/sqlx-exasol-impl/src/options/ssl_mode.rs similarity index 54% rename from src/options/ssl_mode.rs rename to sqlx-exasol-impl/src/options/ssl_mode.rs index 76b58943..7eb5aad4 100644 --- a/src/options/ssl_mode.rs +++ b/sqlx-exasol-impl/src/options/ssl_mode.rs @@ -1,11 +1,11 @@ use std::str::FromStr; -use super::{error::ExaConfigError, PARAM_SSL_MODE}; +use super::{error::ExaConfigError, SSL_MODE}; /// Options for controlling the desired security state of the connection to the Exasol server. /// -/// It is used by the `ssl_mode` method of [`crate::options::builder::ExaConnectOptionsBuilder`]. -#[derive(Debug, Clone, Copy, Default)] +/// It is used by [`crate::options::builder::ExaConnectOptionsBuilder::ssl_mode`]. +#[derive(Debug, Clone, Copy, Default, PartialEq)] pub enum ExaSslMode { /// Establish an unencrypted connection. Disabled, @@ -32,20 +32,37 @@ pub enum ExaSslMode { VerifyIdentity, } +impl ExaSslMode { + const DISABLED: &str = "disabled"; + const PREFERRED: &str = "preferred"; + const REQUIRED: &str = "required"; + const VERIFY_CA: &str = "verify_ca"; + const VERIFY_IDENTITY: &str = "verify_identity"; +} + impl FromStr for ExaSslMode { type Err = ExaConfigError; fn from_str(s: &str) -> Result { Ok(match &*s.to_ascii_lowercase() { - "disabled" => ExaSslMode::Disabled, - "preferred" => ExaSslMode::Preferred, - "required" => ExaSslMode::Required, - "verify_ca" => ExaSslMode::VerifyCa, - "verify_identity" => ExaSslMode::VerifyIdentity, - - _ => { - return Err(ExaConfigError::InvalidParameter(PARAM_SSL_MODE)); - } + Self::DISABLED => ExaSslMode::Disabled, + Self::PREFERRED => ExaSslMode::Preferred, + Self::REQUIRED => ExaSslMode::Required, + Self::VERIFY_CA => ExaSslMode::VerifyCa, + Self::VERIFY_IDENTITY => ExaSslMode::VerifyIdentity, + _ => Err(ExaConfigError::InvalidParameter(SSL_MODE))?, }) } } + +impl AsRef for ExaSslMode { + fn as_ref(&self) -> &str { + match self { + ExaSslMode::Disabled => Self::DISABLED, + ExaSslMode::Preferred => Self::PREFERRED, + ExaSslMode::Required => Self::REQUIRED, + ExaSslMode::VerifyCa => Self::VERIFY_CA, + ExaSslMode::VerifyIdentity => Self::VERIFY_IDENTITY, + } + } +} diff --git a/src/query_result.rs b/sqlx-exasol-impl/src/query_result.rs similarity index 100% rename from src/query_result.rs rename to sqlx-exasol-impl/src/query_result.rs diff --git a/src/responses/attributes.rs b/sqlx-exasol-impl/src/responses/attributes.rs similarity index 100% rename from src/responses/attributes.rs rename to sqlx-exasol-impl/src/responses/attributes.rs diff --git a/src/responses/columns.rs b/sqlx-exasol-impl/src/responses/columns.rs similarity index 100% rename from src/responses/columns.rs rename to sqlx-exasol-impl/src/responses/columns.rs diff --git a/src/responses/describe.rs b/sqlx-exasol-impl/src/responses/describe.rs similarity index 100% rename from src/responses/describe.rs rename to sqlx-exasol-impl/src/responses/describe.rs diff --git a/src/responses/error.rs b/sqlx-exasol-impl/src/responses/error.rs similarity index 100% rename from src/responses/error.rs rename to sqlx-exasol-impl/src/responses/error.rs diff --git a/src/responses/fetch.rs b/sqlx-exasol-impl/src/responses/fetch.rs similarity index 100% rename from src/responses/fetch.rs rename to sqlx-exasol-impl/src/responses/fetch.rs diff --git a/src/responses/hosts.rs b/sqlx-exasol-impl/src/responses/hosts.rs similarity index 100% rename from src/responses/hosts.rs rename to sqlx-exasol-impl/src/responses/hosts.rs diff --git a/src/responses/mod.rs b/sqlx-exasol-impl/src/responses/mod.rs similarity index 100% rename from src/responses/mod.rs rename to sqlx-exasol-impl/src/responses/mod.rs diff --git a/src/responses/prepared_stmt.rs b/sqlx-exasol-impl/src/responses/prepared_stmt.rs similarity index 100% rename from src/responses/prepared_stmt.rs rename to sqlx-exasol-impl/src/responses/prepared_stmt.rs diff --git a/src/responses/public_key.rs b/sqlx-exasol-impl/src/responses/public_key.rs similarity index 100% rename from src/responses/public_key.rs rename to sqlx-exasol-impl/src/responses/public_key.rs diff --git a/src/responses/result.rs b/sqlx-exasol-impl/src/responses/result.rs similarity index 100% rename from src/responses/result.rs rename to sqlx-exasol-impl/src/responses/result.rs diff --git a/src/responses/session_info.rs b/sqlx-exasol-impl/src/responses/session_info.rs similarity index 100% rename from src/responses/session_info.rs rename to sqlx-exasol-impl/src/responses/session_info.rs diff --git a/src/row.rs b/sqlx-exasol-impl/src/row.rs similarity index 88% rename from src/row.rs rename to sqlx-exasol-impl/src/row.rs index f60fa338..17f6dda1 100644 --- a/src/row.rs +++ b/sqlx-exasol-impl/src/row.rs @@ -1,14 +1,14 @@ use std::{fmt::Debug, sync::Arc}; use serde_json::Value; -use sqlx_core::{column::ColumnIndex, database::Database, row::Row, HashMap}; +use sqlx_core::{column::ColumnIndex, database::Database, ext::ustr::UStr, row::Row, HashMap}; use crate::{column::ExaColumn, database::Exasol, value::ExaValueRef, SqlxError, SqlxResult}; /// Struct representing a result set row. Implementor of [`Row`]. #[derive(Debug)] pub struct ExaRow { - column_names: Arc, usize>>, + pub(crate) column_names: Arc>, columns: Arc<[ExaColumn]>, data: Vec, } @@ -18,7 +18,7 @@ impl ExaRow { pub fn new( data: Vec, columns: Arc<[ExaColumn]>, - column_names: Arc, usize>>, + column_names: Arc>, ) -> Self { Self { column_names, diff --git a/src/statement.rs b/sqlx-exasol-impl/src/statement.rs similarity index 55% rename from src/statement.rs rename to sqlx-exasol-impl/src/statement.rs index 8174ce05..6226f49b 100644 --- a/src/statement.rs +++ b/sqlx-exasol-impl/src/statement.rs @@ -1,7 +1,8 @@ -use std::{borrow::Cow, collections::HashMap, sync::Arc}; +use std::sync::Arc; use sqlx_core::{ - column::ColumnIndex, database::Database, impl_statement_query, statement::Statement, Either, + column::ColumnIndex, database::Database, ext::ustr::UStr, impl_statement_query, + sql_str::SqlStr, statement::Statement, Either, HashMap, }; use crate::{ @@ -11,45 +12,43 @@ use crate::{ /// Implementor of [`Statement`]. #[derive(Debug, Clone)] -pub struct ExaStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct ExaStatement { + pub(crate) sql: SqlStr, pub(crate) metadata: ExaStatementMetadata, } #[derive(Debug, Clone)] pub struct ExaStatementMetadata { pub columns: Arc<[ExaColumn]>, - pub column_names: HashMap, usize>, + pub column_names: Arc>, pub parameters: Arc<[ExaTypeInfo]>, } impl ExaStatementMetadata { pub fn new(columns: Arc<[ExaColumn]>, parameters: Arc<[ExaTypeInfo]>) -> Self { - let mut column_names = HashMap::with_capacity(columns.len()); - - for (idx, col) in columns.as_ref().iter().enumerate() { - column_names.insert(col.name.clone(), idx); - } + let column_names = columns + .as_ref() + .iter() + .enumerate() + .map(|(idx, col)| (col.name.clone(), idx)) + .collect(); Self { columns, - column_names, + column_names: Arc::new(column_names), parameters, } } } -impl<'q> Statement<'q> for ExaStatement<'q> { +impl Statement for ExaStatement { type Database = Exasol; - fn to_owned(&self) -> ::Statement<'static> { - ExaStatement { - sql: Cow::Owned(self.sql.clone().into_owned()), - metadata: self.metadata.clone(), - } + fn into_sql(self) -> SqlStr { + self.sql } - fn sql(&self) -> &str { + fn sql(&self) -> &SqlStr { &self.sql } @@ -64,8 +63,8 @@ impl<'q> Statement<'q> for ExaStatement<'q> { impl_statement_query!(ExaArguments); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &ExaStatement<'_>) -> SqlxResult { +impl ColumnIndex for &str { + fn index(&self, statement: &ExaStatement) -> SqlxResult { statement .metadata .column_names diff --git a/src/testing.rs b/sqlx-exasol-impl/src/testing.rs similarity index 65% rename from src/testing.rs rename to sqlx-exasol-impl/src/testing.rs index 9767d997..4306ea3e 100644 --- a/src/testing.rs +++ b/sqlx-exasol-impl/src/testing.rs @@ -1,90 +1,86 @@ use std::{ops::Deref, str::FromStr, sync::OnceLock, time::Duration}; -use futures_core::future::BoxFuture; use futures_util::TryStreamExt; use sqlx_core::{ connection::Connection, error::DatabaseError, executor::Executor, pool::{Pool, PoolOptions}, - query, query_scalar, + sql_str::AssertSqlSafe, testing::{FixtureSnapshot, TestArgs, TestContext, TestSupport}, Error, }; use crate::{ connection::ExaConnection, database::Exasol, options::ExaConnectOptions, ExaQueryResult, + SqlxResult, }; static MASTER_POOL: OnceLock> = OnceLock::new(); impl TestSupport for Exasol { - fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>> { - Box::pin(test_context(args)) + async fn test_context(args: &TestArgs) -> SqlxResult> { + test_context(args).await } - fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { - let mut conn = MASTER_POOL - .get() - .expect("cleanup_test() invoked outside `#[sqlx::test]") - .acquire() - .await?; + async fn cleanup_test(db_name: &str) -> SqlxResult<()> { + let mut conn = MASTER_POOL + .get() + .expect("cleanup_test() invoked outside `#[sqlx_exasol::test]") + .acquire() + .await?; - do_cleanup(&mut conn, db_name).await - }) + do_cleanup(&mut conn, db_name).await } - fn cleanup_test_dbs() -> BoxFuture<'static, Result, Error>> { - Box::pin(async move { - let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + async fn cleanup_test_dbs() -> SqlxResult> { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); - let mut conn = ExaConnection::connect(&url).await?; + let mut conn = ExaConnection::connect(&url).await?; - let query_str = r#"SELECT db_name FROM "_sqlx_tests"."_sqlx_test_databases";"#; - let db_names_to_delete: Vec = query_scalar::query_scalar(query_str) - .fetch_all(&mut conn) - .await?; + let query = r#"SELECT db_name FROM "_sqlx_tests"."_sqlx_test_databases";"#; + let db_names_to_delete: Vec = sqlx_core::query_scalar::query_scalar(query) + .fetch_all(&mut conn) + .await?; - if db_names_to_delete.is_empty() { - return Ok(None); - } + if db_names_to_delete.is_empty() { + return Ok(None); + } + + let mut deleted_db_names = Vec::with_capacity(db_names_to_delete.len()); - let mut deleted_db_names = Vec::with_capacity(db_names_to_delete.len()); - - for db_name in &db_names_to_delete { - let query_str = format!(r#"DROP SCHEMA IF EXISTS "{db_name}" CASCADE;"#); - - match conn.execute(&*query_str).await { - Ok(_deleted) => { - deleted_db_names.push(db_name); - } - // Assume a database error just means the DB is still in use. - Err(Error::Database(dbe)) => { - eprintln!("could not clean test database {db_name}: {dbe}"); - } - // Bubble up other errors - Err(e) => return Err(e), + for db_name in &db_names_to_delete { + let query = format!(r#"DROP SCHEMA IF EXISTS "{db_name}" CASCADE;"#); + + match conn.execute(AssertSqlSafe(query)).await { + Ok(_deleted) => { + deleted_db_names.push(db_name); + } + // Assume a database error just means the DB is still in use. + Err(Error::Database(dbe)) => { + eprintln!("could not clean test database {db_name}: {dbe}"); } + // Bubble up other errors + Err(e) => return Err(e), } + } - if deleted_db_names.is_empty() { - return Ok(None); - } + if deleted_db_names.is_empty() { + return Ok(None); + } - query::query(r#"DELETE FROM "_sqlx_tests"."_sqlx_test_databases" WHERE db_name = ?;"#) - .bind(&deleted_db_names) - .execute(&mut conn) - .await?; + sqlx_core::query::query( + r#"DELETE FROM "_sqlx_tests"."_sqlx_test_databases" WHERE db_name = ?;"#, + ) + .bind(&deleted_db_names) + .execute(&mut conn) + .await?; - conn.close().await.ok(); - Ok(Some(db_names_to_delete.len())) - }) + conn.close().await.ok(); + Ok(Some(db_names_to_delete.len())) } - fn snapshot( - _conn: &mut Self::Connection, - ) -> BoxFuture<'_, Result, Error>> { + async fn snapshot(_conn: &mut Self::Connection) -> SqlxResult> { // TODO: SQLx doesn't implement this yet either. todo!() } @@ -118,6 +114,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { "DATABASE_URL changed at runtime, database differs" ); + #[allow(clippy::large_futures, reason = "silencing clippy")] let mut conn = master_pool.acquire().await?; cleanup_old_dbs(&mut conn).await?; @@ -164,13 +161,13 @@ async fn test_context(args: &TestArgs) -> Result, Error> { INSERT INTO "_sqlx_tests"."_sqlx_test_databases" (db_name, test_path) VALUES (?, ?)"#; - query::query(query_str) + sqlx_core::query::query(query_str) .bind(&db_name) .bind(args.test_path) .execute(&mut *tx) .await?; - tx.execute(&*format!(r#"CREATE SCHEMA "{db_name}";"#)) + tx.execute(AssertSqlSafe(format!(r#"CREATE SCHEMA "{db_name}";"#))) .await?; tx.commit().await?; @@ -195,25 +192,28 @@ async fn test_context(args: &TestArgs) -> Result, Error> { } async fn do_cleanup(conn: &mut ExaConnection, db_name: &str) -> Result<(), Error> { - conn.execute(&*format!(r#"DROP SCHEMA IF EXISTS "{db_name}" CASCADE"#)) - .await?; + let query = format!(r#"DROP SCHEMA IF EXISTS "{db_name}" CASCADE"#); + conn.execute(AssertSqlSafe(query)).await?; - query::query(r#"DELETE FROM "_sqlx_tests"."_sqlx_test_databases" WHERE db_name = ?;"#) - .bind(db_name) - .execute(&mut *conn) - .await?; + sqlx_core::query::query( + r#"DELETE FROM "_sqlx_tests"."_sqlx_test_databases" WHERE db_name = ?;"#, + ) + .bind(db_name) + .execute(&mut *conn) + .await?; Ok(()) } /// Pre <0.8.4, test databases were stored by integer ID. async fn cleanup_old_dbs(conn: &mut ExaConnection) -> Result<(), Error> { - let res = - query_scalar::query_scalar(r#"SELECT db_id FROM "_sqlx_tests"."_sqlx_test_databases";"#) - .fetch_all(&mut *conn) - .await; + let res = sqlx_core::query_scalar::query_scalar( + r#"SELECT db_id FROM "_sqlx_tests"."_sqlx_test_databases";"#, + ) + .fetch_all(&mut *conn) + .await; - let db_ids: Vec = match res { + let db_ids: Vec = match res { Ok(db_ids) => db_ids, Err(e) => { return match e @@ -235,7 +235,7 @@ async fn cleanup_old_dbs(conn: &mut ExaConnection) -> Result<(), Error> { // Drop old-style test databases. for id in db_ids { let query = format!(r#"DROP SCHEMA IF EXISTS "_sqlx_test_database_{id}" CASCADE"#); - match conn.execute(&*query).await { + match conn.execute(AssertSqlSafe(query)).await { Ok(_deleted) => (), // Assume a database error just means the DB is still in use. Err(Error::Database(dbe)) => { diff --git a/sqlx-exasol-impl/src/transaction.rs b/sqlx-exasol-impl/src/transaction.rs new file mode 100644 index 00000000..4c0d99dc --- /dev/null +++ b/sqlx-exasol-impl/src/transaction.rs @@ -0,0 +1,51 @@ +use sqlx_core::{sql_str::SqlStr, transaction::TransactionManager}; + +use crate::{ + connection::websocket::future::{Commit, Rollback, WebSocketFuture}, + database::Exasol, + error::ExaProtocolError, + ExaConnection, SqlxResult, +}; + +/// Implementor of [`TransactionManager`]. +#[derive(Debug, Clone, Copy)] +pub struct ExaTransactionManager; + +impl TransactionManager for ExaTransactionManager { + type Database = Exasol; + + async fn begin(conn: &mut ExaConnection, _: Option) -> SqlxResult<()> { + // Exasol does not have nested transactions. + if conn.attributes().open_transaction() { + // A pending rollback indicates that a transaction was dropped before an explicit + // rollback, which is why it's still open. If that's the case, then awaiting the + // rollback is sufficient to proceed. + match conn.ws.pending_rollback.take() { + Some(rollback) => rollback.future(&mut conn.ws).await?, + None => return Err(ExaProtocolError::TransactionAlreadyOpen)?, + } + } + + // The next time a request is sent, the transaction will be started. + // We could eagerly start it as well, but that implies one more round-trip to the server + // and back with no benefit. + conn.attributes_mut().set_autocommit(false); + Ok(()) + } + + async fn commit(conn: &mut ExaConnection) -> SqlxResult<()> { + Commit::default().future(&mut conn.ws).await + } + + async fn rollback(conn: &mut ExaConnection) -> SqlxResult<()> { + Rollback::default().future(&mut conn.ws).await + } + + fn start_rollback(conn: &mut ExaConnection) { + conn.ws.pending_rollback = Some(Rollback::default()); + } + + fn get_transaction_depth(conn: &ExaConnection) -> usize { + conn.attributes().open_transaction().into() + } +} diff --git a/sqlx-exasol-impl/src/type_checking.rs b/sqlx-exasol-impl/src/type_checking.rs new file mode 100644 index 00000000..2024ca05 --- /dev/null +++ b/sqlx-exasol-impl/src/type_checking.rs @@ -0,0 +1,100 @@ +#[allow(unused_imports)] +use sqlx_core::{config::drivers::Config, describe::Describe, impl_type_checking}; +use sqlx_macros_core::{ + database::{CachingDescribeBlocking, DatabaseExt}, + query::QueryDriver, +}; + +use crate::{Exasol, SqlxResult}; + +pub const QUERY_DRIVER: QueryDriver = QueryDriver::new::(); + +impl DatabaseExt for Exasol { + const DATABASE_PATH: &'static str = "sqlx_exasol::Exasol"; + + const ROW_PATH: &'static str = "sqlx_exasol::ExaRow"; + + fn describe_blocking( + query: &str, + database_url: &str, + driver_config: &Config, + ) -> SqlxResult> { + static CACHE: CachingDescribeBlocking = CachingDescribeBlocking::new(); + + CACHE.describe(query, database_url, driver_config) + } +} + +mod sqlx_exasol { + #[allow(unused_imports, reason = "used in type checking")] + pub mod types { + pub use sqlx_core::types::*; + + pub use crate::types::*; + + #[cfg(feature = "chrono")] + pub mod chrono { + pub use sqlx_core::types::chrono::*; + + pub use crate::types::chrono::*; + } + + #[cfg(feature = "time")] + pub mod time { + pub use sqlx_core::types::time::*; + + pub use crate::types::time::*; + } + } +} + +impl_type_checking!( + Exasol { + bool, + i8, + i16, + i32, + i64, + f64, + String | &str, + + sqlx_exasol::types::HashType, + sqlx_exasol::types::ExaIntervalYearToMonth, + + #[cfg(feature = "uuid")] + sqlx_exasol::types::Uuid, + + #[cfg(feature = "geo-types")] + sqlx_exasol::types::geo_types::Geometry, + }, + ParamChecking::Weak, + feature-types: info => info.__type_feature_gate(), + datetime-types: { + chrono: { + sqlx_exasol::types::chrono::TimeDelta, + + sqlx_exasol::types::chrono::NaiveDate, + + sqlx_exasol::types::chrono::NaiveDateTime, + + sqlx_exasol::types::chrono::DateTime, + }, + time: { + sqlx_exasol::types::time::Duration, + + sqlx_exasol::types::time::Date, + + sqlx_exasol::types::time::PrimitiveDateTime, + + sqlx_exasol::types::time::OffsetDateTime, + }, + }, + numeric-types: { + bigdecimal: { + sqlx_exasol::types::BigDecimal, + }, + rust_decimal: { + sqlx_exasol::types::Decimal, + }, + }, +); diff --git a/sqlx-exasol-impl/src/type_info.rs b/sqlx-exasol-impl/src/type_info.rs new file mode 100644 index 00000000..5851784b --- /dev/null +++ b/sqlx-exasol-impl/src/type_info.rs @@ -0,0 +1,513 @@ +use std::fmt::{Arguments, Display}; + +use arrayvec::ArrayString; +use serde::{Deserialize, Serialize}; +use sqlx_core::type_info::TypeInfo; + +/// Information about an Exasol data type and implementor of [`TypeInfo`]. +// Note that the [`DataTypeName`] is automatically constructed from the provided [`ExaDataType`]. +#[derive(Debug, Clone, Copy, Deserialize)] +#[serde(from = "ExaDataType")] +pub struct ExaTypeInfo { + pub(crate) name: DataTypeName, + pub(crate) data_type: ExaDataType, +} + +impl ExaTypeInfo { + #[doc(hidden)] + #[allow(clippy::must_use_candidate)] + pub fn __type_feature_gate(&self) -> Option<&'static str> { + match self.data_type { + ExaDataType::Date + | ExaDataType::Timestamp + | ExaDataType::TimestampWithLocalTimeZone => Some("time"), + ExaDataType::Decimal(decimal) + if decimal.scale > 0 || decimal.precision > Some(Decimal::MAX_64BIT_PRECISION) => + { + Some("bigdecimal") + } + _ => None, + } + } +} + +/// Manually implemented because we only want to serialize the `data_type` field while also +/// flattening the structure. +// NOTE: On [`Deserialize`] we simply convert from the [`ExaDataType`] to this. +impl Serialize for ExaTypeInfo { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.data_type.serialize(serializer) + } +} + +impl From for ExaTypeInfo { + fn from(data_type: ExaDataType) -> Self { + let name = data_type.full_name(); + Self { name, data_type } + } +} + +impl PartialEq for ExaTypeInfo { + fn eq(&self, other: &Self) -> bool { + self.data_type == other.data_type + } +} + +impl Display for ExaTypeInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl TypeInfo for ExaTypeInfo { + fn is_null(&self) -> bool { + false + } + + /// We're going against `sqlx` here, but knowing the full data type definition is actually very + /// helpful when displaying error messages, so... ¯\_(ツ)_/¯. This is also due to Exasol's + /// limited number of data types. How would it look saying that a `DECIMAL` column does not fit + /// in some other `DECIMAL` data type? + /// + /// In fact, error messages seem to be the only place where this is being used, particularly + /// when trying to decode a value but the data type provided by the database does not + /// match/fit inside the Rust data type. + fn name(&self) -> &str { + self.name.as_ref() + } + + /// Checks compatibility with other data types. + /// + /// Returns true if this [`ExaTypeInfo`] instance is able to accommodate the `other` instance. + fn type_compatible(&self, other: &Self) -> bool + where + Self: Sized, + { + self.data_type.compatible(&other.data_type) + } +} + +/// Datatype definitions enum, as Exasol sees them. +/// +/// If you manually construct them, be aware that there is a [`DataTypeName`] automatically +/// constructed when converting to [`ExaTypeInfo`] and there are compatibility checks set in place. +/// +/// In case of incompatibility, the definition is displayed for troubleshooting. +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "UPPERCASE")] +#[serde(tag = "type")] +pub enum ExaDataType { + /// The BOOLEAN data type. + Boolean, + /// The CHAR data type. + #[serde(rename_all = "camelCase")] + Char { size: u32, character_set: Charset }, + /// The DATE data type. + Date, + /// The DECIMAL data type. + Decimal(Decimal), + /// The DOUBLE data type. + Double, + /// The `GEOMETRY` data type. + #[serde(rename_all = "camelCase")] + Geometry { srid: u16 }, + /// The `INTERVAL DAY TO SECOND` data type. + #[serde(rename = "INTERVAL DAY TO SECOND")] + #[serde(rename_all = "camelCase")] + IntervalDayToSecond { precision: u32, fraction: u32 }, + /// The `INTERVAL YEAR TO MONTH` data type. + #[serde(rename = "INTERVAL YEAR TO MONTH")] + #[serde(rename_all = "camelCase")] + IntervalYearToMonth { precision: u32 }, + /// The TIMESTAMP data type. + Timestamp, + /// The TIMESTAMP WITH LOCAL TIME ZONE data type. + #[serde(rename = "TIMESTAMP WITH LOCAL TIME ZONE")] + TimestampWithLocalTimeZone, + /// The VARCHAR data type. + #[serde(rename_all = "camelCase")] + Varchar { size: u32, character_set: Charset }, + /// The Exasol `HASHTYPE` data type. + /// + /// NOTE: Exasol returns the size of the column string representation which depends on the + /// `HASHTYPE_FORMAT` database parameter. We set the parameter to `HEX` whenever we open + /// a connection to allow us to reliably use the column size, particularly for UUIDs. + /// + /// However, other values (especially the ones to be encoded) through + /// [`crate::types::HashType`] cannot be strictly checked because they could be in different + /// formats, like hex, base64, etc. In that case we avoid the size check by relying on + /// [`None`]. + /// + /// Database columns and prepared statements parameters will **always** be [`Some`]. + HashType { size: Option }, +} + +impl ExaDataType { + // Data type names + const BOOLEAN: &'static str = "BOOLEAN"; + const CHAR: &'static str = "CHAR"; + const DATE: &'static str = "DATE"; + const DECIMAL: &'static str = "DECIMAL"; + const DOUBLE: &'static str = "DOUBLE PRECISION"; + const GEOMETRY: &'static str = "GEOMETRY"; + const INTERVAL_DAY_TO_SECOND: &'static str = "INTERVAL DAY TO SECOND"; + const INTERVAL_YEAR_TO_MONTH: &'static str = "INTERVAL YEAR TO MONTH"; + const TIMESTAMP: &'static str = "TIMESTAMP"; + const TIMESTAMP_WITH_LOCAL_TIME_ZONE: &'static str = "TIMESTAMP WITH LOCAL TIME ZONE"; + const VARCHAR: &'static str = "VARCHAR"; + const HASHTYPE: &'static str = "HASHTYPE"; + + // Datatype constants + // + /// Accuracy is limited to milliseconds, see: . + /// + /// The fraction has the weird behavior of shifting the milliseconds up the value and mixing it + /// with the seconds, minutes, hours or even the days when the value exceeds 3 (the max + /// milliseconds digits limit) even though the maximum value is 9. + /// + /// See: + /// + /// Therefore, we'll only be handling fractions smaller or equal to 3. + #[allow(dead_code, reason = "used by optional dependency")] + pub(crate) const INTERVAL_DTS_MAX_FRACTION: u32 = 3; + #[allow(dead_code, reason = "used by optional dependency")] + pub(crate) const INTERVAL_DTS_MAX_PRECISION: u32 = 9; + pub(crate) const INTERVAL_YTM_MAX_PRECISION: u32 = 9; + pub(crate) const VARCHAR_MAX_LEN: u32 = 2_000_000; + #[cfg_attr(not(test), expect(dead_code))] + pub(crate) const CHAR_MAX_LEN: u32 = 2_000; + // 1024 * 2 because we set HASHTYPE_FORMAT to HEX. + #[cfg_attr(not(test), expect(dead_code))] + pub(crate) const HASHTYPE_MAX_LEN: u16 = 2048; + + /// Returns `true` if this instance is compatible with the other one provided. + /// + /// Compatibility means that the [`self`] instance is bigger/able to accommodate the other + /// instance. + pub fn compatible(&self, other: &Self) -> bool { + match (self, other) { + (Self::HashType { size: Some(s1) }, Self::HashType { size: Some(s2) }) => s1 == s2, + (Self::Boolean, Self::Boolean) + | ( + Self::Char { .. } | Self::Varchar { .. }, + Self::Char { .. } | Self::Varchar { .. }, + ) + | (Self::Date, Self::Date) + | (Self::Double, Self::Double) + | (Self::Geometry { .. }, Self::Geometry { .. }) + | (Self::IntervalDayToSecond { .. }, Self::IntervalDayToSecond { .. }) + | (Self::IntervalYearToMonth { .. }, Self::IntervalYearToMonth { .. }) + | (Self::Timestamp, Self::Timestamp) + | (Self::TimestampWithLocalTimeZone, Self::TimestampWithLocalTimeZone) + | (Self::HashType { .. }, Self::HashType { .. }) => true, + (Self::Decimal(d1), Self::Decimal(d2)) => d1.compatible(*d2), + _ => false, + } + } + + fn full_name(&self) -> DataTypeName { + match self { + Self::Boolean => Self::BOOLEAN.into(), + Self::Date => Self::DATE.into(), + Self::Double => Self::DOUBLE.into(), + Self::Timestamp => Self::TIMESTAMP.into(), + Self::TimestampWithLocalTimeZone => Self::TIMESTAMP_WITH_LOCAL_TIME_ZONE.into(), + Self::Char { + size, + character_set, + } + | Self::Varchar { + size, + character_set, + } => format_args!("{}({}) {}", self.as_ref(), size, character_set).into(), + Self::Decimal(d) => match d.precision { + Some(p) => format_args!("{}({}, {})", self.as_ref(), p, d.scale).into(), + None => format_args!("{}(*, {})", self.as_ref(), d.scale).into(), + }, + Self::Geometry { srid } => format_args!("{}({srid})", self.as_ref()).into(), + Self::IntervalDayToSecond { + precision, + fraction, + } => format_args!("INTERVAL DAY({precision}) TO SECOND({fraction})").into(), + Self::IntervalYearToMonth { precision } => { + format_args!("INTERVAL YEAR({precision}) TO MONTH").into() + } + Self::HashType { size } => match size { + // We get the HEX len, which is double the byte count. + Some(s) => format_args!("{}({} BYTE)", self.as_ref(), s / 2).into(), + None => format_args!("{}", self.as_ref()).into(), + }, + } + } +} + +impl AsRef for ExaDataType { + fn as_ref(&self) -> &str { + match self { + Self::Boolean => Self::BOOLEAN, + Self::Char { .. } => Self::CHAR, + Self::Date => Self::DATE, + Self::Decimal(_) => Self::DECIMAL, + Self::Double => Self::DOUBLE, + Self::Geometry { .. } => Self::GEOMETRY, + Self::IntervalDayToSecond { .. } => Self::INTERVAL_DAY_TO_SECOND, + Self::IntervalYearToMonth { .. } => Self::INTERVAL_YEAR_TO_MONTH, + Self::Timestamp => Self::TIMESTAMP, + Self::TimestampWithLocalTimeZone => Self::TIMESTAMP_WITH_LOCAL_TIME_ZONE, + Self::Varchar { .. } => Self::VARCHAR, + Self::HashType { .. } => Self::HASHTYPE, + } + } +} + +/// A data type's name, composed from an instance of [`ExaDataType`]. For performance's sake, since +/// data type names are small, we either store them statically or as inlined strings. +/// +/// *IMPORTANT*: Creating absurd [`ExaDataType`] can result in panics if the name exceeds the +/// inlined strings max capacity. Valid values always fit. +#[derive(Debug, Clone, Copy)] +pub enum DataTypeName { + Static(&'static str), + Inline(ArrayString<30>), +} + +impl AsRef for DataTypeName { + fn as_ref(&self) -> &str { + match self { + DataTypeName::Static(s) => s, + DataTypeName::Inline(s) => s.as_str(), + } + } +} + +impl Display for DataTypeName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_ref()) + } +} + +impl From<&'static str> for DataTypeName { + fn from(value: &'static str) -> Self { + Self::Static(value) + } +} + +impl From> for DataTypeName { + fn from(value: Arguments<'_>) -> Self { + Self::Inline(ArrayString::try_from(value).expect("inline data type name too large")) + } +} + +/// The `DECIMAL` data type. +#[derive(Debug, Copy, Clone, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct Decimal { + /// The absence of precision means universal compatibility. + pub(crate) precision: Option, + pub(crate) scale: u8, +} + +impl Decimal { + /// Max precision values for signed integers. + pub(crate) const MAX_8BIT_PRECISION: u8 = 3; + pub(crate) const MAX_16BIT_PRECISION: u8 = 5; + pub(crate) const MAX_32BIT_PRECISION: u8 = 10; + pub(crate) const MAX_64BIT_PRECISION: u8 = 20; + + /// Max supported values. + pub(crate) const MAX_PRECISION: u8 = 36; + #[allow(dead_code)] + pub(crate) const MAX_SCALE: u8 = 36; + + /// The purpose of this is to be able to tell if some [`Decimal`] fits inside another + /// [`Decimal`]. + /// + /// Therefore, we consider cases such as: + /// - DECIMAL(10, 1) != DECIMAL(9, 2) + /// - DECIMAL(10, 1) != DECIMAL(10, 2) + /// - DECIMAL(10, 1) < DECIMAL(11, 2) + /// - DECIMAL(10, 1) < DECIMAL(17, 4) + /// + /// - DECIMAL(10, 1) > DECIMAL(9, 1) + /// - DECIMAL(10, 1) = DECIMAL(10, 1) + /// - DECIMAL(10, 1) < DECIMAL(11, 1) + /// + /// This boils down to: + /// `a.scale >= b.scale AND (a.precision - a.scale) >= (b.precision - b.scale)` + /// + /// However, decimal Rust types require special handling because they can hold virtually any + /// decoded value. Therefore, an absent precision means that the comparison must be skipped. + #[rustfmt::skip] // just to skip rules formatting + fn compatible(self, dec: Decimal) -> bool { + let (precision, scale) = match dec.precision { + Some(precision) => (precision, dec.scale), + // Short-circuit if we are encoding a Rust decimal type as they have arbitrary precision. + None => return true, + }; + + // If we're decoding to a Rust decimal type then we can accept any DECIMAL precision. + let self_diff = self.precision.map_or(Decimal::MAX_PRECISION, |p| p - self.scale); + let other_diff = precision - scale; + + self.scale >= scale && self_diff >= other_diff + } +} + +/// Exasol supported character sets. +#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "UPPERCASE")] +pub enum Charset { + Utf8, + Ascii, +} + +impl AsRef for Charset { + fn as_ref(&self) -> &str { + match self { + Charset::Utf8 => "UTF8", + Charset::Ascii => "ASCII", + } + } +} + +impl Display for Charset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_ref()) + } +} + +/// Mainly adding these so that we ensure the inlined type names won't panic when created with their +/// max values. +/// +/// If the max values work, the lower ones inherently will too. +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_boolean_name() { + let data_type = ExaDataType::Boolean; + assert_eq!(data_type.full_name().as_ref(), "BOOLEAN"); + } + + #[test] + fn test_max_char_name() { + let data_type = ExaDataType::Char { + size: ExaDataType::CHAR_MAX_LEN, + character_set: Charset::Ascii, + }; + assert_eq!( + data_type.full_name().as_ref(), + format!("CHAR({}) ASCII", ExaDataType::CHAR_MAX_LEN) + ); + } + + #[test] + fn test_date_name() { + let data_type = ExaDataType::Date; + assert_eq!(data_type.full_name().as_ref(), "DATE"); + } + + #[test] + fn test_max_decimal_name() { + let decimal = Decimal { + precision: Some(Decimal::MAX_PRECISION), + scale: Decimal::MAX_SCALE, + }; + let data_type = ExaDataType::Decimal(decimal); + assert_eq!( + data_type.full_name().as_ref(), + format!( + "DECIMAL({}, {})", + Decimal::MAX_PRECISION, + Decimal::MAX_SCALE + ) + ); + } + + #[test] + fn test_double_name() { + let data_type = ExaDataType::Double; + assert_eq!(data_type.full_name().as_ref(), "DOUBLE PRECISION"); + } + + #[test] + fn test_max_geometry_name() { + let data_type = ExaDataType::Geometry { srid: u16::MAX }; + assert_eq!( + data_type.full_name().as_ref(), + format!("GEOMETRY({})", u16::MAX) + ); + } + + #[test] + fn test_max_interval_day_name() { + let data_type = ExaDataType::IntervalDayToSecond { + precision: ExaDataType::INTERVAL_DTS_MAX_PRECISION, + fraction: ExaDataType::INTERVAL_DTS_MAX_FRACTION, + }; + assert_eq!( + data_type.full_name().as_ref(), + format!( + "INTERVAL DAY({}) TO SECOND({})", + ExaDataType::INTERVAL_DTS_MAX_PRECISION, + ExaDataType::INTERVAL_DTS_MAX_FRACTION + ) + ); + } + + #[test] + fn test_max_interval_year_name() { + let data_type = ExaDataType::IntervalYearToMonth { + precision: ExaDataType::INTERVAL_YTM_MAX_PRECISION, + }; + assert_eq!( + data_type.full_name().as_ref(), + format!( + "INTERVAL YEAR({}) TO MONTH", + ExaDataType::INTERVAL_YTM_MAX_PRECISION, + ) + ); + } + + #[test] + fn test_timestamp_name() { + let data_type = ExaDataType::Timestamp; + assert_eq!(data_type.full_name().as_ref(), "TIMESTAMP"); + } + + #[test] + fn test_timestamp_with_tz_name() { + let data_type = ExaDataType::TimestampWithLocalTimeZone; + assert_eq!( + data_type.full_name().as_ref(), + "TIMESTAMP WITH LOCAL TIME ZONE" + ); + } + + #[test] + fn test_max_varchar_name() { + let data_type = ExaDataType::Varchar { + size: ExaDataType::VARCHAR_MAX_LEN, + character_set: Charset::Ascii, + }; + assert_eq!( + data_type.full_name().as_ref(), + format!("VARCHAR({}) ASCII", ExaDataType::VARCHAR_MAX_LEN) + ); + } + + #[test] + fn test_max_hashbyte_name() { + let data_type = ExaDataType::HashType { + size: Some(ExaDataType::HASHTYPE_MAX_LEN), + }; + assert_eq!( + data_type.full_name().as_ref(), + format!("HASHTYPE({} BYTE)", ExaDataType::HASHTYPE_MAX_LEN / 2) + ); + } +} diff --git a/sqlx-exasol-impl/src/types/bigdecimal.rs b/sqlx-exasol-impl/src/types/bigdecimal.rs new file mode 100644 index 00000000..81558d1a --- /dev/null +++ b/sqlx-exasol-impl/src/types/bigdecimal.rs @@ -0,0 +1,44 @@ +use bigdecimal::BigDecimal; +use serde::Deserialize; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{Decimal, ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +impl Type for BigDecimal { + fn type_info() -> ExaTypeInfo { + // A somewhat non-sensical value used to allow decoding any DECIMAL value. + ExaDataType::Decimal(Decimal { + precision: None, + scale: Decimal::MAX_SCALE, + }) + .into() + } +} + +impl Encode<'_, Exasol> for BigDecimal { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(format_args!("{self}"))?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 1 quote + 1 sign + 1 zero + 1 dot + max scale + 1 quote + 4 + Decimal::MAX_SCALE as usize + 1 + } +} + +impl Decode<'_, Exasol> for BigDecimal { + fn decode(value: ExaValueRef<'_>) -> Result { + ::deserialize(value.value).map_err(From::from) + } +} diff --git a/src/types/bool.rs b/sqlx-exasol-impl/src/types/bool.rs similarity index 78% rename from src/types/bool.rs rename to sqlx-exasol-impl/src/types/bool.rs index 1a858962..5b53ac70 100644 --- a/src/types/bool.rs +++ b/sqlx-exasol-impl/src/types/bool.rs @@ -25,16 +25,9 @@ impl Encode<'_, Exasol> for bool { Ok(IsNull::No) } - fn produces(&self) -> Option { - Some(ExaDataType::Boolean.into()) - } - fn size_hint(&self) -> usize { - if *self { - stringify!(true).len() - } else { - stringify!(false).len() - } + // len of `false` + 5 } } diff --git a/sqlx-exasol-impl/src/types/chrono/date.rs b/sqlx-exasol-impl/src/types/chrono/date.rs new file mode 100644 index 00000000..12cf2d11 --- /dev/null +++ b/sqlx-exasol-impl/src/types/chrono/date.rs @@ -0,0 +1,39 @@ +use chrono::NaiveDate; +use serde::Deserialize; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +impl Type for NaiveDate { + fn type_info() -> ExaTypeInfo { + ExaDataType::Date.into() + } +} + +impl Encode<'_, Exasol> for NaiveDate { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(self)?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 2 quotes + 4 year + 1 dash + 2 months + 1 dash + 2 days + 12 + } +} + +impl Decode<'_, Exasol> for NaiveDate { + fn decode(value: ExaValueRef<'_>) -> Result { + ::deserialize(value.value).map_err(From::from) + } +} diff --git a/sqlx-exasol-impl/src/types/chrono/datetime.rs b/sqlx-exasol-impl/src/types/chrono/datetime.rs new file mode 100644 index 00000000..91f7f190 --- /dev/null +++ b/sqlx-exasol-impl/src/types/chrono/datetime.rs @@ -0,0 +1,105 @@ +use chrono::{DateTime, Local, NaiveDateTime, Utc}; +use serde::Deserialize; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +const TIMESTAMP_FMT: &str = "%Y-%m-%d %H:%M:%S%.9f"; + +impl Type for NaiveDateTime { + fn type_info() -> ExaTypeInfo { + ExaDataType::Timestamp.into() + } +} + +impl Encode<'_, Exasol> for NaiveDateTime { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(format_args!("{}", self.format(TIMESTAMP_FMT)))?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 1 quote + + // 4 years + 1 dash + 2 months + 1 dash + 2 days + + // 1 space + 2 hours + 2 minutes + 2 seconds + 9 subseconds + + // 1 quote + 28 + } +} + +impl Decode<'_, Exasol> for NaiveDateTime { + fn decode(value: ExaValueRef<'_>) -> Result { + let input = <&str>::deserialize(value.value).map_err(Box::new)?; + Self::parse_from_str(input, TIMESTAMP_FMT) + .map_err(Box::new) + .map_err(From::from) + } +} + +impl Type for DateTime { + fn type_info() -> ExaTypeInfo { + ExaDataType::Timestamp.into() + } +} + +impl Encode<'_, Exasol> for DateTime { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + Encode::::encode(self.naive_utc(), buf) + } + + fn size_hint(&self) -> usize { + // 1 quote + + // 4 years + 1 dash + 2 months + 1 dash + 2 days + + // 1 space + 2 hours + 2 minutes + 2 seconds + 9 subseconds + + // 1 quote + 28 + } +} + +impl<'r> Decode<'r, Exasol> for DateTime { + fn decode(value: ExaValueRef<'r>) -> Result { + let naive: NaiveDateTime = Decode::::decode(value)?; + Ok(DateTime::from_naive_utc_and_offset(naive, Utc)) + } +} + +impl Type for DateTime { + fn type_info() -> ExaTypeInfo { + ExaDataType::TimestampWithLocalTimeZone.into() + } +} + +impl Encode<'_, Exasol> for DateTime { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + Encode::::encode(self.naive_local(), buf) + } + + fn size_hint(&self) -> usize { + // 1 quote + + // 4 years + 1 dash + 2 months + 1 dash + 2 days + + // 1 space + 2 hours + 2 minutes + 2 seconds + 9 subseconds + + // 1 quote + 28 + } +} + +impl<'r> Decode<'r, Exasol> for DateTime { + fn decode(value: ExaValueRef<'r>) -> Result { + let naive: NaiveDateTime = Decode::::decode(value)?; + naive + .and_local_timezone(Local) + .single() + .ok_or("cannot uniquely determine timezone offset") + .map_err(From::from) + } +} diff --git a/sqlx-exasol-impl/src/types/chrono/mod.rs b/sqlx-exasol-impl/src/types/chrono/mod.rs new file mode 100644 index 00000000..4aadd175 --- /dev/null +++ b/sqlx-exasol-impl/src/types/chrono/mod.rs @@ -0,0 +1,19 @@ +use chrono::Months; +use sqlx_core::error::BoxDynError; + +use crate::types::ExaIntervalYearToMonth; + +mod date; +mod datetime; +mod timedelta; + +pub use chrono::TimeDelta; + +impl TryFrom for ExaIntervalYearToMonth { + type Error = BoxDynError; + + fn try_from(value: Months) -> Result { + let num_months = value.as_u32().into(); + Ok(Self(num_months)) + } +} diff --git a/sqlx-exasol-impl/src/types/chrono/timedelta.rs b/sqlx-exasol-impl/src/types/chrono/timedelta.rs new file mode 100644 index 00000000..4a0d3ff2 --- /dev/null +++ b/sqlx-exasol-impl/src/types/chrono/timedelta.rs @@ -0,0 +1,78 @@ +use chrono::TimeDelta; +use serde::Deserialize; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +impl Type for TimeDelta { + fn type_info() -> ExaTypeInfo { + ExaDataType::IntervalDayToSecond { + precision: ExaDataType::INTERVAL_DTS_MAX_PRECISION, + fraction: ExaDataType::INTERVAL_DTS_MAX_FRACTION, + } + .into() + } +} + +impl Encode<'_, Exasol> for TimeDelta { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(format_args!( + "{} {}:{}:{}.{}", + self.num_days(), + self.num_hours().abs() % 24, + self.num_minutes().abs() % 60, + self.num_seconds().abs() % 60, + self.num_milliseconds().abs() % 1000 + ))?; + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 1 quote + 1 sign + max days precision + + // 1 space + 2 hours + 1 column + 2 minutes + 1 column + 2 seconds + + // 1 dot + max milliseconds fraction + + // 1 quote + 2 + ExaDataType::INTERVAL_DTS_MAX_PRECISION as usize + + 10 + + ExaDataType::INTERVAL_DTS_MAX_FRACTION as usize + + 1 + } +} + +impl<'r> Decode<'r, Exasol> for TimeDelta { + fn decode(value: ExaValueRef<'r>) -> Result { + let input = <&str>::deserialize(value.value).map_err(Box::new)?; + let input_err_fn = || format!("could not parse {input} as INTERVAL DAY TO SECOND"); + + let (days, rest) = input.split_once(' ').ok_or_else(input_err_fn)?; + let (hours, rest) = rest.split_once(':').ok_or_else(input_err_fn)?; + let (minutes, rest) = rest.split_once(':').ok_or_else(input_err_fn)?; + let (seconds, millis) = rest.split_once('.').ok_or_else(input_err_fn)?; + + let days: i64 = days.parse().map_err(Box::new)?; + let hours: i64 = hours.parse().map_err(Box::new)?; + let minutes: i64 = minutes.parse().map_err(Box::new)?; + let seconds: i64 = seconds.parse().map_err(Box::new)?; + let millis: i64 = millis.parse().map_err(Box::new)?; + let sign = if days.is_negative() { -1 } else { 1 }; + + let duration = TimeDelta::days(days) + + TimeDelta::hours(hours * sign) + + TimeDelta::minutes(minutes * sign) + + TimeDelta::seconds(seconds * sign) + + TimeDelta::milliseconds(millis * sign); + + Ok(duration) + } +} diff --git a/sqlx-exasol-impl/src/types/float.rs b/sqlx-exasol-impl/src/types/float.rs new file mode 100644 index 00000000..c9d410f3 --- /dev/null +++ b/sqlx-exasol-impl/src/types/float.rs @@ -0,0 +1,52 @@ +use serde::Deserialize; +use serde_json::Value; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +impl Type for f64 { + fn type_info() -> ExaTypeInfo { + ExaDataType::Double.into() + } +} + +impl Encode<'_, Exasol> for f64 { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + // NaN is treated as NULL by Exasol. + // Infinity is not supported by Exasol but serde_json + // serializes it as NULL as well. + if self.is_finite() { + buf.append(self)?; + Ok(IsNull::No) + } else { + buf.append(())?; + Ok(IsNull::Yes) + } + } + + fn size_hint(&self) -> usize { + // 1 sign + 15 digits + 1 dot + // See + 17 + } +} + +impl Decode<'_, Exasol> for f64 { + fn decode(value: ExaValueRef<'_>) -> Result { + match value.value { + Value::Number(n) => ::deserialize(n).map_err(From::from), + Value::String(s) => serde_json::from_str(s).map_err(From::from), + v => Err(format!("invalid f64 value: {v}").into()), + } + } +} diff --git a/sqlx-exasol-impl/src/types/geo_types.rs b/sqlx-exasol-impl/src/types/geo_types.rs new file mode 100644 index 00000000..48bd5fdf --- /dev/null +++ b/sqlx-exasol-impl/src/types/geo_types.rs @@ -0,0 +1,48 @@ +use std::{fmt::Display, str::FromStr}; + +use geo_types::CoordNum; +pub use geo_types::{ + Geometry, GeometryCollection, Line, LineString, MultiLineString, MultiPoint, MultiPolygon, + Point, Polygon, Rect, Triangle, +}; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +impl Type for Geometry +where + T: CoordNum, +{ + fn type_info() -> ExaTypeInfo { + ExaDataType::Geometry { srid: 0 }.into() + } +} + +impl Encode<'_, Exasol> for Geometry +where + T: CoordNum + Display, +{ + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append_geometry(self)?; + Ok(IsNull::No) + } +} + +impl Decode<'_, Exasol> for Geometry +where + T: CoordNum + FromStr + Default, +{ + fn decode(value: ExaValueRef<'_>) -> Result { + wkt::deserialize::deserialize_wkt(value.value).map_err(From::from) + } +} diff --git a/sqlx-exasol-impl/src/types/hashtype.rs b/sqlx-exasol-impl/src/types/hashtype.rs new file mode 100644 index 00000000..3b3e35e1 --- /dev/null +++ b/sqlx-exasol-impl/src/types/hashtype.rs @@ -0,0 +1,51 @@ +use serde::Deserialize; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +/// Newtype used for more explicit encoding/decoding of arbitrary length data into/from HASHTYPE +/// columns. +/// +/// Unlike UUID, this type is not subject to length checks because Exasol can accept multiple +/// formats for these columns. While connections set the `HASHTYPE_FORMAT` database parameter to +/// `HEX` when they're opened to allow the driver to reliably use the reported column size for +/// length checks for UUIDs, this only affects the column output. Exasol will still accept any valid +/// input for the column, which can have different lengths depending on the data format. +/// +/// See , in particular for your exact database version. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct HashType(pub String); + +impl Type for HashType { + fn type_info() -> ExaTypeInfo { + ExaDataType::HashType { size: None }.into() + } +} + +impl Encode<'_, Exasol> for HashType { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + <&str as Encode>::encode_by_ref(&self.0.as_str(), buf) + } + + fn size_hint(&self) -> usize { + <&str as Encode>::size_hint(&self.0.as_str()) + } +} + +impl<'r> Decode<'r, Exasol> for HashType { + fn decode(value: ExaValueRef<'r>) -> Result { + String::deserialize(value.value) + .map(HashType) + .map_err(From::from) + } +} diff --git a/sqlx-exasol-impl/src/types/int.rs b/sqlx-exasol-impl/src/types/int.rs new file mode 100644 index 00000000..cdd4e82b --- /dev/null +++ b/sqlx-exasol-impl/src/types/int.rs @@ -0,0 +1,143 @@ +use std::ops::Range; + +use serde::Deserialize; +use serde_json::Value; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{Decimal, ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +/// Numbers within this range must be serialized/deserialized as integers. +/// The ones above/under these thresholds are treated as strings. +const NUMERIC_I64_RANGE: Range = -999_999_999_999_999_999..1_000_000_000_000_000_000; + +impl Type for i8 { + fn type_info() -> ExaTypeInfo { + ExaDataType::Decimal(Decimal { + precision: Some(Decimal::MAX_8BIT_PRECISION), + scale: 0, + }) + .into() + } +} + +impl Encode<'_, Exasol> for i8 { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(self)?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // sign + max num digits + 1 + Decimal::MAX_8BIT_PRECISION as usize + } +} + +impl Decode<'_, Exasol> for i8 { + fn decode(value: ExaValueRef<'_>) -> Result { + ::deserialize(value.value).map_err(From::from) + } +} + +impl Type for i16 { + fn type_info() -> ExaTypeInfo { + ExaDataType::Decimal(Decimal { + precision: Some(Decimal::MAX_16BIT_PRECISION), + scale: 0, + }) + .into() + } +} + +impl Encode<'_, Exasol> for i16 { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(self)?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // sign + max num digits + 1 + Decimal::MAX_16BIT_PRECISION as usize + } +} + +impl Decode<'_, Exasol> for i16 { + fn decode(value: ExaValueRef<'_>) -> Result { + ::deserialize(value.value).map_err(From::from) + } +} + +impl Type for i32 { + fn type_info() -> ExaTypeInfo { + ExaDataType::Decimal(Decimal { + precision: Some(Decimal::MAX_32BIT_PRECISION), + scale: 0, + }) + .into() + } +} + +impl Encode<'_, Exasol> for i32 { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(self)?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // sign + max num digits + 1 + Decimal::MAX_32BIT_PRECISION as usize + } +} + +impl Decode<'_, Exasol> for i32 { + fn decode(value: ExaValueRef<'_>) -> Result { + ::deserialize(value.value).map_err(From::from) + } +} + +impl Type for i64 { + fn type_info() -> ExaTypeInfo { + ExaDataType::Decimal(Decimal { + precision: Some(Decimal::MAX_64BIT_PRECISION), + scale: 0, + }) + .into() + } +} + +impl Encode<'_, Exasol> for i64 { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + if NUMERIC_I64_RANGE.contains(self) { + buf.append(self)?; + } else { + // Large numbers get serialized as strings + buf.append(format_args!("{self}"))?; + } + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 1 quote + 1 sign + max num digits + 1 quote + 2 + Decimal::MAX_64BIT_PRECISION as usize + 1 + } +} + +impl Decode<'_, Exasol> for i64 { + fn decode(value: ExaValueRef<'_>) -> Result { + match value.value { + Value::Number(n) => ::deserialize(n).map_err(From::from), + Value::String(s) => serde_json::from_str(s).map_err(From::from), + v => Err(format!("invalid i64 value: {v}").into()), + } + } +} diff --git a/sqlx-exasol-impl/src/types/interval_ytm.rs b/sqlx-exasol-impl/src/types/interval_ytm.rs new file mode 100644 index 00000000..13903db2 --- /dev/null +++ b/sqlx-exasol-impl/src/types/interval_ytm.rs @@ -0,0 +1,101 @@ +use std::fmt::{self, Display}; + +use serde::{de, Deserialize, Deserializer, Serialize}; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +/// A duration interval as a representation of the `INTERVAL YEAR TO MONTH` datatype. +/// +/// The duration is expressed in months. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ExaIntervalYearToMonth(pub i64); + +impl Type for ExaIntervalYearToMonth { + fn type_info() -> ExaTypeInfo { + ExaDataType::IntervalYearToMonth { + precision: ExaDataType::INTERVAL_YTM_MAX_PRECISION, + } + .into() + } +} + +impl Encode<'_, Exasol> for ExaIntervalYearToMonth { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(self)?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 1 quote + 1 sign + max year precision + 1 dash + 2 months + 1 quote + 2 + ExaDataType::INTERVAL_YTM_MAX_PRECISION as usize + 4 + } +} + +impl<'r> Decode<'r, Exasol> for ExaIntervalYearToMonth { + fn decode(value: ExaValueRef<'r>) -> Result { + Self::deserialize(value.value).map_err(From::from) + } +} + +impl Display for ExaIntervalYearToMonth { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let years = self.0 / 12; + let months = (self.0 % 12).abs(); + let plus = if years.is_negative() { "" } else { "+" }; + write!(f, "{plus}{years}-{months}") + } +} + +impl Serialize for ExaIntervalYearToMonth { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + format_args!("{self}").serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ExaIntervalYearToMonth { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Visitor; + + impl de::Visitor<'_> for Visitor { + type Value = ExaIntervalYearToMonth; + + fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "INTERVAL YEAR TO MONTH in the format [year]-[month]") + } + fn visit_str(self, value: &str) -> Result { + let input_err_fn = || { + let msg = format!("could not parse {value} as INTERVAL YEAR TO MONTH"); + de::Error::custom(msg) + }; + + let (years, months) = value.rsplit_once('-').ok_or_else(input_err_fn)?; + let years = years.parse::().map_err(de::Error::custom)?; + let months = months.parse::().map_err(de::Error::custom)?; + + let sign = if years.is_negative() { -1 } else { 1 }; + let total = years * 12 + months * sign; + + Ok(ExaIntervalYearToMonth(total)) + } + } + + deserializer.deserialize_str(Visitor) + } +} diff --git a/sqlx-exasol-impl/src/types/iter.rs b/sqlx-exasol-impl/src/types/iter.rs new file mode 100644 index 00000000..5c4d0efb --- /dev/null +++ b/sqlx-exasol-impl/src/types/iter.rs @@ -0,0 +1,144 @@ +use std::marker::PhantomData; + +use sqlx_core::{ + database::Database, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{arguments::ExaBuffer, Exasol}; + +/// Adapter allowing any iterator of encodable values to be treated and passed as a one dimensional +/// parameter array for a column to Exasol in a single query invocation. Multi dimensional arrays +/// are not supported. The adapter is needed because [`Encode`] is still a foreign trait and thus +/// cannot be implemented in a generic manner to all types implementing [`IntoIterator`]. +/// +/// Note that the [`Encode`] trait requires the ability to encode by reference, thus the adapter +/// takes a type that implements [`IntoIterator`]. But since iteration requires mutability, +/// the adaptar also requires [`Clone`]. The adapter definition should steer it towards being used +/// with cheaply clonable iterators since it expects the iteration elements to be references. +/// However, care should still be taken so as not to clone expensive [`IntoIterator`] types. +/// +/// ```rust +/// # use sqlx_exasol_impl as sqlx_exasol; +/// use sqlx_exasol::types::ExaIter; +/// +/// let vector = vec![1, 2, 3]; +/// let borrowed_iter = ExaIter::new(vector.iter().filter(|v| **v % 2 == 0)); +/// ``` +#[derive(Debug)] +#[repr(transparent)] +pub struct ExaIter { + into_iter: I, + data_lifetime: PhantomData T>, +} + +impl ExaIter +where + I: IntoIterator + Clone, + T: for<'q> Encode<'q, Exasol> + Type + Copy, +{ + pub fn new(into_iter: I) -> Self { + Self { + into_iter, + data_lifetime: PhantomData, + } + } +} + +impl Type for ExaIter +where + I: IntoIterator + Clone, + T: Type + Copy, +{ + fn type_info() -> ::TypeInfo { + ::Item::type_info() + } +} + +impl Encode<'_, Exasol> for ExaIter +where + I: IntoIterator + Clone, + T: for<'q> Encode<'q, Exasol> + Copy, +{ + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append_iter(self.into_iter.clone())?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // Brackets [] + items size + 2 + self + .into_iter + .clone() + .into_iter() + .fold(0, |sum, item| sum + item.size_hint()) + } +} + +impl Type for &[T] +where + T: Type, +{ + fn type_info() -> ::TypeInfo { + T::type_info() + } +} + +impl Encode<'_, Exasol> for &[T] +where + for<'q> T: Encode<'q, Exasol> + Type, +{ + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + ExaIter::new(*self).encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + ExaIter::new(*self).size_hint() + } +} + +impl Type for [T; N] +where + T: Type, +{ + fn type_info() -> ::TypeInfo { + T::type_info() + } +} + +impl Encode<'_, Exasol> for [T; N] +where + for<'q> T: Encode<'q, Exasol> + Type, +{ + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + ExaIter::new(self.as_slice()).encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + ExaIter::new(self.as_slice()).size_hint() + } +} + +impl Type for Vec +where + T: Type, +{ + fn type_info() -> ::TypeInfo { + T::type_info() + } +} + +impl Encode<'_, Exasol> for Vec +where + for<'q> T: Encode<'q, Exasol> + Type, +{ + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + ExaIter::new(self.as_slice()).encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + ExaIter::new(self.as_slice()).size_hint() + } +} diff --git a/sqlx-exasol-impl/src/types/json.rs b/sqlx-exasol-impl/src/types/json.rs new file mode 100644 index 00000000..950aea21 --- /dev/null +++ b/sqlx-exasol-impl/src/types/json.rs @@ -0,0 +1,45 @@ +use serde::{Deserialize, Serialize}; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::{Json, Type}, +}; + +use crate::{ + arguments::ExaBuffer, + type_info::{Charset, ExaDataType}, + ExaTypeInfo, ExaValueRef, Exasol, +}; + +impl Type for Json { + fn type_info() -> ExaTypeInfo { + ExaDataType::Varchar { + size: ExaDataType::VARCHAR_MAX_LEN, + character_set: Charset::Utf8, + } + .into() + } +} + +impl Encode<'_, Exasol> for Json +where + T: Serialize, +{ + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append_json(&self.0)?; + Ok(IsNull::No) + } +} + +impl<'r, T> Decode<'r, Exasol> for Json +where + T: 'r + Deserialize<'r>, +{ + fn decode(value: ExaValueRef<'r>) -> Result { + <&str>::deserialize(value.value) + .and_then(serde_json::from_str) + .map(Json) + .map_err(From::from) + } +} diff --git a/sqlx-exasol-impl/src/types/mod.rs b/sqlx-exasol-impl/src/types/mod.rs new file mode 100644 index 00000000..e3454f38 --- /dev/null +++ b/sqlx-exasol-impl/src/types/mod.rs @@ -0,0 +1,27 @@ +#[cfg(feature = "bigdecimal")] +mod bigdecimal; +mod bool; +#[cfg(feature = "chrono")] +pub mod chrono; +mod float; +#[cfg(feature = "geo-types")] +pub mod geo_types; +mod hashtype; +mod int; +mod interval_ytm; +mod iter; +#[cfg(feature = "json")] +mod json; +mod option; +#[cfg(feature = "rust_decimal")] +mod rust_decimal; +mod str; +mod text; +#[cfg(feature = "time")] +pub mod time; +#[cfg(feature = "uuid")] +mod uuid; + +pub use hashtype::HashType; +pub use interval_ytm::ExaIntervalYearToMonth; +pub use iter::ExaIter; diff --git a/src/types/option.rs b/sqlx-exasol-impl/src/types/option.rs similarity index 82% rename from src/types/option.rs rename to sqlx-exasol-impl/src/types/option.rs index 6d28992b..69a41b24 100644 --- a/src/types/option.rs +++ b/sqlx-exasol-impl/src/types/option.rs @@ -4,7 +4,7 @@ use sqlx_core::{ types::Type, }; -use crate::{arguments::ExaBuffer, type_info::ExaDataType, ExaTypeInfo, Exasol}; +use crate::{arguments::ExaBuffer, ExaTypeInfo, Exasol}; impl Encode<'_, Exasol> for Option where @@ -15,7 +15,7 @@ where if let Some(v) = self { v.produces() } else { - Some(ExaDataType::Null.into()) + Some(T::type_info()) } } @@ -41,6 +41,7 @@ where #[inline] fn size_hint(&self) -> usize { - self.as_ref().map(Encode::size_hint).unwrap_or_default() + // We encode `null` when `None`, hence size 4. + self.as_ref().map_or(4, Encode::size_hint) } } diff --git a/sqlx-exasol-impl/src/types/rust_decimal.rs b/sqlx-exasol-impl/src/types/rust_decimal.rs new file mode 100644 index 00000000..ef3a6bda --- /dev/null +++ b/sqlx-exasol-impl/src/types/rust_decimal.rs @@ -0,0 +1,45 @@ +use serde::Deserialize; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{Decimal, ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +impl Type for rust_decimal::Decimal { + #[allow(clippy::cast_possible_truncation)] + fn type_info() -> ExaTypeInfo { + // A somewhat non-sensical value used to allow decoding any DECIMAL value + // with a supported scale. + ExaDataType::Decimal(Decimal { + precision: None, + scale: rust_decimal::Decimal::MAX_SCALE as u8, + }) + .into() + } +} + +impl Encode<'_, Exasol> for rust_decimal::Decimal { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(format_args!("{self}"))?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 1 quote + 1 sign + 1 zero + 1 dot + max scale + 1 quote + 4 + Decimal::MAX_SCALE as usize + 1 + } +} + +impl Decode<'_, Exasol> for rust_decimal::Decimal { + fn decode(value: ExaValueRef<'_>) -> Result { + ::deserialize(value.value).map_err(From::from) + } +} diff --git a/src/types/str.rs b/sqlx-exasol-impl/src/types/str.rs similarity index 50% rename from src/types/str.rs rename to sqlx-exasol-impl/src/types/str.rs index 7e7c2fea..5d8ed2c5 100644 --- a/src/types/str.rs +++ b/sqlx-exasol-impl/src/types/str.rs @@ -11,18 +11,17 @@ use sqlx_core::{ use crate::{ arguments::ExaBuffer, database::Exasol, - type_info::{Charset, ExaDataType, ExaTypeInfo, StringLike}, + type_info::{Charset, ExaDataType, ExaTypeInfo}, value::ExaValueRef, }; impl Type for str { fn type_info() -> ExaTypeInfo { - let string_like = StringLike::new(StringLike::MAX_VARCHAR_LEN, Charset::Utf8); - ExaDataType::Varchar(string_like).into() - } - - fn compatible(ty: &ExaTypeInfo) -> bool { - >::type_info().compatible(ty) + ExaDataType::Varchar { + size: ExaDataType::VARCHAR_MAX_LEN, + character_set: Charset::Utf8, + } + .into() } } @@ -38,12 +37,8 @@ impl Encode<'_, Exasol> for &'_ str { Ok(IsNull::No) } - fn produces(&self) -> Option { - Some(>::type_info()) - } - fn size_hint(&self) -> usize { - // 2 Quotes + length + // 2 quotes + length 2 + self.len() } } @@ -58,23 +53,15 @@ impl Type for String { fn type_info() -> ExaTypeInfo { >::type_info() } - - fn compatible(ty: &ExaTypeInfo) -> bool { - >::compatible(ty) - } } impl Encode<'_, Exasol> for String { fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { - <&str as Encode>::encode(&**self, buf) - } - - fn produces(&self) -> Option { - <&str as Encode>::produces(&&**self) + <&str as Encode>::encode(self.as_ref(), buf) } fn size_hint(&self) -> usize { - <&str as Encode>::size_hint(&&**self) + <&str as Encode>::size_hint(&self.as_ref()) } } @@ -84,35 +71,12 @@ impl Decode<'_, Exasol> for String { } } -impl Type for Cow<'_, str> { - fn type_info() -> ExaTypeInfo { - <&str as Type>::type_info() - } - - fn compatible(ty: &ExaTypeInfo) -> bool { - <&str as Type>::compatible(ty) - } -} - impl Encode<'_, Exasol> for Cow<'_, str> { fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { - match self { - Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), - Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), - } - } - - fn produces(&self) -> Option { - <&str as Encode>::produces(&&**self) + <&str as Encode>::encode(self.as_ref(), buf) } fn size_hint(&self) -> usize { - <&str as Encode>::size_hint(&&**self) - } -} - -impl<'r> Decode<'r, Exasol> for Cow<'r, str> { - fn decode(value: ExaValueRef<'r>) -> Result { - Cow::deserialize(value.value).map_err(From::from) + <&str as Encode>::size_hint(&self.as_ref()) } } diff --git a/src/types/text.rs b/sqlx-exasol-impl/src/types/text.rs similarity index 76% rename from src/types/text.rs rename to sqlx-exasol-impl/src/types/text.rs index 6bcf884a..80b25b3b 100644 --- a/src/types/text.rs +++ b/sqlx-exasol-impl/src/types/text.rs @@ -13,10 +13,6 @@ impl Type for Text { fn type_info() -> ExaTypeInfo { >::type_info() } - - fn compatible(ty: &ExaTypeInfo) -> bool { - >::compatible(ty) - } } impl Encode<'_, Exasol> for Text @@ -53,14 +49,13 @@ where mod tests { use sqlx::{types::Text, Encode}; - use crate::ExaArguments; + use crate::{ExaArguments, Exasol}; #[test] fn test_text_null_string() { let mut arg_buffer = ExaArguments::default(); - let is_null = Text(String::new()) - .encode_by_ref(&mut arg_buffer.buf) - .unwrap(); + let value = Text(String::new()); + let is_null = Encode::::encode_by_ref(&value, &mut arg_buffer.buf).unwrap(); assert!(is_null.is_null()); } @@ -68,7 +63,8 @@ mod tests { #[test] fn test_text_null_str() { let mut arg_buffer = ExaArguments::default(); - let is_null = Text("").encode_by_ref(&mut arg_buffer.buf).unwrap(); + let value = Text(""); + let is_null = Encode::::encode_by_ref(&value, &mut arg_buffer.buf).unwrap(); assert!(is_null.is_null()); } @@ -76,9 +72,8 @@ mod tests { #[test] fn test_text_non_null_string() { let mut arg_buffer = ExaArguments::default(); - let is_null = Text(String::from("something")) - .encode_by_ref(&mut arg_buffer.buf) - .unwrap(); + let value = Text(String::from("something")); + let is_null = Encode::::encode_by_ref(&value, &mut arg_buffer.buf).unwrap(); assert!(!is_null.is_null()); } @@ -86,9 +81,8 @@ mod tests { #[test] fn test_text_non_null_str() { let mut arg_buffer = ExaArguments::default(); - let is_null = Text("something") - .encode_by_ref(&mut arg_buffer.buf) - .unwrap(); + let value = Text("something"); + let is_null = Encode::::encode_by_ref(&value, &mut arg_buffer.buf).unwrap(); assert!(!is_null.is_null()); } diff --git a/sqlx-exasol-impl/src/types/time/date.rs b/sqlx-exasol-impl/src/types/time/date.rs new file mode 100644 index 00000000..5c6701da --- /dev/null +++ b/sqlx-exasol-impl/src/types/time/date.rs @@ -0,0 +1,49 @@ +use ::serde::{Deserialize, Serialize}; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; +use time::{serde, Date}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +impl Type for Date { + fn type_info() -> ExaTypeInfo { + ExaDataType::Date.into() + } +} + +impl Encode<'_, Exasol> for Date { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(DateSer(self))?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 2 quotes + 4 year + 1 dash + 2 months + 1 dash + 2 days + 12 + } +} + +impl Decode<'_, Exasol> for Date { + fn decode(value: ExaValueRef<'_>) -> Result { + DateDe::deserialize(value.value) + .map(|v| v.0) + .map_err(From::from) + } +} + +#[derive(Serialize)] +struct DateSer<'a>(#[serde(serialize_with = "date::serialize")] &'a Date); + +#[derive(Deserialize)] +struct DateDe(#[serde(deserialize_with = "date::deserialize")] Date); + +serde::format_description!(date, Date, "[year]-[month]-[day]"); diff --git a/sqlx-exasol-impl/src/types/time/datetime.rs b/sqlx-exasol-impl/src/types/time/datetime.rs new file mode 100644 index 00000000..00e5a802 --- /dev/null +++ b/sqlx-exasol-impl/src/types/time/datetime.rs @@ -0,0 +1,86 @@ +use ::serde::{Deserialize, Serialize}; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; +use time::{serde, OffsetDateTime, PrimitiveDateTime, UtcOffset}; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +impl Type for PrimitiveDateTime { + fn type_info() -> ExaTypeInfo { + ExaDataType::TimestampWithLocalTimeZone.into() + } +} + +impl Encode<'_, Exasol> for PrimitiveDateTime { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(PrimitiveDateTimeSer(self))?; + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 1 quote + + // 4 years + 1 dash + 2 months + 1 dash + 2 days + + // 1 space + 2 hours + 2 minutes + 2 seconds + 9 subseconds + + // 1 quote + 28 + } +} + +impl Decode<'_, Exasol> for PrimitiveDateTime { + fn decode(value: ExaValueRef<'_>) -> Result { + PrimitiveDateTimeDe::deserialize(value.value) + .map(|v| v.0) + .map_err(From::from) + } +} + +impl Type for OffsetDateTime { + fn type_info() -> ExaTypeInfo { + ExaDataType::Timestamp.into() + } +} + +impl Encode<'_, Exasol> for OffsetDateTime { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + let utc_dt = self.to_offset(UtcOffset::UTC); + let primitive = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time()); + primitive.encode(buf) + } + + fn size_hint(&self) -> usize { + // 1 quote + + // 4 years + 1 dash + 2 months + 1 dash + 2 days + + // 1 space + 2 hours + 2 minutes + 2 seconds + 9 subseconds + + // 1 quote + 28 + } +} + +impl<'r> Decode<'r, Exasol> for OffsetDateTime { + fn decode(value: ExaValueRef<'r>) -> Result { + PrimitiveDateTime::decode(value).map(PrimitiveDateTime::assume_utc) + } +} + +#[derive(Serialize)] +struct PrimitiveDateTimeSer<'a>( + #[serde(serialize_with = "timestamp::serialize")] &'a PrimitiveDateTime, +); + +#[derive(Deserialize)] +struct PrimitiveDateTimeDe(#[serde(deserialize_with = "timestamp::deserialize")] PrimitiveDateTime); + +serde::format_description!( + timestamp, + PrimitiveDateTime, + "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]" +); diff --git a/sqlx-exasol-impl/src/types/time/duration.rs b/sqlx-exasol-impl/src/types/time/duration.rs new file mode 100644 index 00000000..22bdb5be --- /dev/null +++ b/sqlx-exasol-impl/src/types/time/duration.rs @@ -0,0 +1,78 @@ +use serde::Deserialize; +use sqlx_core::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; +use time::Duration; + +use crate::{ + arguments::ExaBuffer, + database::Exasol, + type_info::{ExaDataType, ExaTypeInfo}, + value::ExaValueRef, +}; + +impl Type for Duration { + fn type_info() -> ExaTypeInfo { + ExaDataType::IntervalDayToSecond { + precision: ExaDataType::INTERVAL_DTS_MAX_PRECISION, + fraction: ExaDataType::INTERVAL_DTS_MAX_FRACTION, + } + .into() + } +} + +impl Encode<'_, Exasol> for Duration { + fn encode_by_ref(&self, buf: &mut ExaBuffer) -> Result { + buf.append(format_args!( + "{} {}:{}:{}.{}", + self.whole_days(), + self.whole_hours().abs() % 24, + self.whole_minutes().abs() % 60, + self.whole_seconds().abs() % 60, + self.whole_milliseconds().abs() % 1000 + ))?; + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + // 1 quote + 1 sign + max days precision + + // 1 space + 2 hours + 1 column + 2 minutes + 1 column + 2 seconds + + // 1 dot + max milliseconds fraction + + // 1 quote + 2 + ExaDataType::INTERVAL_DTS_MAX_PRECISION as usize + + 10 + + ExaDataType::INTERVAL_DTS_MAX_FRACTION as usize + + 1 + } +} + +impl<'r> Decode<'r, Exasol> for Duration { + fn decode(value: ExaValueRef<'r>) -> Result { + let input = <&str>::deserialize(value.value).map_err(Box::new)?; + let input_err_fn = || format!("could not parse {input} as INTERVAL DAY TO SECOND"); + + let (days, rest) = input.split_once(' ').ok_or_else(input_err_fn)?; + let (hours, rest) = rest.split_once(':').ok_or_else(input_err_fn)?; + let (minutes, rest) = rest.split_once(':').ok_or_else(input_err_fn)?; + let (seconds, millis) = rest.split_once('.').ok_or_else(input_err_fn)?; + + let days: i64 = days.parse().map_err(Box::new)?; + let hours: i64 = hours.parse().map_err(Box::new)?; + let minutes: i64 = minutes.parse().map_err(Box::new)?; + let seconds: i64 = seconds.parse().map_err(Box::new)?; + let millis: i64 = millis.parse().map_err(Box::new)?; + let sign = if days.is_negative() { -1 } else { 1 }; + + let duration = Duration::days(days) + + Duration::hours(hours * sign) + + Duration::minutes(minutes * sign) + + Duration::seconds(seconds * sign) + + Duration::milliseconds(millis * sign); + + Ok(duration) + } +} diff --git a/sqlx-exasol-impl/src/types/time/mod.rs b/sqlx-exasol-impl/src/types/time/mod.rs new file mode 100644 index 00000000..d8083499 --- /dev/null +++ b/sqlx-exasol-impl/src/types/time/mod.rs @@ -0,0 +1,5 @@ +mod date; +mod datetime; +mod duration; + +pub use time::Duration; diff --git a/src/types/uuid.rs b/sqlx-exasol-impl/src/types/uuid.rs similarity index 65% rename from src/types/uuid.rs rename to sqlx-exasol-impl/src/types/uuid.rs index fe6249c5..740a77e0 100644 --- a/src/types/uuid.rs +++ b/sqlx-exasol-impl/src/types/uuid.rs @@ -10,17 +10,14 @@ use uuid::Uuid; use crate::{ arguments::ExaBuffer, database::Exasol, - type_info::{ExaDataType, ExaTypeInfo, HashType}, + type_info::{ExaDataType, ExaTypeInfo}, value::ExaValueRef, }; impl Type for Uuid { fn type_info() -> ExaTypeInfo { - ExaDataType::HashType(HashType {}).into() - } - - fn compatible(ty: &ExaTypeInfo) -> bool { - >::type_info().compatible(ty) + // 16 bytes * 2 because we set HASHTYPE_FORMAT to HEX + ExaDataType::HashType { size: Some(32) }.into() } } @@ -30,13 +27,9 @@ impl Encode<'_, Exasol> for Uuid { Ok(IsNull::No) } - fn produces(&self) -> Option { - Some(ExaDataType::HashType(HashType {}).into()) - } - fn size_hint(&self) -> usize { - // 16 bytes encoded as HEX, so double - 32 + // Serialized as string so: 2 * 16 HEX bytes + 4 dashes + 2 quotes + 38 } } diff --git a/src/value.rs b/sqlx-exasol-impl/src/value.rs similarity index 100% rename from src/value.rs rename to sqlx-exasol-impl/src/value.rs diff --git a/sqlx-exasol-macros/Cargo.toml b/sqlx-exasol-macros/Cargo.toml new file mode 100644 index 00000000..25e3fbdc --- /dev/null +++ b/sqlx-exasol-macros/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "sqlx-exasol-macros" +description = "Macros support for sqlx-exasol. Not meant to be used directly." +version.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors.workspace = true + +[lib] +proc-macro = true + +[features] +# SQLx features +derive = ["sqlx-macros-core/derive"] +macros = ["sqlx-macros-core/macros", "sqlx-exasol-impl/macros"] +migrate = ["sqlx-macros-core/migrate", "sqlx-exasol-impl/migrate"] + +# Types +bigdecimal = ["sqlx-macros-core/bigdecimal", "sqlx-exasol-impl/bigdecimal"] +chrono = ["sqlx-macros-core/chrono", "sqlx-exasol-impl/chrono"] +geo-types = ["sqlx-exasol-impl/geo-types"] +json = ["sqlx-macros-core/json", "sqlx-exasol-impl/json"] +rust_decimal = ["sqlx-macros-core/rust_decimal", "sqlx-exasol-impl/rust_decimal"] +time = ["sqlx-macros-core/time", "sqlx-exasol-impl/time"] +uuid = ["sqlx-macros-core/uuid", "sqlx-exasol-impl/uuid"] + +[dependencies] +sqlx-macros-core = { workspace = true } +sqlx-exasol-impl = { workspace = true } +syn = { workspace = true } +quote = { workspace = true } + +[lints] +workspace = true diff --git a/sqlx-exasol-macros/src/lib.rs b/sqlx-exasol-macros/src/lib.rs new file mode 100644 index 00000000..fbb27da5 --- /dev/null +++ b/sqlx-exasol-macros/src/lib.rs @@ -0,0 +1,23 @@ +#![cfg_attr(not(test), warn(unused_crate_dependencies))] + +use proc_macro::TokenStream; +use quote::quote; +use sqlx_macros_core::query; + +#[cfg(feature = "macros")] +#[proc_macro] +pub fn expand_query(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as query::QueryMacroInput); + + match query::expand_input(input, &[sqlx_exasol_impl::QUERY_DRIVER]) { + Ok(ts) => ts.into(), + Err(e) => { + if let Some(parse_err) = e.downcast_ref::() { + parse_err.to_compile_error().into() + } else { + let msg = e.to_string(); + quote!(::std::compile_error!(#msg)).into() + } + } + } +} diff --git a/src/column.rs b/src/column.rs deleted file mode 100644 index 341cc2f4..00000000 --- a/src/column.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::{fmt::Display, sync::Arc}; - -use serde::{Deserialize, Deserializer}; -use sqlx_core::{column::Column, database::Database}; - -use crate::{database::Exasol, type_info::ExaTypeInfo}; - -/// Implementor of [`Column`]. -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ExaColumn { - #[serde(skip)] - pub(crate) ordinal: usize, - #[serde(deserialize_with = "ExaColumn::lowercase_name")] - pub(crate) name: Arc, - pub(crate) data_type: ExaTypeInfo, -} - -impl ExaColumn { - fn lowercase_name<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - // NOTE: We can borrow because we always deserialize from an owned buffer. - <&str>::deserialize(deserializer) - .map(str::to_lowercase) - .map(From::from) - } -} - -impl Display for ExaColumn { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}: {}", self.name, self.data_type) - } -} - -impl Column for ExaColumn { - type Database = Exasol; - - fn ordinal(&self) -> usize { - self.ordinal - } - - fn name(&self) -> &str { - &self.name - } - - fn type_info(&self) -> &::TypeInfo { - &self.data_type - } -} diff --git a/src/lib.rs b/src/lib.rs index b4a6323b..5b186107 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,6 @@ //! Connection options: //! - `access-token`: Use an access token for login instead of credentials //! - `refresh-token`: Use a refresh token for login instead of credentials -//! - `protocol-version`: Select a specific protocol version to use //! - `ssl-mode`: Select a specifc SSL behavior. See: [`ExaSslMode`] //! - `ssl-ca`: Use a certain certificate authority //! - `ssl-cert`: Use a certain certificate @@ -55,11 +54,11 @@ //! - [`i8`], [`i16`], [`i32`], [`i64`], [`i128`] //! - [`f32`], [`f64`] //! - [`str`], [`String`], [`std::borrow::Cow`] -//! - `chrono` feature: [`chrono::DateTime`], [`chrono::DateTime`], -//! [`chrono::NaiveDateTime`], [`chrono::NaiveDate`], [`chrono::Duration`], [`Months`] (analog of -//! [`chrono::Months`]) -//! - `uuid` feature: [`uuid::Uuid`] -//! - `rust_decimal` feature: [`rust_decimal::Decimal`] +//! - `chrono` feature: [`crate::types::chrono::DateTime`], +//! [`crate::types::chrono::DateTime`], [`crate::types::chrono::NaiveDateTime`], +//! [`crate::types::chrono::NaiveDate`], +//! - `uuid` feature: [`crate::types::Uuid`] +//! - `rust_decimal` feature: [`crate::types::Decimal`] //! //! ## Supported Exasol datatypes: //! All Exasol datatypes are supported in some way, also depdending on the additional types used @@ -75,8 +74,8 @@ //! //! The data is always in `CSV` format and job configuration can be done through the //! [`etl::ImportBuilder`] and [`etl::ExportBuilder`] structs. The workers implement -//! [`futures_io::AsyncWrite`] and [`futures_io::AsyncRead`] respectively, providing great -//! flexibility in terms of how the data is processed. +//! `AsyncWrite` and `AsyncRead` respectively, providing great flexibility in terms of how the data +//! is processed. //! //! The general flow of an ETL job is: //! - build the job through [`etl::ImportBuilder`] or [`etl::ExportBuilder`] @@ -100,7 +99,7 @@ //! let pool = ExaPool::connect(&env::var("DATABASE_URL").unwrap()).await?; //! let mut con = pool.acquire().await?; //! -//! sqlx::query("CREATE SCHEMA RUST_DOC_TEST") +//! sqlx_exasol::query("CREATE SCHEMA RUST_DOC_TEST") //! .execute(&mut *con) //! .await?; //! # @@ -109,7 +108,7 @@ //! # }; //! ``` //! -//! Array-like parameter binding, also featuring the [`ExaIter`] adapter. +//! Array-like parameter binding, also featuring the [`crate::types::ExaIter`] adapter. //! An important thing to note is that the parameter sets must be of equal length, //! otherwise an error is thrown: //! ```rust,no_run @@ -125,9 +124,9 @@ //! let params1 = vec![1, 2, 3]; //! let params2 = HashSet::from([1, 2, 3]); //! -//! sqlx::query("INSERT INTO MY_TABLE VALUES (?, ?)") +//! sqlx_exasol::query("INSERT INTO MY_TABLE VALUES (?, ?)") //! .bind(¶ms1) -//! .bind(ExaIter::from(¶ms2)) +//! .bind(types::ExaIter::new(params2.iter())) //! .execute(&mut *con) //! .await?; //! # @@ -217,65 +216,35 @@ //! nullable or not, so the driver cannot implicitly decide whether a `NULL` value can go into a //! certain database column or not until it actually tries. -/// Gets rid of unused dependencies warning from dev-dependencies. -mod arguments; -mod column; -mod connection; -mod database; -mod error; -#[cfg(feature = "migrate")] -mod migrate; -mod options; -mod query_result; -mod responses; -mod row; -mod statement; -#[cfg(feature = "migrate")] -mod testing; -mod transaction; -mod type_info; -mod types; -mod value; +pub use sqlx::*; +pub use sqlx_exasol_impl::*; -pub use arguments::ExaArguments; -pub use column::ExaColumn; -#[cfg(feature = "etl")] -pub use connection::etl; -pub use connection::ExaConnection; -pub use database::Exasol; -pub use options::{ExaConnectOptions, ExaConnectOptionsBuilder, ExaSslMode, ProtocolVersion}; -pub use query_result::ExaQueryResult; -pub use responses::{ExaAttributes, ExaDatabaseError, SessionInfo}; -pub use row::ExaRow; -use sqlx_core::{ - executor::Executor, impl_acquire, impl_column_index_for_row, impl_column_index_for_statement, - impl_into_arguments_for_arguments, -}; -pub use statement::ExaStatement; -pub use transaction::ExaTransactionManager; -pub use type_info::ExaTypeInfo; -pub use types::ExaIter; -#[cfg(feature = "chrono")] -pub use types::Months; -pub use value::{ExaValue, ExaValueRef}; +pub mod types { + pub use sqlx::types::*; + pub use sqlx_exasol_impl::types::*; -/// An alias for [`Pool`][sqlx_core::pool::Pool], specialized for Exasol. -pub type ExaPool = sqlx_core::pool::Pool; + #[cfg(feature = "chrono")] + pub mod chrono { + pub use sqlx::types::chrono::*; + pub use sqlx_exasol_impl::types::chrono::*; + } -/// An alias for [`PoolOptions`][sqlx_core::pool::PoolOptions], specialized for Exasol. -pub type ExaPoolOptions = sqlx_core::pool::PoolOptions; + #[cfg(feature = "time")] + pub mod time { + pub use sqlx::types::time::*; + pub use sqlx_exasol_impl::types::time::*; + } +} -/// An alias for [`Executor<'_, Database = Exasol>`][Executor]. -pub trait ExaExecutor<'c>: Executor<'c, Database = Exasol> {} -impl<'c, T: Executor<'c, Database = Exasol>> ExaExecutor<'c> for T {} +pub mod any { + pub use sqlx::any::*; + pub use sqlx_exasol_impl::any::DRIVER; +} +#[cfg(feature = "macros")] +pub use sqlx_exasol_macros; +#[cfg(feature = "macros")] +mod macros; -impl_into_arguments_for_arguments!(ExaArguments); -impl_acquire!(Exasol, ExaConnection); -impl_column_index_for_row!(ExaRow); -impl_column_index_for_statement!(ExaStatement); - -// ################### -// ##### Aliases ##### -// ################### -type SqlxError = sqlx_core::Error; -type SqlxResult = sqlx_core::Result; +#[cfg(feature = "macros")] +#[doc(hidden)] +pub mod ty_match; diff --git a/src/macros.rs b/src/macros.rs new file mode 100644 index 00000000..eac467ff --- /dev/null +++ b/src/macros.rs @@ -0,0 +1,133 @@ +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query ( + ($query:expr) => ({ + $crate::sqlx_exasol_macros::expand_query!(source = $query) + }); + ($query:expr, $($args:tt)*) => ({ + $crate::sqlx_exasol_macros::expand_query!(source = $query, args = [$($args)*]) + }) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_unchecked ( + ($query:expr) => ({ + $crate::sqlx_exasol_macros::expand_query!(source = $query, checked = false) + }); + ($query:expr, $($args:tt)*) => ({ + $crate::sqlx_exasol_macros::expand_query!(source = $query, args = [$($args)*], checked = false) + }) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_file ( + ($path:literal) => ({ + $crate::sqlx_exasol_macros::expand_query!(source_file = $path) + }); + ($path:literal, $($args:tt)*) => ({ + $crate::sqlx_exasol_macros::expand_query!(source_file = $path, args = [$($args)*]) + }) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_file_unchecked ( + ($path:literal) => ({ + $crate::sqlx_exasol_macros::expand_query!(source_file = $path, checked = false) + }); + ($path:literal, $($args:tt)*) => ({ + $crate::sqlx_exasol_macros::expand_query!(source_file = $path, args = [$($args)*], checked = false) + }) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_as ( + ($out_struct:path, $query:expr) => ( { + $crate::sqlx_exasol_macros::expand_query!(record = $out_struct, source = $query) + }); + ($out_struct:path, $query:expr, $($args:tt)*) => ( { + $crate::sqlx_exasol_macros::expand_query!(record = $out_struct, source = $query, args = [$($args)*]) + }) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_file_as ( + ($out_struct:path, $path:literal) => ( { + $crate::sqlx_exasol_macros::expand_query!(record = $out_struct, source_file = $path) + }); + ($out_struct:path, $path:literal, $($args:tt)*) => ( { + $crate::sqlx_exasol_macros::expand_query!(record = $out_struct, source_file = $path, args = [$($args)*]) + }) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_as_unchecked ( + ($out_struct:path, $query:expr) => ( { + $crate::sqlx_exasol_macros::expand_query!(record = $out_struct, source = $query, checked = false) + }); + + ($out_struct:path, $query:expr, $($args:tt)*) => ( { + $crate::sqlx_exasol_macros::expand_query!(record = $out_struct, source = $query, args = [$($args)*], checked = false) + }) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_file_as_unchecked ( + ($out_struct:path, $path:literal) => ( { + $crate::sqlx_exasol_macros::expand_query!(record = $out_struct, source_file = $path, checked = false) + }); + + ($out_struct:path, $path:literal, $($args:tt)*) => ( { + $crate::sqlx_exasol_macros::expand_query!(record = $out_struct, source_file = $path, args = [$($args)*], checked = false) + }) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_scalar ( + ($query:expr) => ( + $crate::sqlx_exasol_macros::expand_query!(scalar = _, source = $query) + ); + ($query:expr, $($args:tt)*) => ( + $crate::sqlx_exasol_macros::expand_query!(scalar = _, source = $query, args = [$($args)*]) + ) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_file_scalar ( + ($path:literal) => ( + $crate::sqlx_exasol_macros::expand_query!(scalar = _, source_file = $path) + ); + ($path:literal, $($args:tt)*) => ( + $crate::sqlx_exasol_macros::expand_query!(scalar = _, source_file = $path, args = [$($args)*]) + ) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_scalar_unchecked ( + ($query:expr) => ( + $crate::sqlx_exasol_macros::expand_query!(scalar = _, source = $query, checked = false) + ); + ($query:expr, $($args:tt)*) => ( + $crate::sqlx_exasol_macros::expand_query!(scalar = _, source = $query, args = [$($args)*], checked = false) + ) +); + +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_file_scalar_unchecked ( + ($path:literal) => ( + $crate::sqlx_exasol_macros::expand_query!(scalar = _, source_file = $path, checked = false) + ); + ($path:literal, $($args:tt)*) => ( + $crate::sqlx_exasol_macros::expand_query!(scalar = _, source_file = $path, args = [$($args)*], checked = false) + ) +); diff --git a/src/migrate.rs b/src/migrate.rs deleted file mode 100644 index 499ecc4c..00000000 --- a/src/migrate.rs +++ /dev/null @@ -1,229 +0,0 @@ -use std::{ - str::FromStr, - time::{Duration, Instant}, -}; - -use futures_core::future::BoxFuture; -use sqlx_core::{ - connection::{ConnectOptions, Connection}, - executor::Executor, - migrate::{AppliedMigration, Migrate, MigrateDatabase, MigrateError, Migration}, - query::query, - query_as::query_as, - query_scalar::query_scalar, -}; - -use crate::{ - connection::{ - websocket::future::{ExecuteBatch, WebSocketFuture}, - ExaConnection, - }, - database::Exasol, - options::ExaConnectOptions, - SqlxError, SqlxResult, -}; - -const LOCK_WARN: &str = "Exasol does not support database locking!"; - -fn parse_for_maintenance(url: &str) -> SqlxResult<(ExaConnectOptions, String)> { - let mut options = ExaConnectOptions::from_str(url)?; - - let database = options.schema.ok_or_else(|| { - SqlxError::Configuration("DATABASE_URL does not specify a database".into()) - })?; - - // switch to database for create/drop commands - options.schema = None; - - Ok((options, database)) -} - -impl MigrateDatabase for Exasol { - fn create_database(url: &str) -> BoxFuture<'_, SqlxResult<()>> { - Box::pin(async move { - let (options, database) = parse_for_maintenance(url)?; - let mut conn = options.connect().await?; - - let query = format!("CREATE SCHEMA \"{}\"", database.replace('"', "\"\"")); - let _ = conn.execute(&*query).await?; - - Ok(()) - }) - } - - fn database_exists(url: &str) -> BoxFuture<'_, SqlxResult> { - Box::pin(async move { - let (options, database) = parse_for_maintenance(url)?; - let mut conn = options.connect().await?; - - let query = "SELECT true FROM exa_schemas WHERE schema_name = ?"; - let exists: bool = query_scalar(query) - .bind(database) - .fetch_one(&mut conn) - .await?; - - Ok(exists) - }) - } - - fn drop_database(url: &str) -> BoxFuture<'_, SqlxResult<()>> { - Box::pin(async move { - let (options, database) = parse_for_maintenance(url)?; - let mut conn = options.connect().await?; - - let query = format!("DROP SCHEMA IF EXISTS `{database}`"); - let _ = conn.execute(&*query).await?; - - Ok(()) - }) - } -} - -impl Migrate for ExaConnection { - fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { - Box::pin(async move { - let query = r#" - CREATE TABLE IF NOT EXISTS "_sqlx_migrations" ( - version DECIMAL(20, 0), - description CLOB NOT NULL, - installed_on TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - success BOOLEAN NOT NULL, - checksum CLOB NOT NULL, - execution_time DECIMAL(20, 0) NOT NULL - );"#; - - self.execute(query).await?; - Ok(()) - }) - } - - fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { - Box::pin(async move { - let query = r#" - SELECT version - FROM "_sqlx_migrations" - WHERE success = false - ORDER BY version - LIMIT 1 - "#; - - let row: Option<(i64,)> = query_as(query).fetch_optional(self).await?; - - Ok(row.map(|r| r.0)) - }) - } - - fn list_applied_migrations( - &mut self, - ) -> BoxFuture<'_, Result, MigrateError>> { - Box::pin(async move { - let query = r#" - SELECT version, checksum - FROM "_sqlx_migrations" - ORDER BY version - "#; - - let rows: Vec<(i64, String)> = query_as(query).fetch_all(self).await?; - - let mut migrations = Vec::with_capacity(rows.len()); - - for (version, checksum) in rows { - let checksum = hex::decode(checksum) - .map_err(From::from) - .map_err(MigrateError::Source)? - .into(); - - let migration = AppliedMigration { version, checksum }; - migrations.push(migration); - } - - Ok(migrations) - }) - } - - fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { - Box::pin(async move { - tracing::warn!("{LOCK_WARN}"); - Ok(()) - }) - } - - fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { - Box::pin(async move { - tracing::warn!("{LOCK_WARN}"); - Ok(()) - }) - } - - fn apply<'e: 'm, 'm>( - &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { - Box::pin(async move { - let mut tx = self.begin().await?; - let start = Instant::now(); - - ExecuteBatch::new(migration.sql.as_ref()) - .future(&mut tx.ws) - .await?; - - let checksum = hex::encode(&*migration.checksum); - - let query_str = r#" - INSERT INTO "_sqlx_migrations" ( version, description, success, checksum, execution_time ) - VALUES ( ?, ?, TRUE, ?, -1 ); - "#; - - let _ = query(query_str) - .bind(migration.version) - .bind(&*migration.description) - .bind(checksum) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - - let elapsed = start.elapsed(); - - let query_str = r#" - UPDATE "_sqlx_migrations" - SET execution_time = ? - WHERE version = ? - "#; - - let _ = query(query_str) - .bind(elapsed.as_nanos()) - .bind(migration.version) - .execute(self) - .await?; - - Ok(elapsed) - }) - } - - fn revert<'e: 'm, 'm>( - &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result> { - Box::pin(async move { - let mut tx = self.begin().await?; - let start = Instant::now(); - - ExecuteBatch::new(migration.sql.as_ref()) - .future(&mut tx.ws) - .await?; - - let query_str = r#" DELETE FROM "_sqlx_migrations" WHERE version = ? "#; - let _ = query(query_str) - .bind(migration.version) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - - let elapsed = start.elapsed(); - - Ok(elapsed) - }) - } -} diff --git a/src/options/mod.rs b/src/options/mod.rs deleted file mode 100644 index dfb13527..00000000 --- a/src/options/mod.rs +++ /dev/null @@ -1,291 +0,0 @@ -mod builder; -mod error; -mod protocol_version; -mod ssl_mode; - -use std::{borrow::Cow, net::SocketAddr, num::NonZeroUsize, path::PathBuf, str::FromStr}; - -pub use builder::ExaConnectOptionsBuilder; -use error::ExaConfigError; -pub use protocol_version::ProtocolVersion; -use sqlx_core::{ - connection::{ConnectOptions, LogSettings}, - net::tls::CertificateInput, -}; -pub use ssl_mode::ExaSslMode; -use tracing::log; -use url::Url; - -use crate::{ - connection::{ - websocket::request::{ExaLoginRequest, LoginRef}, - ExaConnection, - }, - responses::ExaRwAttributes, - SqlxError, SqlxResult, -}; - -const URL_SCHEME: &str = "exa"; - -const DEFAULT_FETCH_SIZE: usize = 5 * 1024 * 1024; -const DEFAULT_PORT: u16 = 8563; -const DEFAULT_CACHE_CAPACITY: NonZeroUsize = match NonZeroUsize::new(100) { - Some(v) => v, - None => unreachable!(), -}; - -const PARAM_ACCESS_TOKEN: &str = "access-token"; -const PARAM_REFRESH_TOKEN: &str = "refresh-token"; -const PARAM_PROTOCOL_VERSION: &str = "protocol-version"; -const PARAM_SSL_MODE: &str = "ssl-mode"; -const PARAM_SSL_CA: &str = "ssl-ca"; -const PARAM_SSL_CERT: &str = "ssl-cert"; -const PARAM_SSL_KEY: &str = "ssl-key"; -const PARAM_CACHE_CAP: &str = "statement-cache-capacity"; -const PARAM_FETCH_SIZE: &str = "fetch-size"; -const PARAM_QUERY_TIMEOUT: &str = "query-timeout"; -const PARAM_COMPRESSION: &str = "compression"; -const PARAM_FEEDBACK_INTERVAL: &str = "feedback-interval"; - -/// Options for connecting to the Exasol database. Implementor of [`ConnectOptions`]. -/// -/// While generally automatically created through a connection string, -/// [`ExaConnectOptions::builder()`] can be used to get a [`ExaConnectOptionsBuilder`]. -#[derive(Debug, Clone)] -pub struct ExaConnectOptions { - pub(crate) hosts_details: Vec<(String, Vec)>, - pub(crate) port: u16, - pub(crate) ssl_mode: ExaSslMode, - pub(crate) ssl_ca: Option, - pub(crate) ssl_client_cert: Option, - pub(crate) ssl_client_key: Option, - pub(crate) statement_cache_capacity: NonZeroUsize, - pub(crate) schema: Option, - pub(crate) compression: bool, - login: Login, - protocol_version: ProtocolVersion, - fetch_size: usize, - query_timeout: u64, - feedback_interval: u64, - log_settings: LogSettings, -} - -impl ExaConnectOptions { - #[must_use] - pub fn builder() -> ExaConnectOptionsBuilder { - ExaConnectOptionsBuilder::default() - } -} - -impl FromStr for ExaConnectOptions { - type Err = SqlxError; - - fn from_str(s: &str) -> Result { - let url = Url::parse(s) - .map_err(From::from) - .map_err(SqlxError::Configuration)?; - Self::from_url(&url) - } -} - -impl ConnectOptions for ExaConnectOptions { - type Connection = ExaConnection; - - fn from_url(url: &Url) -> SqlxResult { - let scheme = url.scheme(); - - if URL_SCHEME != scheme { - return Err(ExaConfigError::InvalidUrlScheme(scheme.to_owned()).into()); - } - - let mut builder = Self::builder(); - - if let Some(host) = url.host_str() { - builder = builder.host(host.to_owned()); - } - - let username = url.username(); - if !username.is_empty() { - builder = builder.username(username.to_owned()); - } - - if let Some(password) = url.password() { - builder = builder.password(password.to_owned()); - } - - if let Some(port) = url.port() { - builder = builder.port(port); - } - - let opt_schema = url.path_segments().into_iter().flatten().next(); - - if let Some(schema) = opt_schema { - builder = builder.schema(schema.to_owned()); - } - - for (name, value) in url.query_pairs() { - match name.as_ref() { - PARAM_ACCESS_TOKEN => builder = builder.access_token(value.to_string()), - - PARAM_REFRESH_TOKEN => builder = builder.refresh_token(value.to_string()), - - PARAM_PROTOCOL_VERSION => { - let protocol_version = value.parse::()?; - builder = builder.protocol_version(protocol_version); - } - - PARAM_SSL_MODE => { - let ssl_mode = value.parse::()?; - builder = builder.ssl_mode(ssl_mode); - } - - PARAM_SSL_CA => { - let ssl_ca = CertificateInput::File(PathBuf::from(value.to_string())); - builder = builder.ssl_ca(ssl_ca); - } - - PARAM_SSL_CERT => { - let ssl_cert = CertificateInput::File(PathBuf::from(value.to_string())); - builder = builder.ssl_client_cert(ssl_cert); - } - - PARAM_SSL_KEY => { - let ssl_key = CertificateInput::File(PathBuf::from(value.to_string())); - builder = builder.ssl_client_key(ssl_key); - } - - PARAM_CACHE_CAP => { - let capacity = value - .parse::() - .map_err(|_| ExaConfigError::InvalidParameter(PARAM_CACHE_CAP))?; - builder = builder.statement_cache_capacity(capacity); - } - - PARAM_FETCH_SIZE => { - let fetch_size = value - .parse::() - .map_err(|_| ExaConfigError::InvalidParameter(PARAM_FETCH_SIZE))?; - builder = builder.fetch_size(fetch_size); - } - - PARAM_QUERY_TIMEOUT => { - let query_timeout = value - .parse::() - .map_err(|_| ExaConfigError::InvalidParameter(PARAM_QUERY_TIMEOUT))?; - builder = builder.query_timeout(query_timeout); - } - - PARAM_COMPRESSION => { - let compression = value - .parse::() - .map_err(|_| ExaConfigError::InvalidParameter(PARAM_COMPRESSION))?; - builder = builder.compression(compression); - } - - PARAM_FEEDBACK_INTERVAL => { - let feedback_interval = value - .parse::() - .map_err(|_| ExaConfigError::InvalidParameter(PARAM_FEEDBACK_INTERVAL))?; - builder = builder.feedback_interval(feedback_interval); - } - - _ => { - return Err(SqlxError::Protocol(format!( - "Unknown connection string parameter: {value}" - ))) - } - }; - } - - builder.build() - } - - fn connect(&self) -> futures_util::future::BoxFuture<'_, SqlxResult> - where - Self::Connection: Sized, - { - Box::pin(ExaConnection::establish(self)) - } - - fn log_statements(mut self, level: log::LevelFilter) -> Self { - self.log_settings.log_statements(level); - self - } - - fn log_slow_statements( - mut self, - level: log::LevelFilter, - duration: std::time::Duration, - ) -> Self { - self.log_settings.log_slow_statements(level, duration); - self - } -} - -impl<'a> From<&'a ExaConnectOptions> for ExaLoginRequest<'a> { - fn from(value: &'a ExaConnectOptions) -> Self { - let crate_version = option_env!("CARGO_PKG_VERSION").unwrap_or("UNKNOWN"); - - let attributes = ExaRwAttributes::new( - value.schema.as_deref().map(Cow::Borrowed), - value.feedback_interval, - value.query_timeout, - ); - - Self { - protocol_version: value.protocol_version, - fetch_size: value.fetch_size, - statement_cache_capacity: value.statement_cache_capacity, - login: (&value.login).into(), - use_compression: value.compression, - client_name: "sqlx-exasol", - client_version: crate_version, - client_os: std::env::consts::OS, - client_runtime: "RUST", - attributes, - } - } -} - -/// Enum representing the possible ways of authenticating a connection. -/// The variant chosen dictates which login process is called. -#[derive(Clone, Debug)] -pub enum Login { - Credentials { username: String, password: String }, - AccessToken { access_token: String }, - RefreshToken { refresh_token: String }, -} - -impl<'a> From<&'a Login> for LoginRef<'a> { - fn from(value: &'a Login) -> Self { - match value { - Login::Credentials { username, password } => LoginRef::Credentials { - username, - password: Cow::Borrowed(password), - }, - Login::AccessToken { access_token } => LoginRef::AccessToken { access_token }, - Login::RefreshToken { refresh_token } => LoginRef::RefreshToken { refresh_token }, - } - } -} - -/// Helper containing TLS related options. -#[derive(Debug, Clone, Copy)] -#[allow(clippy::struct_field_names)] -pub struct ExaTlsOptionsRef<'a> { - pub ssl_mode: ExaSslMode, - pub ssl_ca: Option<&'a CertificateInput>, - pub ssl_client_cert: Option<&'a CertificateInput>, - pub ssl_client_key: Option<&'a CertificateInput>, -} - -impl<'a> From<&'a ExaConnectOptions> for ExaTlsOptionsRef<'a> { - fn from(value: &'a ExaConnectOptions) -> Self { - ExaTlsOptionsRef { - ssl_mode: value.ssl_mode, - ssl_ca: value.ssl_ca.as_ref(), - ssl_client_cert: value.ssl_client_cert.as_ref(), - ssl_client_key: value.ssl_client_key.as_ref(), - } - } -} diff --git a/src/options/protocol_version.rs b/src/options/protocol_version.rs deleted file mode 100644 index 0e269822..00000000 --- a/src/options/protocol_version.rs +++ /dev/null @@ -1,62 +0,0 @@ -use std::{ - fmt::{Display, Formatter}, - str::FromStr, -}; - -use serde::{Deserialize, Serialize}; - -use super::{error::ExaConfigError, PARAM_PROTOCOL_VERSION}; - -/// Enum listing the protocol versions that can be used when establishing a websocket connection to -/// Exasol. Defaults to the highest defined protocol version and falls back to the highest protocol -/// version supported by the server. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize)] -#[serde(try_from = "u8")] -#[serde(into = "u8")] -#[repr(u8)] -pub enum ProtocolVersion { - V1 = 1, - V2 = 2, - V3 = 3, - V4 = 4, -} - -impl FromStr for ProtocolVersion { - type Err = ExaConfigError; - - fn from_str(s: &str) -> Result { - match s { - "1" => Ok(ProtocolVersion::V1), - "2" => Ok(ProtocolVersion::V2), - "3" => Ok(ProtocolVersion::V3), - "4" => Ok(ProtocolVersion::V4), - _ => Err(ExaConfigError::InvalidParameter(PARAM_PROTOCOL_VERSION)), - } - } -} - -impl From for u8 { - fn from(value: ProtocolVersion) -> Self { - value as Self - } -} - -impl TryFrom for ProtocolVersion { - type Error = ExaConfigError; - - fn try_from(value: u8) -> Result { - match value { - 1 => Ok(Self::V1), - 2 => Ok(Self::V2), - 3 => Ok(Self::V3), - 4 => Ok(Self::V4), - _ => Err(ExaConfigError::InvalidParameter(PARAM_PROTOCOL_VERSION)), - } - } -} - -impl Display for ProtocolVersion { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", *self as u8) - } -} diff --git a/src/transaction.rs b/src/transaction.rs deleted file mode 100644 index 4e608cd5..00000000 --- a/src/transaction.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::borrow::Cow; - -use futures_core::future::BoxFuture; -use futures_util::FutureExt; -use sqlx_core::transaction::TransactionManager; - -use crate::{ - connection::websocket::future::{Commit, Rollback, WebSocketFuture}, - database::Exasol, - error::ExaProtocolError, - ExaConnection, SqlxResult, -}; - -/// Implementor of [`TransactionManager`]. -#[derive(Debug, Clone, Copy)] -pub struct ExaTransactionManager; - -impl TransactionManager for ExaTransactionManager { - type Database = Exasol; - - fn begin<'conn>( - conn: &'conn mut ExaConnection, - _: Option>, - ) -> BoxFuture<'conn, SqlxResult<()>> { - Box::pin(async { - // Exasol does not have nested transactions. - if conn.attributes().open_transaction() { - // A pending rollback indicates that a transaction was dropped before an explicit - // rollback, which is why it's still open. If that's the case, then awaiting the - // rollback is sufficient to proceed. - match conn.ws.pending_rollback.take() { - Some(rollback) => rollback.future(&mut conn.ws).await?, - None => return Err(ExaProtocolError::TransactionAlreadyOpen)?, - } - } - - // The next time a request is sent, the transaction will be started. - // We could eagerly start it as well, but that implies one more round-trip to the server - // and back with no benefit. - conn.attributes_mut().set_autocommit(false); - Ok(()) - }) - } - - fn commit(conn: &mut ExaConnection) -> BoxFuture<'_, SqlxResult<()>> { - async move { Commit::default().future(&mut conn.ws).await }.boxed() - } - - fn rollback(conn: &mut ExaConnection) -> BoxFuture<'_, SqlxResult<()>> { - async move { Rollback::default().future(&mut conn.ws).await }.boxed() - } - - fn start_rollback(conn: &mut ExaConnection) { - conn.ws.pending_rollback = Some(Rollback::default()); - } - - fn get_transaction_depth(conn: &ExaConnection) -> usize { - conn.attributes().open_transaction().into() - } -} diff --git a/src/ty_match/extension.rs b/src/ty_match/extension.rs new file mode 100644 index 00000000..2b7a94d3 --- /dev/null +++ b/src/ty_match/extension.rs @@ -0,0 +1,371 @@ +//! Extension module that houses the array-like types implementations. + +use std::{rc::Rc, sync::Arc}; + +use crate::{ + ty_match::{MatchBorrow, MatchBorrowExt, WrapSame, WrapSameExt}, + types::ExaIter, +}; + +/// Owned types have the highest priorty in autoref specialization. +/// This will take precedence if we're wrapping a type in an slice of [`Option`]. +impl<'a, T, U> WrapSameExt for WrapSame]> +where + T: 'a, +{ + type Wrapped = &'a [Option]; +} + +/// Immutable references have middle priorty in autoref specialization. +/// This will take precedence over the mutalbe reference implementation but not before the owned +/// type implementation. +/// +/// It wraps the type in a slice. +impl<'a, T, U> WrapSameExt for &WrapSame +where + T: 'a, +{ + type Wrapped = &'a [T]; +} + +/// Owned types have the highest priorty in autoref specialization. +/// This will take precedence if we're wrapping a type in a fixed size array of [`Option`]. +impl WrapSameExt for WrapSame; N]> { + type Wrapped = [Option; N]; +} + +/// Immutable references have middle priorty in autoref specialization. +/// This will take precedence over the mutalbe reference implementation but not before the owned +/// type implementation. +/// +/// It wraps the type in a fixed size array. +impl WrapSameExt for &WrapSame { + type Wrapped = [T; N]; +} + +/// Owned types have the highest priorty in autoref specialization. +/// This will take precedence if we're wrapping a type in a [`Vec