diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ac5fa1864..87b9466ba 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -40,6 +40,7 @@ jobs: push: true tags: | ghcr.io/umccr/htsget-rs:latest + ## SOCI (Seekable OCI) support. Only enable when and if docker layers surpass 10MB in the future, see: # https://github.com/awslabs/soci-snapshotter/issues/100 # - name: Install aws SOCI diff --git a/Cargo.lock b/Cargo.lock index 4bb58e9b1..9a674c8d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,9 +36,9 @@ dependencies = [ [[package]] name = "actix-http" -version = "3.6.0" +version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d223b13fd481fc0d1f83bb12659ae774d9e3601814c68a0bc539731698cca743" +checksum = "4eb9843d84c775696c37d9a418bbb01b932629d01870722c0f13eb3f95e2536d" dependencies = [ "actix-codec", "actix-rt", @@ -46,7 +46,7 @@ dependencies = [ "actix-tls", "actix-utils", "ahash", - "base64 0.21.7", + "base64 0.22.1", "bitflags 2.5.0", "brotli", "bytes", @@ -81,18 +81,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e01ed3140b2f8d422c68afa1ed2e85d996ea619c988ac834d255db32138655cb" dependencies = [ "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] name = "actix-router" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d22475596539443685426b6bdadb926ad0ecaefdfc5fb05e5e3441f15463c511" +checksum = "13d324164c51f63867b57e73ba5936ea151b8a41a1d23d1031eeb9f70d0236f8" dependencies = [ "bytestring", + "cfg-if", "http", "regex", + "regex-lite", "serde", "tracing", ] @@ -137,9 +139,9 @@ dependencies = [ [[package]] name = "actix-tls" -version = "3.3.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4cce60a2f2b477bc72e5cde0af1812a6e82d8fd85b5570a5dcf2a5bf2c5be5f" +checksum = "ac453898d866cdbecdbc2334fe1738c747b4eba14a677261f2b768ba05329389" dependencies = [ "actix-rt", "actix-service", @@ -166,9 +168,9 @@ dependencies = [ [[package]] name = "actix-web" -version = "4.5.1" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a6556ddebb638c2358714d853257ed226ece6023ef9364f23f0c70737ea984" +checksum = "b1cf67dadb19d7c95e5a299e2dda24193b89d5d4f33a3b9800888ede9e19aa32" dependencies = [ "actix-codec", "actix-http", @@ -196,6 +198,7 @@ dependencies = [ "once_cell", "pin-project-lite", "regex", + "regex-lite", "serde", "serde_json", "serde_urlencoded", @@ -214,7 +217,7 @@ dependencies = [ "actix-router", "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -232,6 +235,27 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahash" version = "0.8.11" @@ -369,6 +393,31 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-crypt4gh" +version = "0.1.0" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bincode", + "bstr", + "bytes", + "crypt4gh", + "futures", + "futures-util", + "hex-literal", + "htsget-test", + "noodles 0.60.0", + "pin-project-lite", + "rand_chacha", + "rustls", + "tempfile", + "thiserror", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -388,7 +437,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -399,7 +448,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -949,6 +998,32 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + +[[package]] +name = "bcrypt-pbkdf" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6aeac2e1fe888769f34f05ac343bbef98b14d1ffb292ab69d4608b3abc86f2a2" +dependencies = [ + "blowfish", + "pbkdf2", + "sha2", +] + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bit-vec" version = "0.6.3" @@ -967,6 +1042,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -976,11 +1060,30 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-padding" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" +dependencies = [ + "generic-array", +] + +[[package]] +name = "blowfish" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e412e2cd0f2b2d93e02543ceae7917b3c70331573df19ee046bcbc35e45e87d7" +dependencies = [ + "byteorder", + "cipher", +] + [[package]] name = "brotli" -version = "3.5.0" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -989,9 +1092,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "2.5.1" +version = "4.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +checksum = "e6221fe77a248b9117d431ad93761222e1cf8ff282d9d1d5d9f53d6299a1cf76" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1004,6 +1107,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706" dependencies = [ "memchr", + "regex-automata 0.4.6", "serde", ] @@ -1015,9 +1119,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" +checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" [[package]] name = "byteorder" @@ -1080,11 +1184,20 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cbc" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" +dependencies = [ + "cipher", +] + [[package]] name = "cc" -version = "1.0.97" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4" +checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" dependencies = [ "jobserver", "libc", @@ -1097,6 +1210,30 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chacha20" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "chacha20poly1305" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" +dependencies = [ + "aead", + "chacha20", + "cipher", + "poly1305", + "zeroize", +] + [[package]] name = "chrono" version = "0.4.38" @@ -1137,6 +1274,17 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", + "zeroize", +] + [[package]] name = "clap" version = "4.5.4" @@ -1156,7 +1304,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex", - "strsim 0.11.1", + "strsim", ] [[package]] @@ -1168,7 +1316,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -1256,7 +1404,7 @@ dependencies = [ "criterion-plot", "futures", "is-terminal", - "itertools", + "itertools 0.10.5", "num-traits", "once_cell", "oorandom", @@ -1278,14 +1426,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", ] [[package]] name = "crossbeam-channel" -version = "0.5.12" +version = "0.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" +checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" dependencies = [ "crossbeam-utils", ] @@ -1311,9 +1459,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.19" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crunchy" @@ -1321,6 +1469,35 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "crypt4gh" +version = "0.4.1" +source = "git+https://github.com/EGA-archive/crypt4gh-rust#2d41a1770067003bc67ab499841e0def186ed218" +dependencies = [ + "aes", + "base64 0.21.7", + "bcrypt-pbkdf", + "bincode", + "cbc", + "chacha20poly1305", + "clap", + "crypto_kx", + "ctr", + "curve25519-dalek", + "ed25519_to_curve25519", + "itertools 0.11.0", + "lazy_static", + "log", + "pretty_env_logger", + "rand", + "rand_chacha", + "regex", + "rpassword", + "scrypt", + "serde", + "thiserror", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -1328,14 +1505,62 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core", "typenum", ] +[[package]] +name = "crypto_kx" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704722d1d929489c8528bb1882805700f1ba20f54325704973e786352320b1ed" +dependencies = [ + "blake2", + "curve25519-dalek", + "rand_core", +] + +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + +[[package]] +name = "curve25519-dalek" +version = "4.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a677b8922c94e01bdbb12126b0bc852f00447528dee1782229af9c720c3f348" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "fiat-crypto", + "platforms", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.65", +] + [[package]] name = "darling" -version = "0.20.8" +version = "0.20.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391" +checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" dependencies = [ "darling_core", "darling_macro", @@ -1343,27 +1568,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.8" +version = "0.20.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f" +checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", - "strsim 0.10.0", - "syn 2.0.61", + "strsim", + "syn 2.0.65", ] [[package]] name = "darling_macro" -version = "0.20.8" +version = "0.20.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" +checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" dependencies = [ "darling_core", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -1401,6 +1626,12 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" +[[package]] +name = "difference" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524cbf6897b527295dff137cec09ecf3a05f4fddffd7dfcd1585403449e74198" + [[package]] name = "digest" version = "0.10.7" @@ -1412,11 +1643,23 @@ dependencies = [ "subtle", ] +[[package]] +name = "downcast" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bb454f0228b18c7f4c3b0ebbee346ed9c52e7443b0999cd543ff3571205701d" + +[[package]] +name = "ed25519_to_curve25519" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a976025474add79730a8df2913b114afd39bc53ce5633e045100aceb6d06bb6" + [[package]] name = "either" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" [[package]] name = "encoding_rs" @@ -1427,6 +1670,19 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "env_logger" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +dependencies = [ + "humantime", + "is-terminal", + "log", + "regex", + "termcolor", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -1461,11 +1717,17 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "figment" -version = "0.10.18" +version = "0.10.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d032832d74006f99547004d49410a4b4218e4c33382d56ca3ff89df74f86b953" +checksum = "8cb01cd46b0cf372153850f4c6c272d9cbea2da513e07538405148f95bd789f3" dependencies = [ "atomic", "parking_lot", @@ -1487,6 +1749,15 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "float-cmp" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1267f4ac4f343772758f7b1bdcbe767c218bbab93bb432acbf5162bbf85a6c4" +dependencies = [ + "num-traits", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1502,6 +1773,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7464c5c4a3f014d9b2ec4073650e5c06596f385060af740fc45ad5a19f959e8" +dependencies = [ + "fragile 2.0.0", +] + +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "futures" version = "0.3.30" @@ -1558,7 +1844,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -1677,6 +1963,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hex-literal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" + [[package]] name = "hex-simd" version = "0.8.0" @@ -1725,12 +2017,14 @@ dependencies = [ name = "htsget-config" version = "0.9.0" dependencies = [ + "async-crypt4gh", "async-trait", "clap", + "crypt4gh", "figment", "http", "http-serde", - "noodles", + "noodles 0.65.0", "rcgen", "regex", "reqwest", @@ -1791,6 +2085,7 @@ dependencies = [ name = "htsget-search" version = "0.7.0" dependencies = [ + "async-crypt4gh", "async-trait", "aws-config", "aws-sdk-s3", @@ -1798,6 +2093,7 @@ dependencies = [ "base64 0.21.7", "bytes", "criterion", + "crypt4gh", "data-url", "futures", "futures-util", @@ -1805,8 +2101,10 @@ dependencies = [ "htsget-test", "http", "hyper", - "noodles", - "pin-project-lite", + "mockall", + "mockall_double", + "noodles 0.65.0", + "pin-project", "reqwest", "rustls-pemfile", "serde", @@ -1819,22 +2117,26 @@ dependencies = [ "tower-http", "tracing", "url", + "walkdir", ] [[package]] name = "htsget-test" version = "0.6.0" dependencies = [ + "async-crypt4gh", "async-trait", "aws-config", "aws-credential-types", "aws-sdk-s3", + "axum", "base64 0.21.7", + "crypt4gh", "futures", "htsget-config", "http", "mime", - "noodles", + "noodles 0.65.0", "rcgen", "reqwest", "s3s", @@ -1845,6 +2147,11 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-rustls", + "tokio-util", + "tower", + "tower-http", + "walkdir", ] [[package]] @@ -1897,6 +2204,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "hyper" version = "0.14.28" @@ -2010,6 +2323,16 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "block-padding", + "generic-array", +] + [[package]] name = "ipnet" version = "2.9.0" @@ -2042,6 +2365,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -2204,15 +2536,15 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.154" +version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" [[package]] name = "linux-raw-sys" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "local-channel" @@ -2313,9 +2645,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" dependencies = [ "adler", ] @@ -2332,6 +2664,45 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mockall" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d614ad23f9bb59119b8b5670a85c7ba92c5e9adf4385c81ea00c51c8be33d5" +dependencies = [ + "cfg-if", + "downcast", + "fragile 1.2.2", + "lazy_static", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dd4234635bca06fc96c7368d038061e0aae1b00a764dc817e900dc974e3deea" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "mockall_double" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1ca96e5ac35256ae3e13536edd39b172b88f41615e1d7b653c8ad24524113e8" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.65", +] + [[package]] name = "mutually_exclusive_features" version = "0.0.3" @@ -2348,23 +2719,59 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "noodles" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5c35efef9490612b43ca521e29b448f8416305f2875811d26c2f89c4d94df82" +dependencies = [ + "noodles-bam 0.52.0", + "noodles-bcf 0.45.0", + "noodles-bgzf", + "noodles-cram 0.51.0", + "noodles-csi 0.29.0", + "noodles-fasta 0.31.0", + "noodles-fastq", + "noodles-sam 0.49.0", + "noodles-tabix 0.35.0", + "noodles-vcf 0.48.0", +] + [[package]] name = "noodles" version = "0.65.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38db1833ba39368f7a855c5b9a8412729a61dd86b2563e1a3bdb02aecbe92a3c" dependencies = [ - "noodles-bam", - "noodles-bcf", + "noodles-bam 0.56.0", + "noodles-bcf 0.46.0", "noodles-bgzf", - "noodles-core", - "noodles-cram", - "noodles-csi", - "noodles-fasta", + "noodles-core 0.14.0", + "noodles-cram 0.56.0", + "noodles-csi 0.30.0", + "noodles-fasta 0.33.0", "noodles-fastq", - "noodles-sam", - "noodles-tabix", - "noodles-vcf", + "noodles-sam 0.53.0", + "noodles-tabix 0.36.0", + "noodles-vcf 0.49.0", +] + +[[package]] +name = "noodles-bam" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e75824c4fad4713c177148543d96893212c2b8b6efc3cd9fc19934bb9334c97" +dependencies = [ + "bit-vec", + "byteorder", + "bytes", + "futures", + "indexmap 2.2.6", + "noodles-bgzf", + "noodles-core 0.13.0", + "noodles-csi 0.29.0", + "noodles-sam 0.49.0", + "tokio", ] [[package]] @@ -2380,9 +2787,25 @@ dependencies = [ "futures", "indexmap 2.2.6", "noodles-bgzf", - "noodles-core", - "noodles-csi", - "noodles-sam", + "noodles-core 0.14.0", + "noodles-csi 0.30.0", + "noodles-sam 0.53.0", + "tokio", +] + +[[package]] +name = "noodles-bcf" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc9d479487b2e021df164bd488244108ad1695e7d980f9a6bd22efd18e479a73" +dependencies = [ + "byteorder", + "futures", + "indexmap 2.2.6", + "noodles-bgzf", + "noodles-core 0.13.0", + "noodles-csi 0.29.0", + "noodles-vcf 0.48.0", "tokio", ] @@ -2396,9 +2819,9 @@ dependencies = [ "futures", "indexmap 2.2.6", "noodles-bgzf", - "noodles-core", - "noodles-csi", - "noodles-vcf", + "noodles-core 0.14.0", + "noodles-csi 0.30.0", + "noodles-vcf 0.49.0", "tokio", ] @@ -2418,12 +2841,41 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "noodles-core" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2993a01927b449e191670446b8a36e153e89fc4527a246a84eed9057adeefe1b" + [[package]] name = "noodles-core" version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7336c3be652de4e05444c9b12a32331beb5ba3316e8872d92bfdd8ef3b06c282" +[[package]] +name = "noodles-cram" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9322eea979cd8c62da1f04fed6bf23830c50079d25c333ce629958783b23de2a" +dependencies = [ + "async-compression", + "bitflags 2.5.0", + "byteorder", + "bytes", + "bzip2", + "flate2", + "futures", + "md-5", + "noodles-bam 0.52.0", + "noodles-core 0.13.0", + "noodles-fasta 0.31.0", + "noodles-sam 0.49.0", + "pin-project-lite", + "tokio", + "xz2", +] + [[package]] name = "noodles-cram" version = "0.56.0" @@ -2440,15 +2892,29 @@ dependencies = [ "futures", "indexmap 2.2.6", "md-5", - "noodles-bam", - "noodles-core", - "noodles-fasta", - "noodles-sam", + "noodles-bam 0.56.0", + "noodles-core 0.14.0", + "noodles-fasta 0.33.0", + "noodles-sam 0.53.0", "pin-project-lite", "tokio", "xz2", ] +[[package]] +name = "noodles-csi" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9abd5616c374ad3da6677603dc1637ef518388537ec4a0263b8e4471ee5b0801" +dependencies = [ + "bit-vec", + "byteorder", + "indexmap 2.2.6", + "noodles-bgzf", + "noodles-core 0.13.0", + "tokio", +] + [[package]] name = "noodles-csi" version = "0.30.0" @@ -2459,7 +2925,20 @@ dependencies = [ "byteorder", "indexmap 2.2.6", "noodles-bgzf", - "noodles-core", + "noodles-core 0.14.0", + "tokio", +] + +[[package]] +name = "noodles-fasta" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc8985866ee2464904d71f26661797ed610afcdb926523afc1b8a88f34d512c0" +dependencies = [ + "bytes", + "memchr", + "noodles-bgzf", + "noodles-core 0.13.0", "tokio", ] @@ -2472,7 +2951,7 @@ dependencies = [ "bytes", "memchr", "noodles-bgzf", - "noodles-core", + "noodles-core 0.14.0", "tokio", ] @@ -2487,6 +2966,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "noodles-sam" +version = "0.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b94966806ac7aec118d41eea7080bfbd0e8b843ba64f46522c57f0f55cfb1f0" +dependencies = [ + "bitflags 2.5.0", + "futures", + "indexmap 2.2.6", + "lexical-core", + "memchr", + "noodles-bgzf", + "noodles-core 0.13.0", + "noodles-csi 0.29.0", + "tokio", +] + [[package]] name = "noodles-sam" version = "0.53.0" @@ -2500,8 +2996,23 @@ dependencies = [ "lexical-core", "memchr", "noodles-bgzf", - "noodles-core", - "noodles-csi", + "noodles-core 0.14.0", + "noodles-csi 0.30.0", + "tokio", +] + +[[package]] +name = "noodles-tabix" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd762c360b3d611d0457bc7db5f30741f3b658f27e8541538171d058f3d2379" +dependencies = [ + "bit-vec", + "byteorder", + "indexmap 2.2.6", + "noodles-bgzf", + "noodles-core 0.13.0", + "noodles-csi 0.29.0", "tokio", ] @@ -2515,8 +3026,25 @@ dependencies = [ "byteorder", "indexmap 2.2.6", "noodles-bgzf", - "noodles-core", - "noodles-csi", + "noodles-core 0.14.0", + "noodles-csi 0.30.0", + "tokio", +] + +[[package]] +name = "noodles-vcf" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7799c5e60ff2a3234a778e9bd4dbed470d8ccd5e934fd75028a951f1f11099" +dependencies = [ + "futures", + "indexmap 2.2.6", + "memchr", + "noodles-bgzf", + "noodles-core 0.13.0", + "noodles-csi 0.29.0", + "noodles-tabix 0.35.0", + "percent-encoding", "tokio", ] @@ -2530,13 +3058,19 @@ dependencies = [ "indexmap 2.2.6", "memchr", "noodles-bgzf", - "noodles-core", - "noodles-csi", - "noodles-tabix", + "noodles-core 0.14.0", + "noodles-csi 0.30.0", + "noodles-tabix 0.36.0", "percent-encoding", "tokio", ] +[[package]] +name = "normalize-line-endings" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -2617,6 +3151,12 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "openssl-probe" version = "0.1.5" @@ -2658,6 +3198,17 @@ dependencies = [ "windows-targets 0.52.5", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.15" @@ -2682,6 +3233,16 @@ dependencies = [ "once_cell", ] +[[package]] +name = "pbkdf2" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" +dependencies = [ + "digest", + "hmac", +] + [[package]] name = "pear" version = "0.2.9" @@ -2702,7 +3263,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -2738,7 +3299,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -2759,6 +3320,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "platforms" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db23d408679286588f4d4644f965003d056e3dd5abcaaa938116871d7ce2fee7" + [[package]] name = "plotters" version = "0.3.5" @@ -2787,6 +3354,17 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "poly1305" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" +dependencies = [ + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -2799,6 +3377,35 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "predicates" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49cfaf7fdaa3bfacc6fa3e7054e65148878354a5cfddcf661df4c851f8021df" +dependencies = [ + "difference", + "float-cmp", + "normalize-line-endings", + "predicates-core", + "regex", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "pretty_assertions" version = "1.4.0" @@ -2809,6 +3416,16 @@ dependencies = [ "yansi 0.5.1", ] +[[package]] +name = "pretty_env_logger" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "865724d4dbe39d9f3dd3b52b88d859d66bcb2d6a0acfd5ea68a65fb66d4bdc1c" +dependencies = [ + "env_logger", + "log", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -2835,9 +3452,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.82" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b" +checksum = "0b33eb56c327dec362a9e55b3ad14f9d2f0904fb5a5b03b513ab5465399e9f43" dependencies = [ "unicode-ident", ] @@ -2850,7 +3467,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", "version_check", "yansi 1.0.1", ] @@ -2988,6 +3605,12 @@ dependencies = [ "regex-syntax 0.8.3", ] +[[package]] +name = "regex-lite" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e" + [[package]] name = "regex-syntax" version = "0.6.29" @@ -3082,6 +3705,27 @@ dependencies = [ "xmlparser", ] +[[package]] +name = "rpassword" +version = "7.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80472be3c897911d0137b2d2b9055faf6eeac5b14e324073d83bc17b191d7e3f" +dependencies = [ + "libc", + "rtoolbox", + "windows-sys 0.48.0", +] + +[[package]] +name = "rtoolbox" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c247d24e63230cdb56463ae328478bd5eac8b8faa8c69461a77e8e323afac90e" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -3155,9 +3799,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "092474d1a01ea8278f69e6a358998405fae5b8b963ddaeb2b0b04a128bf1dfb0" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" [[package]] name = "ryu" @@ -3260,6 +3904,15 @@ dependencies = [ "uuid", ] +[[package]] +name = "salsa20" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213" +dependencies = [ + "cipher", +] + [[package]] name = "same-file" version = "1.0.6" @@ -3284,6 +3937,18 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scrypt" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0516a385866c09368f0b5bcd1caff3366aace790fcd46e2bb032697bb172fd1f" +dependencies = [ + "password-hash", + "pbkdf2", + "salsa20", + "sha2", +] + [[package]] name = "sct" version = "0.7.1" @@ -3325,22 +3990,22 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" [[package]] name = "serde" -version = "1.0.201" +version = "1.0.202" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "780f1cebed1629e4753a1a38a3c72d30b97ec044f0aef68cb26650a3c5cf363c" +checksum = "226b61a0d411b2ba5ff6d7f73a476ac4f8bb900373459cd00fab8512828ba395" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.201" +version = "1.0.202" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5e405930b9796f1c00bee880d03fc7e0bb4b9a11afc776885ffe84320da2865" +checksum = "6048858004bcff69094cd972ed40a32500f153bd3be9f716b2eed2e8217c4838" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -3377,9 +4042,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" +checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" dependencies = [ "serde", ] @@ -3423,7 +4088,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -3515,12 +4180,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "strsim" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" - [[package]] name = "strsim" version = "0.11.1" @@ -3546,9 +4205,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.61" +version = "2.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c993ed8ccba56ae856363b1845da7266a7cb78e1d146c8a32d54b45a8b831fc9" +checksum = "d2863d96a84c6439701d7a38f9de935ec562c8832cc55d1dde0f513b52fad106" dependencies = [ "proc-macro2", "quote", @@ -3594,24 +4253,39 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "thiserror" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -3707,7 +4381,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -3747,9 +4421,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.12" +version = "0.8.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3" +checksum = "a4e43f8cc456c9704c851ae29c67e17ef65d2c30017c17a9765b89c382dc8bba" dependencies = [ "serde", "serde_spanned", @@ -3759,18 +4433,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.12" +version = "0.22.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3328d4f68a705b2a4498da1d580585d39a6510f98318a2cec3018a7ec61ddef" +checksum = "c127785850e8c20836d49732ae6abfa47616e60bf9d9f57c43c250361a9db96c" dependencies = [ "indexmap 2.2.6", "serde", @@ -3865,7 +4539,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] @@ -3990,6 +4664,16 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.7.1" @@ -4098,7 +4782,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", "wasm-bindgen-shared", ] @@ -4132,7 +4816,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4423,7 +5107,7 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.61", + "syn 2.0.65", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index fe8017535..1aa90207d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ resolver = "2" members = [ + "async-crypt4gh", "htsget-config", "htsget-actix", "htsget-http", diff --git a/async-crypt4gh/Cargo.toml b/async-crypt4gh/Cargo.toml new file mode 100644 index 000000000..137340186 --- /dev/null +++ b/async-crypt4gh/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "async-crypt4gh" +version = "0.1.0" +rust-version = "1.70" +authors = ["Marko Malenic "] +edition = "2021" +description = "An async wrapper around crypt4gh-rust using AsyncRead and Stream." +license = "MIT" +documentation = "https://github.com/umccr/htsget-rs/blob/main/async-crypt4gh/README.md" +homepage = "https://github.com/umccr/htsget-rs/blob/main/async-crypt4gh/README.md" +repository = "https://github.com/umccr/htsget-rs" + +[dependencies] +crypt4gh = { version = "0.4", git = "https://github.com/EGA-archive/crypt4gh-rust" } +pin-project-lite = "0.2" +hex-literal = "0.4" +bytes = "1.4" +tokio = { version = "1.29", features = ["macros", "rt-multi-thread", "io-util"] } +tokio-util = { version = "0.7", features = ["io", "compat", "codec"] } +futures = "0.3" +futures-util = "0.3" +thiserror = "1.0" +async-trait = "0.1" +rustls = "0.21" +rand_chacha = "0.3.1" +bincode = "1.3.3" +tempfile = "3.9" +tracing = "0.1" +base64 = "0.22" +bstr = "1.9" + +[dev-dependencies] +noodles = { version = "0.60", features = ["async", "bam", "sam"] } + +htsget-test = { version = "0.6.0", path = "../htsget-test", features = ["http", "crypt4gh"], default-features = false } \ No newline at end of file diff --git a/async-crypt4gh/LICENSE b/async-crypt4gh/LICENSE new file mode 100644 index 000000000..468cd79a8 --- /dev/null +++ b/async-crypt4gh/LICENSE @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/async-crypt4gh/README.md b/async-crypt4gh/README.md new file mode 100644 index 000000000..dd6a76b4a --- /dev/null +++ b/async-crypt4gh/README.md @@ -0,0 +1,11 @@ +# async-crypt4gh + +[![MIT licensed][mit-badge]][mit-url] +[![Build Status][actions-badge]][actions-url] + +[mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg +[mit-url]: https://github.com/umccr/htsget-rs/blob/main/LICENSE +[actions-badge]: https://github.com/umccr/htsget-rs/actions/workflows/action.yml/badge.svg +[actions-url]: https://github.com/umccr/htsget-rs/actions?query=workflow%3Atests+branch%3Amain + +An async wrapper around crypt4gh-rust using AsyncRead and Stream. \ No newline at end of file diff --git a/async-crypt4gh/src/advance.rs b/async-crypt4gh/src/advance.rs new file mode 100644 index 000000000..a4f22a7f8 --- /dev/null +++ b/async-crypt4gh/src/advance.rs @@ -0,0 +1,27 @@ +use std::io; + +use async_trait::async_trait; + +/// A trait which defines the advance operation. +/// +/// Advance to an offset, in bytes, in the stream. This is very similar to seek, but it only +/// operates on the current stream and it does not change the position of the underlying stream +/// or buffer. +/// +/// This is useful for implementing seek-like operations on data types where information about +/// a stream's position can be obtained without having access to the whole stream. For example, +/// determining the offsets of data blocks in a Crypt4GH file while only having access to the +/// header and the file's size. +#[async_trait] +pub trait Advance { + /// Advance in the encrypted stream. This function returns the new position of the + /// advanced stream. + async fn advance_encrypted(&mut self, position: u64) -> io::Result; + + /// Advance in the unencrypted stream. This function returns the new position of the + /// advanced stream. + async fn advance_unencrypted(&mut self, position: u64) -> io::Result; + + /// Get the stream length, if it is available. + fn stream_length(&self) -> Option; +} diff --git a/async-crypt4gh/src/decoder/mod.rs b/async-crypt4gh/src/decoder/mod.rs new file mode 100644 index 000000000..c842f484f --- /dev/null +++ b/async-crypt4gh/src/decoder/mod.rs @@ -0,0 +1,427 @@ +use std::io; + +use bytes::{Bytes, BytesMut}; +use crypt4gh::header::{deconstruct_header_info, HeaderInfo}; +use tokio_util::codec::Decoder; + +use crate::error::Error::{ + Crypt4GHError, DecodingHeaderInfo, MaximumHeaderSize, NumericConversionError, + SliceConversionError, +}; +use crate::error::{Error, Result}; +use crate::{EncryptedHeaderPacketBytes, EncryptedHeaderPackets}; + +pub const ENCRYPTED_BLOCK_SIZE: usize = 65536; +pub const NONCE_SIZE: usize = 12; // ChaCha20 IETF Nonce size +pub const MAC_SIZE: usize = 16; + +const DATA_BLOCK_SIZE: usize = NONCE_SIZE + ENCRYPTED_BLOCK_SIZE + MAC_SIZE; + +const MAGIC_STRING_SIZE: usize = 8; +const VERSION_STRING_SIZE: usize = 4; +const HEADER_PACKET_COUNT_SIZE: usize = 4; + +pub const HEADER_INFO_SIZE: usize = + MAGIC_STRING_SIZE + VERSION_STRING_SIZE + HEADER_PACKET_COUNT_SIZE; + +const HEADER_PACKET_LENGTH_SIZE: usize = 4; + +/// Have some sort of maximum header size to prevent any overflows. +const MAX_HEADER_SIZE: usize = 8 * 1024 * 1024; + +/// The type that a block is decoded into. +#[derive(Debug)] +pub enum DecodedBlock { + /// The magic string, version string and header packet count. + /// Corresponds to `deconstruct_header_info`. + HeaderInfo(HeaderInfo), + /// Header packets, both data encryption key packets and a data edit list packets. + /// Corresponds to `deconstruct_header_body`. + HeaderPackets(EncryptedHeaderPackets), + /// The encrypted data blocks + /// Corresponds to `body_decrypt`. + DataBlock(Bytes), +} + +/// State to keep track of the current block being decoded corresponding to `BlockType`. +#[derive(Debug)] +enum BlockState { + /// Expecting header info. + HeaderInfo, + /// Expecting header packets and the number of header packets left to decode. + HeaderPackets(u32), + /// Expecting a data block. + DataBlock, + /// Expecting the end of the file. This is to account for the last data block potentially being + /// shorter. + Eof, +} + +#[derive(Debug)] +pub struct Block { + next_block: BlockState, +} + +impl Block { + fn get_header_info(src: &mut BytesMut) -> Result { + deconstruct_header_info( + src + .split_to(HEADER_INFO_SIZE) + .as_ref() + .try_into() + .map_err(|_| SliceConversionError)?, + ) + .map_err(DecodingHeaderInfo) + } + + /// Parses the header info, updates the state and returns the block type. Unlike the other + /// `decode` methods, this method parses the header info before returning a decoded block + /// because the header info contains the number of packets which is required for decoding + /// the rest of the source. + pub fn decode_header_info(&mut self, src: &mut BytesMut) -> Result> { + // Header info is a fixed size. + if src.len() < HEADER_INFO_SIZE { + src.reserve(HEADER_INFO_SIZE); + return Ok(None); + } + + // Parse the header info because it contains the number of header packets. + let header_info = Self::get_header_info(src)?; + + self.next_block = BlockState::HeaderPackets(header_info.packets_count); + + Ok(Some(DecodedBlock::HeaderInfo(header_info))) + } + + /// Decodes header packets, updates the state and returns a header packet block type. + pub fn decode_header_packets( + &mut self, + src: &mut BytesMut, + header_packets: u32, + ) -> Result> { + let mut header_packet_bytes = vec![]; + for _ in 0..header_packets { + // Get enough bytes to read the header packet length. + if src.len() < HEADER_PACKET_LENGTH_SIZE { + src.reserve(HEADER_PACKET_LENGTH_SIZE); + return Ok(None); + } + + // Read the header packet length. + let length_bytes = src.split_to(HEADER_PACKET_LENGTH_SIZE).freeze(); + let mut length: usize = u32::from_le_bytes( + length_bytes + .as_ref() + .try_into() + .map_err(|_| SliceConversionError)?, + ) + .try_into() + .map_err(|_| NumericConversionError)?; + + // We have already taken 4 bytes out of the length. + length -= HEADER_PACKET_LENGTH_SIZE; + + // Have a maximum header size to prevent any overflows. + if length > MAX_HEADER_SIZE { + return Err(MaximumHeaderSize); + } + + // Get enough bytes to read the entire header packet. + if src.len() < length { + src.reserve(length - src.len()); + return Ok(None); + } + + header_packet_bytes.push(EncryptedHeaderPacketBytes::new( + length_bytes, + src.split_to(length).freeze(), + )); + } + + self.next_block = BlockState::DataBlock; + + let header_length = u64::try_from( + header_packet_bytes + .iter() + .map(|packet| packet.packet_length().len() + packet.header().len()) + .sum::(), + ) + .map_err(|_| NumericConversionError)?; + + Ok(Some(DecodedBlock::HeaderPackets( + EncryptedHeaderPackets::new(header_packet_bytes, header_length), + ))) + } + + /// Decodes data blocks, updates the state and returns a data block type. + pub fn decode_data_block(&mut self, src: &mut BytesMut) -> Result> { + // Data blocks are a fixed size, so we can return the + // next data block without much processing. + if src.len() < DATA_BLOCK_SIZE { + src.reserve(DATA_BLOCK_SIZE); + return Ok(None); + } + + self.next_block = BlockState::DataBlock; + + Ok(Some(DecodedBlock::DataBlock( + src.split_to(DATA_BLOCK_SIZE).freeze(), + ))) + } + + /// Get the standard size of all non-ending data blocks. + pub const fn standard_data_block_size() -> u64 { + DATA_BLOCK_SIZE as u64 + } + + /// Get the size of the magic string, version and header packet count. + pub const fn header_info_size() -> u64 { + HEADER_INFO_SIZE as u64 + } + + /// Get the encrypted block size, without nonce and mac bytes. + pub const fn encrypted_block_size() -> u64 { + ENCRYPTED_BLOCK_SIZE as u64 + } + + /// Get the size of the nonce. + pub const fn nonce_size() -> u64 { + NONCE_SIZE as u64 + } + + /// Get the size of the mac. + pub const fn mac_size() -> u64 { + MAC_SIZE as u64 + } + + /// Get the maximum possible header size. + pub const fn max_header_size() -> u64 { + MAX_HEADER_SIZE as u64 + } +} + +impl Default for Block { + fn default() -> Self { + Self { + next_block: BlockState::HeaderInfo, + } + } +} + +impl Decoder for Block { + type Item = DecodedBlock; + type Error = Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result> { + match self.next_block { + BlockState::HeaderInfo => self.decode_header_info(src), + BlockState::HeaderPackets(header_packets) => self.decode_header_packets(src, header_packets), + BlockState::DataBlock => self.decode_data_block(src), + BlockState::Eof => Ok(None), + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result> { + // Need a custom implementation of decode_eof because the last data block can be shorter. + match self.decode(buf)? { + Some(frame) => Ok(Some(frame)), + None => { + if buf.is_empty() { + Ok(None) + } else if let BlockState::DataBlock = self.next_block { + // The last data block can be smaller than 64KiB. + if buf.len() <= DATA_BLOCK_SIZE { + self.next_block = BlockState::Eof; + + Ok(Some(DecodedBlock::DataBlock(buf.split().freeze()))) + } else { + Err(Crypt4GHError( + "the last data block is too large".to_string(), + )) + } + } else { + Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into()) + } + } + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use std::io::Cursor; + + use crypt4gh::header::{deconstruct_header_body, DecryptedHeaderPackets}; + use crypt4gh::{body_decrypt, Keys, WriteInfo}; + use futures_util::stream::Skip; + use futures_util::StreamExt; + use tokio::fs::File; + use tokio::io::AsyncReadExt; + use tokio_util::codec::FramedRead; + + use htsget_test::crypt4gh::get_decryption_keys; + use htsget_test::http::get_test_file; + + use crate::tests::get_original_file; + + use super::*; + + #[tokio::test] + async fn decode_header_info() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let mut reader = FramedRead::new(src, Block::default()); + + let header_info = reader.next().await.unwrap().unwrap(); + + // Assert that the first block output is a header info with one packet. + assert!( + matches!(header_info, DecodedBlock::HeaderInfo(header_info) if header_info.packets_count == 1) + ); + } + + #[tokio::test] + async fn decode_header_packets() { + let (recipient_private_key, sender_public_key, header_packet, _) = + get_first_header_packet().await; + let header = get_header_packets(recipient_private_key, sender_public_key, header_packet); + + assert_first_header_packet(header); + + // Todo handle case where there is more than one header packet. + } + + #[tokio::test] + async fn decode_data_block() { + let (header, data_block) = get_data_block(0).await; + + let read_buf = Cursor::new(data_block.to_vec()); + let mut write_buf = Cursor::new(vec![]); + let mut write_info = WriteInfo::new(0, None, &mut write_buf); + + body_decrypt(read_buf, &header.data_enc_packets, &mut write_info, 0).unwrap(); + + let decrypted_bytes = write_buf.into_inner(); + + assert_first_data_block(decrypted_bytes).await; + } + + #[tokio::test] + async fn decode_eof() { + let (header, data_block) = get_data_block(39).await; + + let read_buf = Cursor::new(data_block.to_vec()); + let mut write_buf = Cursor::new(vec![]); + let mut write_info = WriteInfo::new(0, None, &mut write_buf); + + body_decrypt(read_buf, &header.data_enc_packets, &mut write_info, 0).unwrap(); + + let decrypted_bytes = write_buf.into_inner(); + + assert_last_data_block(decrypted_bytes).await; + } + + /// Assert that the first header packet is a data encryption key packet. + pub(crate) fn assert_first_header_packet(header: DecryptedHeaderPackets) { + assert_eq!(header.data_enc_packets.len(), 1); + assert!(header.edit_list_packet.is_none()); + } + + /// Assert that the last data block is equal to the expected ending bytes of the original file. + pub(crate) async fn assert_last_data_block(decrypted_bytes: Vec) { + let mut original_file = get_test_file("bam/htsnexus_test_NA12878.bam").await; + let mut original_bytes = vec![]; + original_file + .read_to_end(&mut original_bytes) + .await + .unwrap(); + + assert_eq!( + decrypted_bytes, + original_bytes + .into_iter() + .rev() + .take(40895) + .rev() + .collect::>() + ); + } + + /// Assert that the first data block is equal to the first 64KiB of the original file. + pub(crate) async fn assert_first_data_block(decrypted_bytes: Vec) { + let original_bytes = get_original_file().await; + + assert_eq!(decrypted_bytes, original_bytes[..65536]); + } + + /// Get the first header packet from the test file. + pub(crate) async fn get_first_header_packet( + ) -> (Keys, Vec, Vec, Skip>) { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = FramedRead::new(src, Block::default()).skip(1); + + // The second block should contain a header packet. + let header_packets = reader.next().await.unwrap().unwrap(); + + let (header_packet, header_length) = + if let DecodedBlock::HeaderPackets(header_packets) = header_packets { + Some(header_packets) + } else { + None + } + .unwrap() + .into_inner(); + + assert_eq!(header_length, 108); + + ( + recipient_private_key, + sender_public_key, + header_packet + .into_iter() + .map(|packet| packet.into_header_bytes()) + .collect(), + reader, + ) + } + + /// Get the first data block from the test file. + pub(crate) async fn get_data_block(skip: usize) -> (DecryptedHeaderPackets, Bytes) { + let (recipient_private_key, sender_public_key, header_packets, reader) = + get_first_header_packet().await; + let header = get_header_packets(recipient_private_key, sender_public_key, header_packets); + + let data_block = reader.skip(skip).next().await.unwrap().unwrap(); + + let data_block = if let DecodedBlock::DataBlock(data_block) = data_block { + Some(data_block) + } else { + None + } + .unwrap(); + + (header, data_block) + } + + /// Get the header packets from a decoded block. + pub(crate) fn get_header_packets( + recipient_private_key: Keys, + sender_public_key: Vec, + header_packets: Vec, + ) -> DecryptedHeaderPackets { + // Assert the size of the header packet is correct. + assert_eq!(header_packets.len(), 1); + assert_eq!(header_packets.first().unwrap().len(), 104); + + deconstruct_header_body( + header_packets + .into_iter() + .map(|header_packet| header_packet.to_vec()) + .collect(), + &[recipient_private_key], + &Some(sender_public_key), + ) + .unwrap() + } +} diff --git a/async-crypt4gh/src/decrypter/builder.rs b/async-crypt4gh/src/decrypter/builder.rs new file mode 100644 index 000000000..4c41f2c0f --- /dev/null +++ b/async-crypt4gh/src/decrypter/builder.rs @@ -0,0 +1,95 @@ +use crypt4gh::Keys; +use tokio::io::{AsyncRead, AsyncSeek}; +use tokio_util::codec::FramedRead; + +use crate::decrypter::DecrypterStream; +use crate::error::Result; +use crate::PublicKey; + +/// An decrypter reader builder. +#[derive(Debug, Default)] +pub struct Builder { + sender_pubkey: Option, + stream_length: Option, + edit_list: Option>, +} + +impl Builder { + /// Sets the sender public key + pub fn with_sender_pubkey(self, sender_pubkey: PublicKey) -> Self { + self.set_sender_pubkey(Some(sender_pubkey)) + } + + /// Sets the sender public key + pub fn set_sender_pubkey(mut self, sender_pubkey: Option) -> Self { + self.sender_pubkey = sender_pubkey; + self + } + + /// Sets the stream length. + pub fn with_stream_length(self, stream_length: u64) -> Self { + self.set_stream_length(Some(stream_length)) + } + + /// Sets the stream length. + pub fn set_stream_length(mut self, stream_length: Option) -> Self { + self.stream_length = stream_length; + self + } + + /// Set the edit list manually. + pub fn with_edit_list(self, edit_list: Vec) -> Self { + self.set_edit_list(Some(edit_list)) + } + + /// Set the edit list manually. + pub fn set_edit_list(mut self, edit_list: Option>) -> Self { + self.edit_list = edit_list; + self + } + + /// Build the decrypter. + pub fn build(self, inner: R, keys: Vec) -> DecrypterStream + where + R: AsyncRead, + { + DecrypterStream { + inner: FramedRead::new(inner, Default::default()), + header_packet_future: None, + keys, + sender_pubkey: self.sender_pubkey, + session_keys: vec![], + encrypted_header_packets: None, + edit_list_packet: DecrypterStream::<()>::create_internal_edit_list(self.edit_list), + header_info: None, + header_length: None, + current_block_size: None, + stream_length: self.stream_length, + } + } + + /// Build the decrypter and compute the stream length for seek operations. This function will + /// ensure that recompute_stream_length is called at least once on the decrypter stream. + /// + /// This means that data block positions past the end of the stream will be valid and will equal + /// the the length of the stream. Use the build function if this behaviour is not desired. Seeking + /// past the end of the stream without a stream length is allowed but the behaviour is dependent + /// on the underlying reader and data block positions may not be valid. + pub async fn build_with_stream_length( + self, + inner: R, + keys: Vec, + ) -> Result> + where + R: AsyncRead + AsyncSeek + Unpin, + { + let stream_length = self.stream_length; + let mut stream = self.build(inner, keys); + + if stream_length.is_none() { + stream.recompute_stream_length().await?; + } + + Ok(stream) + } +} diff --git a/async-crypt4gh/src/decrypter/data_block.rs b/async-crypt4gh/src/decrypter/data_block.rs new file mode 100644 index 000000000..e4eee3a06 --- /dev/null +++ b/async-crypt4gh/src/decrypter/data_block.rs @@ -0,0 +1,124 @@ +use std::future::Future; +use std::io::Cursor; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use crypt4gh::{body_decrypt, WriteInfo}; +use pin_project_lite::pin_project; +use tokio::task::JoinHandle; + +use crate::decrypter::DecrypterStream; +use crate::error::Error::{Crypt4GHError, JoinHandleError}; +use crate::error::Result; +use crate::{DecryptedBytes, DecryptedDataBlock}; + +pin_project! { + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct DataBlockDecrypter { + #[pin] + handle: JoinHandle> + } +} + +impl DataBlockDecrypter { + pub fn new( + data_block: Bytes, + session_keys: Vec>, + edit_list_packet: Option>, + ) -> Self { + Self { + handle: tokio::task::spawn_blocking(move || { + DataBlockDecrypter::decrypt(data_block, session_keys, edit_list_packet) + }), + } + } + + pub fn decrypt( + data_block: Bytes, + session_keys: Vec>, + edit_list_packet: Option>, + ) -> Result { + let size = data_block.len(); + + let read_buf = Cursor::new(data_block.to_vec()); + let mut write_buf = Cursor::new(vec![]); + let mut write_info = WriteInfo::new(0, None, &mut write_buf); + + // Todo crypt4gh-rust body_decrypt_parts does not work properly, so just apply edit list here. + body_decrypt(read_buf, session_keys.as_slice(), &mut write_info, 0) + .map_err(|err| Crypt4GHError(err.to_string()))?; + let mut decrypted_bytes: Bytes = write_buf.into_inner().into(); + let mut edited_bytes = Bytes::new(); + + let edits = DecrypterStream::<()>::create_internal_edit_list(edit_list_packet) + .unwrap_or(vec![(false, decrypted_bytes.len() as u64)]); + if edits.iter().map(|(_, edit)| edit).sum::() > decrypted_bytes.len() as u64 { + return Err(Crypt4GHError( + "invalid edit lists for the decrypted data block".to_string(), + )); + } + + edits.into_iter().for_each(|(discarding, edit)| { + if !discarding { + let edit = decrypted_bytes.slice(0..edit as usize); + edited_bytes = [edited_bytes.clone(), edit].concat().into(); + } + + decrypted_bytes = decrypted_bytes.slice(edit as usize..); + }); + + Ok(DecryptedDataBlock::new( + DecryptedBytes::new(edited_bytes), + size, + )) + } +} + +impl Future for DataBlockDecrypter { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().handle.poll(cx).map_err(JoinHandleError)? + } +} + +#[cfg(test)] +mod tests { + use crate::decoder::tests::{assert_first_data_block, get_data_block}; + use crate::tests::get_original_file; + + use super::*; + + #[tokio::test] + async fn data_block_decrypter() { + let (header_packets, data_block) = get_data_block(0).await; + + let data = DataBlockDecrypter::new( + data_block, + header_packets.data_enc_packets, + header_packets.edit_list_packet, + ) + .await + .unwrap(); + + assert_first_data_block(data.bytes.to_vec()).await; + } + + #[tokio::test] + async fn data_block_decrypter_with_edit_list() { + let (header_packets, data_block) = get_data_block(0).await; + + let data = DataBlockDecrypter::new( + data_block, + header_packets.data_enc_packets, + Some(vec![0, 4668, 60868]), + ) + .await + .unwrap(); + + let original_bytes = get_original_file().await; + + assert_eq!(data.bytes.to_vec(), original_bytes[..4668]); + } +} diff --git a/async-crypt4gh/src/decrypter/header/mod.rs b/async-crypt4gh/src/decrypter/header/mod.rs new file mode 100644 index 000000000..278825fbb --- /dev/null +++ b/async-crypt4gh/src/decrypter/header/mod.rs @@ -0,0 +1,41 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::AsyncRead; + +use crate::decrypter::DecrypterStream; +use crate::decrypter::Result; + +pub mod packets; + +/// A struct which will poll a decrypter stream until the session keys are found. +/// After polling the future, the underlying decrypter stream should have processed +/// the session keys. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct SessionKeysFuture<'a, R> { + handle: &'a mut DecrypterStream, +} + +impl<'a, R> SessionKeysFuture<'a, R> { + /// Create the future. + pub fn new(handle: &'a mut DecrypterStream) -> Self { + Self { handle } + } + + /// Get the inner handle. + pub fn get_mut(&mut self) -> &mut DecrypterStream { + self.handle + } +} + +impl<'a, R> Future for SessionKeysFuture<'a, R> +where + R: AsyncRead + Unpin, +{ + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.handle.poll_session_keys_unpin(cx) + } +} diff --git a/async-crypt4gh/src/decrypter/header/packets.rs b/async-crypt4gh/src/decrypter/header/packets.rs new file mode 100644 index 000000000..69ba5bc1c --- /dev/null +++ b/async-crypt4gh/src/decrypter/header/packets.rs @@ -0,0 +1,81 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use crypt4gh::header::{deconstruct_header_body, DecryptedHeaderPackets}; +use crypt4gh::Keys; +use pin_project_lite::pin_project; +use tokio::task::{spawn_blocking, JoinHandle}; + +use crate::error::Error::JoinHandleError; +use crate::error::Result; +use crate::PublicKey; + +pin_project! { + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct HeaderPacketsDecrypter { + #[pin] + handle: JoinHandle> + } +} + +impl HeaderPacketsDecrypter { + pub fn new( + header_packets: Vec, + keys: Vec, + sender_pubkey: Option, + ) -> Self { + Self { + handle: spawn_blocking(|| { + HeaderPacketsDecrypter::decrypt(header_packets, keys, sender_pubkey) + }), + } + } + + pub fn decrypt( + header_packets: Vec, + keys: Vec, + sender_pubkey: Option, + ) -> Result { + Ok(deconstruct_header_body( + header_packets + .into_iter() + .map(|bytes| bytes.to_vec()) + .collect(), + keys.as_slice(), + &sender_pubkey.map(|pubkey| pubkey.into_inner()), + )?) + } +} + +impl Future for HeaderPacketsDecrypter { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().handle.poll(cx).map_err(JoinHandleError)? + } +} + +#[cfg(test)] +mod tests { + use crate::decoder::tests::{assert_first_header_packet, get_first_header_packet}; + + use super::*; + + #[tokio::test] + async fn header_packet_decrypter() { + let (recipient_private_key, sender_public_key, header_packets, _) = + get_first_header_packet().await; + + let data = HeaderPacketsDecrypter::new( + header_packets, + vec![recipient_private_key], + Some(PublicKey::new(sender_public_key)), + ) + .await + .unwrap(); + + assert_first_header_packet(data); + } +} diff --git a/async-crypt4gh/src/decrypter/mod.rs b/async-crypt4gh/src/decrypter/mod.rs new file mode 100644 index 000000000..33f4b99ff --- /dev/null +++ b/async-crypt4gh/src/decrypter/mod.rs @@ -0,0 +1,1329 @@ +use std::cmp::min; +use std::future::Future; +use std::io; +use std::io::SeekFrom; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use bytes::Bytes; +use crypt4gh::header::HeaderInfo; +use crypt4gh::Keys; +use futures::ready; +use futures::Stream; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncSeek, AsyncSeekExt}; +use tokio_util::codec::FramedRead; + +use crate::advance::Advance; +use crate::decoder::Block; +use crate::decoder::DecodedBlock; +use crate::decrypter::data_block::DataBlockDecrypter; +use crate::decrypter::header::packets::HeaderPacketsDecrypter; +use crate::decrypter::header::SessionKeysFuture; +use crate::error::Error::Crypt4GHError; +use crate::error::Result; +use crate::EncryptedHeaderPacketBytes; +use crate::{util, PublicKey}; + +pub mod builder; +pub mod data_block; +pub mod header; + +pin_project! { + /// A decrypter for an entire AsyncRead Crypt4GH file. + pub struct DecrypterStream { + #[pin] + inner: FramedRead, + #[pin] + header_packet_future: Option, + keys: Vec, + sender_pubkey: Option, + encrypted_header_packets: Option>, + header_info: Option, + session_keys: Vec>, + edit_list_packet: Option>, + header_length: Option, + current_block_size: Option, + stream_length: Option, + } +} + +impl DecrypterStream +where + R: AsyncRead, +{ + /// Partitions the edit list packet so that it applies to the current data block, returning a new + /// edit list that correctly discards and keeps the specified bytes this particular data block. + /// Todo, this should possibly go into the decoder, where bytes can be skipped directly. + pub fn partition_edit_list(mut self: Pin<&mut Self>, data_block: &Bytes) -> Option> { + let this = self.as_mut().project(); + + if let Some(edit_list) = this.edit_list_packet { + let mut new_edit_list = vec![]; + let mut bytes_consumed = 0; + + edit_list.retain_mut(|(discarding, value)| { + // If this is not a discarding edit, then discard 0 at the start. + if !*discarding && new_edit_list.is_empty() { + new_edit_list.push(0); + } + + // Get the encrypted block size. + let data_block_len = data_block.len() as u64 - Block::nonce_size() - Block::mac_size(); + + // Can only consume as many bytes as there are in the data block. + if min(bytes_consumed + *value, data_block_len) == data_block_len { + if bytes_consumed != data_block_len { + // If the whole data block hasn't been consumed yet, an edit still needs to be added. + let last_edit = data_block_len - bytes_consumed; + new_edit_list.push(last_edit); + + // And remove this edit from the next value. + *value -= last_edit; + // Now the whole data block has been consumed. + bytes_consumed = data_block_len; + } + + // Keep all values from now. + true + } else { + // Otherwise, consume the value and remove it from the edit list packet. + bytes_consumed += *value; + new_edit_list.push(*value); + false + } + }); + + (!new_edit_list.is_empty()).then_some(new_edit_list) + } else { + // If there is no edit list to begin with, we just keep the whole block. + None + } + } + + /// Polls a data block. This function shouldn't execute until all the header packets have been + /// processed. + pub fn poll_data_block( + mut self: Pin<&mut Self>, + data_block: Bytes, + ) -> Poll>> { + let edit_list = self.as_mut().partition_edit_list(&data_block); + + let this = self.project(); + Poll::Ready(Some(Ok(DataBlockDecrypter::new( + data_block, + // Todo make this so it doesn't use owned Keys and SenderPublicKey as it will be called asynchronously. + this.session_keys.clone(), + edit_list, + )))) + } + + /// Poll the stream until the header packets and session keys are processed. + pub fn poll_session_keys(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Only execute this function if there are no session keys. + if !self.session_keys.is_empty() { + return Poll::Ready(Ok(())); + } + + // Header packets are waiting to be decrypted. + if let Some(header_packet_decrypter) = self.as_mut().project().header_packet_future.as_pin_mut() + { + return match ready!(header_packet_decrypter.poll(cx)) { + Ok(header_packets) => { + let mut this = self.as_mut().project(); + + // Update the session keys and edit list packets. + this.header_packet_future.set(None); + this.session_keys.extend(header_packets.data_enc_packets); + if this.edit_list_packet.is_none() { + *this.edit_list_packet = + Self::create_internal_edit_list(header_packets.edit_list_packet); + } + + Poll::Ready(Ok(())) + } + Err(err) => Poll::Ready(Err(err)), + }; + } + + // No header packets yet, so more data needs to be decoded. + let mut this = self.as_mut().project(); + match ready!(this.inner.poll_next(cx)) { + Some(Ok(buf)) => match buf { + DecodedBlock::HeaderInfo(header_info) => { + // Store the header info but otherwise ignore it and poll again. + *this.header_info = Some(header_info); + cx.waker().wake_by_ref(); + Poll::Pending + } + // todo no clones here. + DecodedBlock::HeaderPackets(header_packets) => { + // Update the header length because we have access to the header packets. + let (header_packets, header_length) = header_packets.into_inner(); + *this.encrypted_header_packets = Some(header_packets.clone()); + *this.header_length = Some(header_length + Block::header_info_size()); + + // Add task for decrypting the header packets. + this + .header_packet_future + .set(Some(HeaderPacketsDecrypter::new( + header_packets + .into_iter() + .map(|packet| packet.into_header_bytes()) + .collect(), + this.keys.clone(), + this.sender_pubkey.clone(), + ))); + + // Poll again. + cx.waker().wake_by_ref(); + Poll::Pending + } + DecodedBlock::DataBlock(_) => Poll::Ready(Err(Crypt4GHError( + "data block reached without finding session keys".to_string(), + ))), + }, + Some(Err(e)) => Poll::Ready(Err(e)), + None => Poll::Ready(Err(Crypt4GHError( + "end of stream reached without finding session keys".to_string(), + ))), + } + } + + /// Convenience for calling [`poll_session_keys`] on [`Unpin`] types. + pub fn poll_session_keys_unpin(&mut self, cx: &mut Context<'_>) -> Poll> + where + Self: Unpin, + { + Pin::new(self).poll_session_keys(cx) + } + + /// Poll the stream until the header has been read. + pub async fn read_header(&mut self) -> Result<()> + where + R: Unpin, + { + SessionKeysFuture::new(self).await + } +} + +impl DecrypterStream { + /// An override for setting the stream length. + pub async fn set_stream_length(&mut self, length: u64) { + self.stream_length = Some(length); + } + + /// Get a reference to the inner reader. + pub fn get_ref(&self) -> &R { + self.inner.get_ref() + } + + /// Get a mutable reference to the inner reader. + pub fn get_mut(&mut self) -> &mut R { + self.inner.get_mut() + } + + /// Get a pinned mutable reference to the inner reader. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.project().inner.get_pin_mut() + } + + /// Get the inner reader. + pub fn into_inner(self) -> R { + self.inner.into_inner() + } + + /// Get the length of the header, including the magic string, version number, packet count + /// and the header packets. Returns `None` before the header packet is polled. + pub fn header_size(&self) -> Option { + self.header_length + } + + /// Get the size of the current data block represented by the encrypted block returned by calling + /// poll_next. This will equal `decoder::DATA_BLOCK_SIZE` except for the last block which may be + /// less than that. Returns `None` before the first data block is polled. + pub fn current_block_size(&self) -> Option { + self.current_block_size + } + + /// Clamps the byte position to the nearest data block if the header length is known. This + /// function takes into account the stream length if it is present. + pub fn clamp_position(&self, position: u64) -> Option { + self.header_size().map(|length| { + if position < length { + length + } else { + match self.stream_length { + Some(end_length) if position >= end_length => end_length, + _ => { + let remainder = (position - length) % Block::standard_data_block_size(); + + position - remainder + } + } + } + }) + } + + /// Convert an unencrypted position to an encrypted position if the header length is known. + pub fn to_encrypted(&self, position: u64) -> Option { + self.header_size().map(|length| { + let encrypted_position = util::to_encrypted(position, length); + + match self.stream_length { + Some(end_length) if encrypted_position + Block::mac_size() > end_length => end_length, + _ => encrypted_position, + } + }) + } + + /// Get the session keys. Empty before the header is polled. + pub fn session_keys(&self) -> &[Vec] { + &self.session_keys + } + + /// Get the edit list packet. Empty before the header is polled. + pub fn edit_list_packet(&self) -> Option> { + self + .edit_list_packet + .as_ref() + .map(|packet| packet.iter().map(|(_, edit)| *edit).collect()) + } + + /// Get the header info. + pub fn header_info(&self) -> Option<&HeaderInfo> { + self.header_info.as_ref() + } + + /// Get the original encrypted header packets, not including the header info. + pub fn encrypted_header_packets(&self) -> Option<&Vec> { + self.encrypted_header_packets.as_ref() + } + + /// Get the stream's keys. + pub fn keys(&self) -> &[Keys] { + self.keys.as_slice() + } + + pub(crate) fn create_internal_edit_list(edit_list: Option>) -> Option> { + edit_list.map(|edits| [true, false].iter().cloned().cycle().zip(edits).collect()) + } +} + +impl DecrypterStream +where + R: AsyncRead + AsyncSeek + Unpin, +{ + /// Recompute the stream length. Having a stream length means that data block positions past the + /// end of the stream will be valid and will equal the the length of the stream. By default this + /// struct contains no stream length when it is initialized. + /// + /// This can take up to 3 seek calls. If the size of the underlying buffer changes, this function + /// should be called again, otherwise data block positions may not be valid. + pub async fn recompute_stream_length(&mut self) -> Result { + let inner = self.inner.get_mut(); + + let position = inner.seek(SeekFrom::Current(0)).await?; + let length = inner.seek(SeekFrom::End(0)).await?; + + if position != length { + inner.seek(SeekFrom::Start(position)).await?; + } + + self.stream_length = Some(length); + + Ok(length) + } +} + +impl Stream for DecrypterStream +where + R: AsyncRead, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // When polling, we first need to process enough data to get the session keys. + if let Err(err) = ready!(self.as_mut().poll_session_keys(cx)) { + return Poll::Ready(Some(Err(err))); + } + + let this = self.as_mut().project(); + let item = this.inner.poll_next(cx); + + match ready!(item) { + Some(Ok(buf)) => match buf { + DecodedBlock::HeaderInfo(_) | DecodedBlock::HeaderPackets(_) => { + // Session keys have already been read, so ignore the header info and header packets + // and poll again + cx.waker().wake_by_ref(); + Poll::Pending + } + DecodedBlock::DataBlock(data_block) => { + // The new size of the data block is available, so update it. + *this.current_block_size = Some(data_block.len()); + + // Session keys have been obtained so process the data blocks. + self.poll_data_block(data_block) + } + }, + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} + +impl DecrypterStream +where + R: AsyncRead + AsyncSeek + Unpin + Send, +{ + /// Seek to a position in the encrypted stream. + pub async fn seek_encrypted(&mut self, position: SeekFrom) -> io::Result { + // Make sure that session keys are polled. + self.read_header().await?; + + // First poll to the position specified. + let seek = self.inner.get_mut().seek(position).await?; + + // Then advance to the correct data block position. + let advance = self.advance_encrypted(seek).await?; + + // Then seek to the correct position. + let seek = self.inner.get_mut().seek(SeekFrom::Start(advance)).await?; + self.inner.read_buffer_mut().clear(); + + Ok(seek) + } + + /// Seek to a position in the unencrypted stream. + pub async fn seek_unencrypted(&mut self, position: u64) -> io::Result { + // Make sure that session keys are polled. + self.read_header().await?; + + // Convert to an encrypted position and seek + let position = self + .to_encrypted(position) + .ok_or_else(|| Crypt4GHError("Unable to convert to encrypted position.".to_string()))?; + + // Then do the seek. + self.seek_encrypted(SeekFrom::Start(position)).await + } +} + +#[async_trait] +impl Advance for DecrypterStream +where + R: AsyncRead + Send + Unpin, +{ + async fn advance_encrypted(&mut self, position: u64) -> io::Result { + // Make sure that session keys are polled. + self.read_header().await?; + + // Get the next position. + let data_block_position = self + .clamp_position(position) + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "could not find data block position"))?; + + Ok(data_block_position) + } + + async fn advance_unencrypted(&mut self, position: u64) -> io::Result { + // Make sure that session keys are polled. + self.read_header().await?; + + // Convert to an encrypted position and seek + let position = self + .to_encrypted(position) + .ok_or_else(|| Crypt4GHError("Unable to convert to encrypted position.".to_string()))?; + + // Then do the advance. + self.advance_encrypted(position).await + } + + fn stream_length(&self) -> Option { + self.stream_length + } +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use futures_util::future::join_all; + use futures_util::StreamExt; + use tokio::fs::File; + + use htsget_test::http::get_test_file; + + use crate::decoder::tests::assert_last_data_block; + use crate::decrypter::builder::Builder; + use crate::tests::get_original_file; + use htsget_test::crypt4gh::get_decryption_keys; + + use super::*; + + #[tokio::test] + async fn partition_edit_lists() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .with_edit_list(vec![60113, 100000, 65536]) + .build(src, vec![recipient_private_key]); + + assert_edit_list(&mut stream, Some(vec![60113, 5423]), vec![0; 65564]); + assert_edit_list(&mut stream, Some(vec![0, 65536]), vec![0; 65564]); + assert_edit_list(&mut stream, Some(vec![0, 29041, 36495]), vec![0; 65564]); + assert_edit_list(&mut stream, Some(vec![29041]), vec![0; 29041 + 12 + 16]); + } + + fn assert_edit_list( + stream: &mut DecrypterStream, + expected: Option>, + bytes: Vec, + ) { + let stream = Pin::new(stream); + let edit_list = stream.partition_edit_list(&Bytes::from(bytes)); + assert_eq!(edit_list, expected); + } + + #[tokio::test] + async fn decrypter_stream() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn get_header_length() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + assert!(stream.header_size().is_none()); + + let _ = stream.next().await.unwrap().unwrap().await; + + assert_eq!(stream.header_size(), Some(124)); + } + + #[tokio::test] + async fn first_block_size() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + assert!(stream.current_block_size().is_none()); + + let _ = stream.next().await.unwrap().unwrap().await; + + assert_eq!(stream.current_block_size(), Some(65564)); + } + + #[tokio::test] + async fn last_block_size() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + assert!(stream.current_block_size().is_none()); + + let mut stream = stream.skip(39); + let _ = stream.next().await.unwrap().unwrap().await; + + assert_eq!(stream.get_ref().current_block_size(), Some(40923)); + } + + #[tokio::test] + async fn clamp_position_first_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let _ = stream.next().await.unwrap().unwrap().await; + + assert_eq!(stream.clamp_position(0), Some(124)); + assert_eq!(stream.clamp_position(124), Some(124)); + assert_eq!(stream.clamp_position(200), Some(124)); + } + + #[tokio::test] + async fn clamp_position_second_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let _ = stream.next().await.unwrap().unwrap().await; + + assert_eq!(stream.clamp_position(80000), Some(124 + 65564)); + } + + #[tokio::test] + async fn clamp_position_past_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + let _ = stream.next().await.unwrap().unwrap().await; + + assert_eq!(stream.clamp_position(2598044), Some(2598043)); + } + + #[tokio::test] + async fn convert_position_first_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let _ = stream.next().await.unwrap().unwrap().await; + + let pos = stream.to_encrypted(0); + assert_eq!(pos, Some(124)); + assert_eq!(stream.clamp_position(pos.unwrap()), Some(124)); + + let pos = stream.to_encrypted(200); + assert_eq!(pos, Some(124 + 12 + 200)); + assert_eq!(stream.clamp_position(pos.unwrap()), Some(124)); + } + + #[tokio::test] + async fn convert_position_second_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let _ = stream.next().await.unwrap().unwrap().await; + + let pos = stream.to_encrypted(80000); + assert_eq!(pos, Some(124 + 65564 + 12 + (80000 - 65536))); + assert_eq!(stream.clamp_position(pos.unwrap()), Some(124 + 65564)); + } + + #[tokio::test] + async fn convert_position_past_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + let _ = stream.next().await.unwrap().unwrap().await; + + let pos = stream.to_encrypted(2596800); + assert_eq!(pos, Some(2598043)); + assert_eq!(stream.clamp_position(pos.unwrap()), Some(2598043)); + } + + #[tokio::test] + async fn seek_first_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let seek = stream.seek_encrypted(SeekFrom::Start(0)).await.unwrap(); + + assert_eq!(seek, 124); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn seek_second_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let seek = stream.seek_encrypted(SeekFrom::Start(80000)).await.unwrap(); + + assert_eq!(seek, 124 + 65564); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let seek = stream + .seek_encrypted(SeekFrom::Current(-20000)) + .await + .unwrap(); + + assert_eq!(seek, 124); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn seek_to_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let seek = stream.seek_encrypted(SeekFrom::End(-1000)).await.unwrap(); + + assert_eq!(seek, 2598043 - 40923); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let block = stream.next().await.unwrap().unwrap().await.unwrap(); + assert_last_data_block(block.bytes.to_vec()).await; + } + + #[tokio::test] + async fn seek_past_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let seek = stream.seek_encrypted(SeekFrom::End(80000)).await.unwrap(); + + assert_eq!(seek, 2598043); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + assert!(stream.next().await.is_none()); + } + + #[tokio::test] + async fn seek_past_end_stream_length_override() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .with_stream_length(2598043) + .build(src, vec![recipient_private_key]); + + let seek = stream.seek_encrypted(SeekFrom::End(80000)).await.unwrap(); + + assert_eq!(seek, 2598043); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + assert!(stream.next().await.is_none()); + } + + #[tokio::test] + async fn advance_first_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let advance = stream.advance_encrypted(0).await.unwrap(); + + assert_eq!(advance, 124); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn advance_second_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let advance = stream.advance_encrypted(80000).await.unwrap(); + + assert_eq!(advance, 124 + 65564); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn advance_to_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let advance = stream.advance_encrypted(2598042).await.unwrap(); + + assert_eq!(advance, 2598043 - 40923); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn advance_past_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let advance = stream.advance_encrypted(2598044).await.unwrap(); + + assert_eq!(advance, 2598043); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn advance_past_end_stream_length_override() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .with_stream_length(2598043) + .build(src, vec![recipient_private_key]); + + let advance = stream.advance_encrypted(2598044).await.unwrap(); + + assert_eq!(advance, 2598043); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn seek_first_data_block_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let seek = stream.seek_unencrypted(0).await.unwrap(); + + assert_eq!(seek, 124); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn seek_second_data_block_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let seek = stream.seek_unencrypted(65537).await.unwrap(); + + assert_eq!(seek, 124 + 65564); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let seek = stream.seek_unencrypted(65535).await.unwrap(); + + assert_eq!(seek, 124); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn seek_to_end_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let seek = stream.seek_unencrypted(2596799).await.unwrap(); + + assert_eq!(seek, 2598043 - 40923); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let block = stream.next().await.unwrap().unwrap().await.unwrap(); + assert_last_data_block(block.bytes.to_vec()).await; + } + + #[tokio::test] + async fn seek_past_end_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let seek = stream.seek_unencrypted(2596800).await.unwrap(); + + assert_eq!(seek, 2598043); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + assert!(stream.next().await.is_none()); + } + + #[tokio::test] + async fn seek_past_end_stream_unencrypted_length_override() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .with_stream_length(2598043) + .build(src, vec![recipient_private_key]); + + let seek = stream.seek_unencrypted(2596800).await.unwrap(); + + assert_eq!(seek, 2598043); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + assert!(stream.next().await.is_none()); + } + + #[tokio::test] + async fn advance_first_data_block_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let advance = stream.advance_unencrypted(0).await.unwrap(); + + assert_eq!(advance, 124); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn advance_second_data_block_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let advance = stream.advance_unencrypted(65537).await.unwrap(); + + assert_eq!(advance, 124 + 65564); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn advance_to_end_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let advance = stream.advance_unencrypted(2596799).await.unwrap(); + + assert_eq!(advance, 2598043 - 40923); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn advance_past_end_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let advance = stream.advance_unencrypted(2596800).await.unwrap(); + + assert_eq!(advance, 2598043); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn advance_past_end_unencrypted_stream_length_override() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut stream = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .with_stream_length(2598043) + .build(src, vec![recipient_private_key]); + + let advance = stream.advance_unencrypted(2596800).await.unwrap(); + + assert_eq!(advance, 2598043); + assert_eq!(stream.header_size(), Some(124)); + assert_eq!(stream.current_block_size(), None); + + let mut futures = vec![]; + while let Some(block) = stream.next().await { + futures.push(block.unwrap()); + } + + let decrypted_bytes = + join_all(futures) + .await + .into_iter() + .fold(BytesMut::new(), |mut acc, bytes| { + let (bytes, _) = bytes.unwrap().into_inner(); + acc.extend(bytes.0); + acc + }); + + // Assert that the decrypted bytes are equal to the original file bytes. + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } +} diff --git a/async-crypt4gh/src/edit_lists.rs b/async-crypt4gh/src/edit_lists.rs new file mode 100644 index 000000000..017fdbfae --- /dev/null +++ b/async-crypt4gh/src/edit_lists.rs @@ -0,0 +1,343 @@ +use std::collections::HashSet; + +use crypt4gh::header::{encrypt, make_packet_data_edit_list, HeaderInfo}; +use crypt4gh::Keys; +use rustls::PrivateKey; +use tokio::io::AsyncRead; +use tracing::info; + +use crate::error::{Error, Result}; +use crate::reader::Reader; +use crate::PublicKey; + +/// Unencrypted byte range positions. Contains inclusive start values and exclusive end values. +#[derive(Debug, Clone)] +pub struct UnencryptedPosition { + start: u64, + end: u64, +} + +impl UnencryptedPosition { + pub fn new(start: u64, end: u64) -> Self { + Self { start, end } + } + + pub fn start(&self) -> u64 { + self.start + } + + pub fn end(&self) -> u64 { + self.end + } +} + +/// Encrypted byte range positions. Contains inclusive start values and exclusive end values. +#[derive(Debug, Clone)] +pub struct ClampedPosition { + start: u64, + end: u64, +} + +impl ClampedPosition { + pub fn new(start: u64, end: u64) -> Self { + Self { start, end } + } + + pub fn start(&self) -> u64 { + self.start + } + + pub fn end(&self) -> u64 { + self.end + } +} + +/// Bytes representing a header packet with an edit list. +#[derive(Debug, Clone)] +pub struct Header { + header_info: Vec, + original_header: Vec, + edit_list_packet: Vec, +} + +impl Header { + pub fn new(header_info: Vec, original_header: Vec, edit_list_packet: Vec) -> Self { + Self { + header_info, + original_header, + edit_list_packet, + } + } + + pub fn into_inner(self) -> (Vec, Vec, Vec) { + ( + self.header_info, + self.original_header, + self.edit_list_packet, + ) + } + + pub fn as_slice(&self) -> Vec { + [ + self.header_info.as_slice(), + self.original_header.as_slice(), + self.edit_list_packet.as_slice(), + ] + .concat() + } +} + +impl From<(Vec, Vec, Vec)> for Header { + fn from((header_info, original_header, edit_list_packet): (Vec, Vec, Vec)) -> Self { + Self::new(header_info, original_header, edit_list_packet) + } +} + +pub struct EditHeader<'a, R> +where + R: AsyncRead + Unpin, +{ + reader: &'a Reader, + unencrypted_positions: Vec, + clamped_positions: Vec, + private_key: PrivateKey, + recipient_public_key: PublicKey, +} + +impl<'a, R> EditHeader<'a, R> +where + R: AsyncRead + Unpin, +{ + pub fn new( + reader: &'a Reader, + unencrypted_positions: Vec, + clamped_positions: Vec, + private_key: PrivateKey, + recipient_public_key: PublicKey, + ) -> Self { + Self { + reader, + unencrypted_positions, + clamped_positions, + private_key, + recipient_public_key, + } + } + + /// Encrypt the edit list packet. + pub fn encrypt_edit_list(&self, edit_list_packet: Vec) -> Result> { + info!("encrypting edit list"); + let keys = Keys { + method: 0, + privkey: self.private_key.clone().0, + recipient_pubkey: self.recipient_public_key.clone().into_inner(), + }; + + encrypt(&edit_list_packet, &HashSet::from_iter(vec![keys]))? + .into_iter() + .last() + .ok_or_else(|| Error::Crypt4GHError("could not encrypt header packet".to_string())) + } + + /// Create the edit lists from the unencrypted byte positions. + pub fn create_edit_list(&self) -> Vec { + info!("creating edit list"); + let mut unencrypted_positions: Vec = self + .unencrypted_positions + .iter() + .flat_map(|pos| [pos.start, pos.end]) + .collect(); + + // Collect the clamped and unencrypted positions into separate edit list groups. + let (mut edit_list, last_discard) = + self + .clamped_positions + .iter() + .fold((vec![], 0), |(mut edit_list, previous_discard), pos| { + // Get the correct number of unencrypted positions that fit within this clamped position. + let partition = + unencrypted_positions.partition_point(|unencrypted_pos| unencrypted_pos <= &pos.end); + let mut positions: Vec = unencrypted_positions.drain(..partition).collect(); + + // Merge all positions. + positions.insert(0, pos.start); + positions.push(pos.end); + + // Find the difference between consecutive positions to get the edits. + let mut positions: Vec = positions + .iter() + .zip(positions.iter().skip(1)) + .map(|(start, end)| end - start) + .collect(); + + // Add the previous discard to the first edit. + if let Some(first) = positions.first_mut() { + *first += previous_discard; + } + + // If the last edit is a discard, then carry this over into the next iteration. + let next_discard = if positions.len() % 2 == 0 { + 0 + } else { + positions.pop().unwrap_or(0) + }; + + // Add edits to the accumulating edit list. + edit_list.extend(positions); + (edit_list, next_discard) + }); + + // If there is a final discard, then add this to the edit list. + if last_discard != 0 { + edit_list.push(last_discard); + } + + edit_list + } + + /// Add edit lists and return a header packet. + pub fn edit_list(self) -> Result> { + info!("adding edit list"); + if self.reader.edit_list_packet().is_some() { + return Err(Error::Crypt4GHError("edit lists already exist".to_string())); + } + + // Todo, header info should have copy or clone on it. + let (mut header_info, encrypted_header_packets) = + if let (Some(header_info), Some(encrypted_header_packets)) = ( + self.reader.header_info(), + self.reader.encrypted_header_packets(), + ) { + ( + HeaderInfo { + magic_number: header_info.magic_number, + version: header_info.version, + packets_count: header_info.packets_count, + }, + encrypted_header_packets + .iter() + .flat_map(|packet| [packet.packet_length().to_vec(), packet.header.to_vec()].concat()) + .collect::>(), + ) + } else { + return Ok(None); + }; + + // Todo rewrite this from the context of an encryption stream like the decrypter. + header_info.packets_count += 1; + let header_info_bytes = + bincode::serialize(&header_info).map_err(|err| Error::Crypt4GHError(err.to_string()))?; + + let edit_list = self.create_edit_list(); + let edit_list_packet = + make_packet_data_edit_list(edit_list.into_iter().map(|edit| edit as usize).collect()); + + let edit_list_bytes = self.encrypt_edit_list(edit_list_packet)?; + let edit_list_bytes = [ + ((edit_list_bytes.len() + 4) as u32).to_le_bytes().to_vec(), + edit_list_bytes, + ] + .concat(); + + Ok(Some( + (header_info_bytes, encrypted_header_packets, edit_list_bytes).into(), + )) + } +} + +#[cfg(test)] +mod tests { + use htsget_test::crypt4gh::{get_decryption_keys, get_encryption_keys}; + use htsget_test::http::get_test_file; + + use crate::reader::builder::Builder; + + use super::*; + + #[tokio::test] + async fn test_append_edit_list() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (private_key_decrypt, public_key_decrypt) = get_decryption_keys().await; + let (private_key_encrypt, public_key_encrypt) = get_encryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(public_key_decrypt.clone())) + .with_stream_length(5485112) + .build_with_reader(src, vec![private_key_decrypt.clone()]); + reader.read_header().await.unwrap(); + + let expected_data_packets = reader.session_keys().to_vec(); + + let header = EditHeader::new( + &reader, + test_unencrypted_positions(), + test_clamped_positions(), + PrivateKey(private_key_encrypt.clone().privkey), + PublicKey { + bytes: public_key_encrypt.clone(), + }, + ) + .edit_list() + .unwrap() + .unwrap(); + + let header_slice = header.as_slice(); + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(public_key_decrypt)) + .with_stream_length(5485112) + .build_with_reader(header_slice.as_slice(), vec![private_key_decrypt]); + reader.read_header().await.unwrap(); + + let data_packets = reader.session_keys(); + assert_eq!(data_packets, expected_data_packets); + + let edit_lists = reader.edit_list_packet().unwrap(); + assert_eq!(edit_lists, expected_edit_list()); + } + + #[tokio::test] + async fn test_create_edit_list() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (private_key_decrypt, public_key_decrypt) = get_decryption_keys().await; + let (private_key_encrypt, public_key_encrypt) = get_encryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(public_key_decrypt.clone())) + .with_stream_length(5485112) + .build_with_reader(src, vec![private_key_decrypt.clone()]); + reader.read_header().await.unwrap(); + + let edit_list = EditHeader::new( + &reader, + test_unencrypted_positions(), + test_clamped_positions(), + PrivateKey(private_key_encrypt.clone().privkey), + PublicKey { + bytes: public_key_encrypt.clone(), + }, + ) + .create_edit_list(); + + assert_eq!(edit_list, expected_edit_list()); + } + + fn test_unencrypted_positions() -> Vec { + vec![ + UnencryptedPosition::new(0, 7853), + UnencryptedPosition::new(145110, 453039), + UnencryptedPosition::new(5485074, 5485112), + ] + } + + fn test_clamped_positions() -> Vec { + vec![ + ClampedPosition::new(0, 65536), + ClampedPosition::new(131072, 458752), + ClampedPosition::new(5439488, 5485112), + ] + } + + fn expected_edit_list() -> Vec { + vec![0, 7853, 71721, 307929, 51299, 38] + } +} diff --git a/async-crypt4gh/src/error.rs b/async-crypt4gh/src/error.rs new file mode 100644 index 000000000..6e2827427 --- /dev/null +++ b/async-crypt4gh/src/error.rs @@ -0,0 +1,51 @@ +use std::{io, result}; + +use crypt4gh::error::Crypt4GHError; +use thiserror::Error; +use tokio::task; + +/// The result type for Crypt4GH errors. +pub type Result = result::Result; + +/// Errors related to Crypt4GH. +#[derive(Error, Debug)] +pub enum Error { + #[error("converting slice to fixed size array")] + SliceConversionError, + #[error("converting between numeric types")] + NumericConversionError, + #[error("decoding header info: `{0}`")] + DecodingHeaderInfo(Crypt4GHError), + #[error("decoding header packet: `{0}`")] + DecodingHeaderPacket(Crypt4GHError), + #[error("io error: `{0}`")] + IOError(io::Error), + #[error("join handle error: `{0}`")] + JoinHandleError(task::JoinError), + #[error("maximum header size exceeded")] + MaximumHeaderSize, + #[error("crypt4gh error: `{0}`")] + Crypt4GHError(String), +} + +impl From for Error { + fn from(error: io::Error) -> Self { + Self::IOError(error) + } +} + +impl From for io::Error { + fn from(error: Error) -> Self { + if let Error::IOError(error) = error { + error + } else { + Self::new(io::ErrorKind::Other, error) + } + } +} + +impl From for Error { + fn from(error: Crypt4GHError) -> Self { + Self::Crypt4GHError(error.to_string()) + } +} diff --git a/async-crypt4gh/src/lib.rs b/async-crypt4gh/src/lib.rs new file mode 100644 index 000000000..85301f8e7 --- /dev/null +++ b/async-crypt4gh/src/lib.rs @@ -0,0 +1,245 @@ +use std::ops::Deref; + +use bytes::Bytes; +use rustls::PrivateKey; + +pub use reader::builder::Builder as Crypt4GHReaderBuilder; +pub use reader::Reader as Crypt4GHReader; + +pub mod advance; +pub mod decoder; +pub mod decrypter; +pub mod edit_lists; +pub mod error; +pub mod reader; +pub mod util; + +/// A wrapper around a vec of bytes that represent a public key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PublicKey { + bytes: Vec, +} + +/// A key pair containing a public and private key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct KeyPair { + private_key: PrivateKey, + public_key: PublicKey, +} + +impl KeyPair { + /// Create a new key pair. + pub fn new(private_key: PrivateKey, public_key: PublicKey) -> Self { + Self { + private_key, + public_key, + } + } + + /// Get the inner keys. + pub fn into_inner(self) -> (PrivateKey, PublicKey) { + (self.private_key, self.public_key) + } + + /// Get private key. + pub fn private_key(&self) -> &PrivateKey { + &self.private_key + } + + /// Get private key + pub fn public_key(&self) -> &PublicKey { + &self.public_key + } +} + +impl PublicKey { + /// Create a new sender public key from bytes. + pub fn new(bytes: Vec) -> Self { + Self { bytes } + } + + /// Get the inner bytes. + pub fn into_inner(self) -> Vec { + self.bytes + } + + /// Get the inner bytes as a reference. + pub fn get_ref(&self) -> &[u8] { + self.bytes.as_slice() + } +} + +/// Represents an encrypted header packet with the packet length, and the remaining header. +#[derive(Clone, Debug, Default)] +pub struct EncryptedHeaderPacketBytes { + packet_length: Bytes, + header: Bytes, +} + +impl EncryptedHeaderPacketBytes { + /// Create header packet bytes. + pub fn new(packet_length: Bytes, header: Bytes) -> Self { + Self { + packet_length, + header, + } + } + + /// Get packet length bytes. + pub fn packet_length(&self) -> &Bytes { + &self.packet_length + } + + /// Get header bytes. + pub fn header(&self) -> &Bytes { + &self.header + } + + /// Get the owned packet length and header bytes. + pub fn into_inner(self) -> (Bytes, Bytes) { + (self.packet_length, self.header) + } + + /// Get the header bytes only. + pub fn into_header_bytes(self) -> Bytes { + self.header + } +} + +/// Represents the encrypted header packet data, and the total size of all the header packets. +#[derive(Debug, Default)] +pub struct EncryptedHeaderPackets { + header_packets: Vec, + header_length: u64, +} + +impl EncryptedHeaderPackets { + /// Create a new decrypted data block. + pub fn new(header_packets: Vec, size: u64) -> Self { + Self { + header_packets, + header_length: size, + } + } + + /// Get the header packet bytes + pub fn header_packets(&self) -> &Vec { + &self.header_packets + } + + /// Get the size of all the packets. + pub fn header_length(&self) -> u64 { + self.header_length + } + + /// Get the inner bytes and size. + pub fn into_inner(self) -> (Vec, u64) { + (self.header_packets, self.header_length) + } +} + +/// Represents the decrypted data block and its original encrypted size. +#[derive(Debug, Default)] +pub struct DecryptedDataBlock { + bytes: DecryptedBytes, + encrypted_size: usize, +} + +impl DecryptedDataBlock { + /// Create a new decrypted data block. + pub fn new(bytes: DecryptedBytes, encrypted_size: usize) -> Self { + Self { + bytes, + encrypted_size, + } + } + + /// Get the plain text bytes. + pub fn bytes(&self) -> &DecryptedBytes { + &self.bytes + } + + /// Get the encrypted size. + pub fn encrypted_size(&self) -> usize { + self.encrypted_size + } + + /// Get the inner bytes and size. + pub fn into_inner(self) -> (DecryptedBytes, usize) { + (self.bytes, self.encrypted_size) + } + + /// Get the length of the decrypted bytes. + pub const fn len(&self) -> usize { + self.bytes.len() + } + + /// Check if the decrypted bytes are empty + pub const fn is_empty(&self) -> bool { + self.bytes.is_empty() + } +} + +impl Deref for DecryptedDataBlock { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &[u8] { + self.bytes.deref() + } +} + +/// A wrapper around a vec of bytes that represents decrypted data. +#[derive(Debug, Default, Clone)] +pub struct DecryptedBytes(Bytes); + +impl DecryptedBytes { + /// Create new decrypted bytes from bytes. + pub fn new(bytes: Bytes) -> Self { + Self(bytes) + } + + /// Get the inner bytes. + pub fn into_inner(self) -> Bytes { + self.0 + } + + /// Get the length of the inner bytes. + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Check if the inner bytes are empty. + pub const fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Deref for DecryptedBytes { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &[u8] { + self.0.deref() + } +} + +#[cfg(test)] +pub(crate) mod tests { + use tokio::io::AsyncReadExt; + + use htsget_test::http::get_test_file; + + /// Get the original file bytes. + pub(crate) async fn get_original_file() -> Vec { + let mut original_file = get_test_file("bam/htsnexus_test_NA12878.bam").await; + let mut original_bytes = vec![]; + + original_file + .read_to_end(&mut original_bytes) + .await + .unwrap(); + + original_bytes + } +} diff --git a/async-crypt4gh/src/reader/builder.rs b/async-crypt4gh/src/reader/builder.rs new file mode 100644 index 000000000..acc426f51 --- /dev/null +++ b/async-crypt4gh/src/reader/builder.rs @@ -0,0 +1,112 @@ +use std::thread; + +use crypt4gh::Keys; +use futures_util::TryStreamExt; +use tokio::io::{AsyncRead, AsyncSeek}; + +use crate::decrypter::builder::Builder as DecrypterBuilder; +use crate::decrypter::DecrypterStream; +use crate::error::Result; +use crate::PublicKey; + +use super::Reader; + +/// An async Crypt4GH reader builder. +#[derive(Debug, Default)] +pub struct Builder { + worker_count: Option, + sender_pubkey: Option, + stream_length: Option, +} + +impl Builder { + /// Sets a worker count. + pub fn with_worker_count(self, worker_count: usize) -> Self { + self.set_worker_count(Some(worker_count)) + } + + /// Sets the sender public key + pub fn with_sender_pubkey(self, sender_pubkey: PublicKey) -> Self { + self.set_sender_pubkey(Some(sender_pubkey)) + } + + /// Sets a worker count. + pub fn set_worker_count(mut self, worker_count: Option) -> Self { + self.worker_count = worker_count; + self + } + + /// Sets the sender public key + pub fn set_sender_pubkey(mut self, sender_pubkey: Option) -> Self { + self.sender_pubkey = sender_pubkey; + self + } + + /// Sets the stream length. + pub fn with_stream_length(self, stream_length: u64) -> Self { + self.set_stream_length(Some(stream_length)) + } + + /// Sets the stream length. + pub fn set_stream_length(mut self, stream_length: Option) -> Self { + self.stream_length = stream_length; + self + } + + /// Build the Crypt4GH reader. + pub fn build_with_reader(self, inner: R, keys: Vec) -> Reader + where + R: AsyncRead, + { + let worker_counter = self.worker_count(); + + Reader { + stream: DecrypterBuilder::default() + .set_sender_pubkey(self.sender_pubkey) + .set_stream_length(self.stream_length) + .build(inner, keys) + .try_buffered(worker_counter), + // Dummy value for bytes to begin with. + current_block: Default::default(), + buf_position: 0, + block_position: None, + } + } + + /// Build the reader and compute the stream length for seek operations. + pub async fn build_with_stream_length(self, inner: R, keys: Vec) -> Result> + where + R: AsyncRead + AsyncSeek + Unpin, + { + let stream_length = self.stream_length; + let mut reader = self.build_with_reader(inner, keys); + + if stream_length.is_none() { + reader.stream.get_mut().recompute_stream_length().await?; + } + + Ok(reader) + } + + /// Build the Crypt4GH reader with a decryper stream. + pub fn build_with_stream(self, stream: DecrypterStream) -> Reader + where + R: AsyncRead, + { + Reader { + stream: stream.try_buffered(self.worker_count()), + // Dummy value for bytes to begin with. + current_block: Default::default(), + buf_position: 0, + block_position: None, + } + } + + fn worker_count(&self) -> usize { + self.worker_count.unwrap_or_else(|| { + thread::available_parallelism() + .map(|worker_count| worker_count.get()) + .unwrap_or_else(|_| 1) + }) + } +} diff --git a/async-crypt4gh/src/reader/mod.rs b/async-crypt4gh/src/reader/mod.rs new file mode 100644 index 000000000..e9751821b --- /dev/null +++ b/async-crypt4gh/src/reader/mod.rs @@ -0,0 +1,774 @@ +use std::io::SeekFrom; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{cmp, io}; + +use async_trait::async_trait; +use crypt4gh::header::HeaderInfo; +use crypt4gh::Keys; +use futures::ready; +use futures::stream::TryBuffered; +use futures::Stream; +use pin_project_lite::pin_project; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, ReadBuf}; + +use crate::advance::Advance; +use crate::decoder::Block; +use crate::error::Error::NumericConversionError; +use crate::error::Result; +use crate::reader::builder::Builder; +use crate::{DecryptedDataBlock, EncryptedHeaderPacketBytes}; + +use super::decrypter::DecrypterStream; + +pub mod builder; + +pin_project! { + pub struct Reader + where R: AsyncRead + { + #[pin] + stream: TryBuffered>, + current_block: DecryptedDataBlock, + // The current position in the decrypted buffer. + buf_position: usize, + // The encrypted position of the current data block minus the size of the header. + block_position: Option + } +} + +impl Reader +where + R: AsyncRead, +{ + /// Gets the position of the data block which includes the current position of the underlying + /// reader. This function will return a value that always corresponds the beginning of a data + /// block or `None` if the reader has not read any bytes. + pub fn current_block_position(&self) -> Option { + self.block_position + } + + /// Gets the position of the next data block from the current position of the underlying reader. + /// This function will return a value that always corresponds the beginning of a data block, the + /// size of the file, or `None` if the reader has not read any bytes. + pub fn next_block_position(&self) -> Option { + self.block_position.and_then(|block_position| { + self + .stream + .get_ref() + .clamp_position(block_position + Block::standard_data_block_size()) + }) + } + + /// Get a reference to the inner reader. + pub fn get_ref(&self) -> &R { + self.stream.get_ref().get_ref() + } + + /// Get a mutable reference to the inner reader. + pub fn get_mut(&mut self) -> &mut R { + self.stream.get_mut().get_mut() + } + + /// Get a pinned mutable reference to the inner reader. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.project().stream.get_pin_mut().get_pin_mut() + } + + /// Get the inner reader. + pub fn into_inner(self) -> R { + self.stream.into_inner().into_inner() + } + + /// Get the session keys. Empty before the header is polled. + pub fn session_keys(&self) -> &[Vec] { + self.stream.get_ref().session_keys() + } + + /// Get the edit list packet. Empty before the header is polled. + pub fn edit_list_packet(&self) -> Option> { + self.stream.get_ref().edit_list_packet() + } + + /// Get the header info. + pub fn header_info(&self) -> Option<&HeaderInfo> { + self.stream.get_ref().header_info() + } + + /// Get the header size + pub fn header_size(&self) -> Option { + self.stream.get_ref().header_size() + } + + /// Get the original encrypted header packets, not including the header info. + pub fn encrypted_header_packets(&self) -> Option<&Vec> { + self.stream.get_ref().encrypted_header_packets() + } + + /// Poll the reader until the header has been read. + pub async fn read_header(&mut self) -> Result<()> + where + R: Unpin, + { + self.stream.get_mut().read_header().await + } + + /// Get the reader's keys. + pub fn keys(&self) -> &[Keys] { + self.stream.get_ref().keys() + } +} + +impl From> for Reader +where + R: AsyncRead, +{ + fn from(stream: DecrypterStream) -> Self { + Builder::default().build_with_stream(stream) + } +} + +impl AsyncRead for Reader +where + R: AsyncRead, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // Defer to poll_fill_buf to do the read. + let src = ready!(self.as_mut().poll_fill_buf(cx))?; + + // Calculate the correct amount to read and copy over to the read buf. + let amt = cmp::min(src.len(), buf.remaining()); + buf.put_slice(&src[..amt]); + + // Inform the internal buffer that amt has been consumed. + self.consume(amt); + + Poll::Ready(Ok(())) + } +} + +impl AsyncBufRead for Reader +where + R: AsyncRead, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + // If this is the beginning of the stream, set the block position to the header length, if any. + if let (None, length @ Some(_)) = ( + this.block_position.as_ref(), + this.stream.get_ref().header_size(), + ) { + *this.block_position = length; + } + + // If the position is past the end of the buffer, then all the data has been read and a new + // buffer should be initialised. + if *this.buf_position >= this.current_block.len() { + match ready!(this.stream.poll_next(cx)) { + Some(Ok(block)) => { + // Update the block position with the previous block size. + *this.block_position = Some( + this.block_position.unwrap_or_default() + + u64::try_from(this.current_block.encrypted_size()) + .map_err(|_| NumericConversionError)?, + ); + + // We have a new buffer, reinitialise the position and buffer. + *this.current_block = block; + *this.buf_position = 0; + } + Some(Err(e)) => return Poll::Ready(Err(e.into())), + None => return Poll::Ready(Ok(&[])), + } + } + + // Return the unconsumed data from the buffer. + Poll::Ready(Ok(&this.current_block[*this.buf_position..])) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let this = self.project(); + // Update the buffer position until the consumed amount reaches the end of the buffer. + *this.buf_position = cmp::min(*this.buf_position + amt, this.current_block.len()); + } +} + +impl Reader +where + R: AsyncRead + AsyncSeek + Unpin + Send, +{ + /// Seek to a position in the encrypted stream. + pub async fn seek_encrypted(&mut self, position: SeekFrom) -> io::Result { + let position = self.stream.get_mut().seek_encrypted(position).await?; + + self.block_position = Some(position); + + Ok(position) + } + + /// Seek to a position in the unencrypted stream. + pub async fn seek_unencrypted(&mut self, position: u64) -> io::Result { + let position = self.stream.get_mut().seek_unencrypted(position).await?; + + self.block_position = Some(position); + + Ok(position) + } +} + +#[async_trait] +impl Advance for Reader +where + R: AsyncRead + Send + Unpin, +{ + async fn advance_encrypted(&mut self, position: u64) -> io::Result { + let position = self.stream.get_mut().advance_encrypted(position).await?; + + self.block_position = Some(position); + + Ok(position) + } + + async fn advance_unencrypted(&mut self, position: u64) -> io::Result { + let position = self.stream.get_mut().advance_unencrypted(position).await?; + + self.block_position = Some(position); + + Ok(position) + } + + fn stream_length(&self) -> Option { + self.stream.get_ref().stream_length() + } +} + +#[cfg(test)] +mod tests { + use std::io::SeekFrom; + + use futures_util::TryStreamExt; + use noodles::bam::AsyncReader; + use noodles::sam::Header; + use tokio::io::AsyncReadExt; + + use htsget_test::http::get_test_file; + + use crate::advance::Advance; + use crate::reader::builder::Builder; + use crate::tests::get_original_file; + use crate::PublicKey; + use htsget_test::crypt4gh::get_decryption_keys; + + #[tokio::test] + async fn reader() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let mut decrypted_bytes = vec![]; + reader.read_to_end(&mut decrypted_bytes).await.unwrap(); + + let original_bytes = get_original_file().await; + assert_eq!(decrypted_bytes, original_bytes); + } + + #[tokio::test] + async fn reader_with_noodles() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + let mut reader = AsyncReader::new(reader); + + let original_file = get_test_file("bam/htsnexus_test_NA12878.bam").await; + let mut original_reader = AsyncReader::new(original_file); + + let header: Header = reader.read_header().await.unwrap().parse().unwrap(); + let reference_sequences = reader.read_reference_sequences().await.unwrap(); + + let original_header: Header = original_reader + .read_header() + .await + .unwrap() + .parse() + .unwrap(); + let original_reference_sequences = original_reader.read_reference_sequences().await.unwrap(); + + assert_eq!(header, original_header); + assert_eq!(reference_sequences, original_reference_sequences); + + let mut stream = original_reader.records(&original_header); + let mut original_records = vec![]; + while let Some(record) = stream.try_next().await.unwrap() { + original_records.push(record); + } + + let mut stream = reader.records(&header); + let mut records = vec![]; + while let Some(record) = stream.try_next().await.unwrap() { + records.push(record); + } + + assert_eq!(records, original_records); + } + + #[tokio::test] + async fn first_current_block_position() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the current block should not be known. + assert_eq!(reader.current_block_position(), None); + + // Read the first byte of the decrypted data. + let mut buf = [0u8; 1]; + reader.read_exact(&mut buf).await.unwrap(); + + // Now the current position should be at the end of the header. + assert_eq!(reader.current_block_position(), Some(124)); + } + + #[tokio::test] + async fn first_next_block_position() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the next block should not be known. + assert_eq!(reader.next_block_position(), None); + + // Read the first byte of the decrypted data. + let mut buf = [0u8; 1]; + reader.read_exact(&mut buf).await.unwrap(); + + // Now the next position should be at the second data block. + assert_eq!(reader.next_block_position(), Some(124 + 65564)); + } + + #[tokio::test] + async fn last_current_block_position() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the current block should not be known. + assert_eq!(reader.current_block_position(), None); + + // Read the whole file. + let mut decrypted_bytes = vec![]; + reader.read_to_end(&mut decrypted_bytes).await.unwrap(); + + // Now the current position should be at the last data block. + assert_eq!(reader.current_block_position(), Some(2598043 - 40923)); + } + + #[tokio::test] + async fn last_next_block_position() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the next block should not be known. + assert_eq!(reader.next_block_position(), None); + + // Read the whole file. + let mut decrypted_bytes = vec![]; + reader.read_to_end(&mut decrypted_bytes).await.unwrap(); + + // Now the next position should be the size of the file. + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn seek_first_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.seek_encrypted(SeekFrom::Start(0)).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(124)); + assert_eq!(reader.next_block_position(), Some(124 + 65564)); + } + + #[tokio::test] + async fn seek_to_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader + .seek_encrypted(SeekFrom::Start(2598042)) + .await + .unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043 - 40923)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn seek_past_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader + .seek_encrypted(SeekFrom::Start(2598044)) + .await + .unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn seek_past_end_stream_length_override() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .with_stream_length(2598043) + .build_with_reader(src, vec![recipient_private_key]); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader + .seek_encrypted(SeekFrom::Start(2598044)) + .await + .unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn advance_first_data_block() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.advance_encrypted(0).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(124)); + assert_eq!(reader.next_block_position(), Some(124 + 65564)); + } + + #[tokio::test] + async fn advance_to_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.advance_encrypted(2598042).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043 - 40923)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn advance_past_end() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.advance_encrypted(2598044).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn advance_past_end_stream_length_override() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .with_stream_length(2598043) + .build_with_reader(src, vec![recipient_private_key]); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.advance_encrypted(2598044).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn seek_first_data_block_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.seek_unencrypted(0).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(124)); + assert_eq!(reader.next_block_position(), Some(124 + 65564)); + } + + #[tokio::test] + async fn seek_to_end_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.seek_unencrypted(2596799).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043 - 40923)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn seek_past_end_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.seek_unencrypted(2596800).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn seek_past_end_unencrypted_stream_length_override() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .with_stream_length(2598043) + .build_with_reader(src, vec![recipient_private_key]); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.seek_unencrypted(2596800).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn advance_first_data_block_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.advance_unencrypted(0).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(124)); + assert_eq!(reader.next_block_position(), Some(124 + 65564)); + } + + #[tokio::test] + async fn advance_to_end_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.advance_unencrypted(2596799).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043 - 40923)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn advance_past_end_unencrypted() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .build_with_stream_length(src, vec![recipient_private_key]) + .await + .unwrap(); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.advance_unencrypted(2596800).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } + + #[tokio::test] + async fn advance_past_end_unencrypted_stream_length_override() { + let src = get_test_file("crypt4gh/htsnexus_test_NA12878.bam.c4gh").await; + let (recipient_private_key, sender_public_key) = get_decryption_keys().await; + + let mut reader = Builder::default() + .with_sender_pubkey(PublicKey::new(sender_public_key)) + .with_stream_length(2598043) + .build_with_reader(src, vec![recipient_private_key]); + + // Before anything is read the block positions should not be known. + assert_eq!(reader.current_block_position(), None); + assert_eq!(reader.next_block_position(), None); + + reader.advance_unencrypted(2596800).await.unwrap(); + + // Now the positions should be at the first data block. + assert_eq!(reader.current_block_position(), Some(2598043)); + assert_eq!(reader.next_block_position(), Some(2598043)); + } +} diff --git a/async-crypt4gh/src/util/mod.rs b/async-crypt4gh/src/util/mod.rs new file mode 100644 index 000000000..c8850687e --- /dev/null +++ b/async-crypt4gh/src/util/mod.rs @@ -0,0 +1,258 @@ +use base64::engine::general_purpose; +use base64::Engine; +use bstr::ByteSlice; +use crypt4gh::error::Crypt4GHError; +use crypt4gh::keys::generate_private_key; +use rustls::PrivateKey; +use std::cmp::min; +use std::io; +use std::ops::Add; + +use crate::decoder::Block; +use crate::error::{Error, Result}; +use crate::{KeyPair, PublicKey}; + +fn to_current_data_block(pos: u64, header_len: u64) -> u64 { + header_len + (pos / Block::encrypted_block_size()) * Block::standard_data_block_size() +} + +/// Convert an unencrypted file position to an encrypted position if the header length is known. +pub fn to_encrypted(position: u64, header_length: u64) -> u64 { + let number_data_blocks = position / Block::encrypted_block_size(); + // Additional bytes include the full data block size. + let mut additional_bytes = number_data_blocks * (Block::nonce_size() + Block::mac_size()); + + // If there is left over data, then there are more nonce bytes. + let remainder = position % Block::encrypted_block_size(); + if remainder != 0 { + additional_bytes += Block::nonce_size(); + } + + // Then add the extra bytes to the current position. + header_length + position + additional_bytes +} + +/// Convert an encrypted file position to an unencrypted position if the header length is known. +pub fn to_unencrypted(encrypted_position: u64, header_length: u64) -> u64 { + let number_data_blocks = encrypted_position / Block::standard_data_block_size(); + let mut additional_bytes = number_data_blocks * (Block::nonce_size() + Block::mac_size()); + + let remainder = encrypted_position % Block::standard_data_block_size(); + if remainder != 0 { + additional_bytes += Block::nonce_size(); + } + + encrypted_position - header_length - additional_bytes +} + +/// Convert an unencrypted file size to an encrypted file size if the header length is known. +pub fn to_encrypted_file_size(file_size: u64, header_length: u64) -> u64 { + to_encrypted(file_size, header_length) + Block::mac_size() +} + +/// Convert an encrypted file size to an unencrypted file size if the header length is known. +pub fn to_unencrypted_file_size(encrypted_file_size: u64, header_length: u64) -> u64 { + to_unencrypted(encrypted_file_size, header_length) - Block::mac_size() +} + +/// Convert an unencrypted position to an encrypted position as shown in +/// https://samtools.github.io/hts-specs/crypt4gh.pdf chapter 4.1. +pub fn unencrypted_to_data_block(pos: u64, header_len: u64, encrypted_file_size: u64) -> u64 { + min(encrypted_file_size, to_current_data_block(pos, header_len)) +} + +/// Get the next data block position from the unencrypted position. +pub fn unencrypted_to_next_data_block(pos: u64, header_len: u64, encrypted_file_size: u64) -> u64 { + min( + encrypted_file_size, + to_current_data_block(pos, header_len) + Block::standard_data_block_size(), + ) +} + +fn unencrypted_clamped_position(pos: u64, encrypted_file_size: u64) -> u64 { + let data_block_positions = unencrypted_to_data_block(pos, 0, encrypted_file_size); + let data_block_count = data_block_positions / Block::standard_data_block_size(); + + data_block_positions - ((Block::nonce_size() + Block::mac_size()) * data_block_count) +} + +/// Convert an unencrypted position to the additional bytes prior to the position that must be +/// included when encrypting data blocks. +pub fn unencrypted_clamp(pos: u64, encrypted_file_size: u64) -> u64 { + min( + to_unencrypted_file_size(encrypted_file_size, 0), + unencrypted_clamped_position(pos, encrypted_file_size), + ) +} + +/// Convert an unencrypted position to the additional bytes after to the position that must be +/// included when encrypting data blocks. +pub fn unencrypted_clamp_next(pos: u64, encrypted_file_size: u64) -> u64 { + min( + to_unencrypted_file_size(encrypted_file_size, 0), + unencrypted_clamped_position(pos, encrypted_file_size) + Block::encrypted_block_size(), + ) +} + +/// Generate a private and public key pair. +pub fn generate_key_pair() -> Result { + let skpk = generate_private_key()?; + let (private_key, public_key) = skpk.split_at(32); + + Ok(KeyPair::new( + PrivateKey(Vec::from(private_key)), + PublicKey::new(Vec::from(public_key)), + )) +} + +pub async fn encode_public_key(public_key: PublicKey) -> String { + let pk = String::new(); + let pk = pk.add("-----BEGIN CRYPT4GH PUBLIC KEY-----\n"); + + let pk = pk.add(&general_purpose::STANDARD.encode(public_key.into_inner())); + + pk.add("\n-----END CRYPT4GH PUBLIC KEY-----\n") +} + +/// Read a public key from bytes +pub async fn read_public_key(bytes: Vec) -> Result { + let mut lines = ByteSlice::lines(bytes.as_slice()).collect::>(); + + let error = || { + Error::IOError(io::Error::new( + io::ErrorKind::Other, + "invalid public key".to_string(), + )) + }; + + if lines.is_empty() { + return Err(error()); + } + + // Optionally decode the key from a string. + let key = if lines + .first() + .is_some_and(|first| first.contains_str(b"CRYPT4GH")) + && lines + .last() + .is_some_and(|first| first.contains_str(b"CRYPT4GH")) + { + lines.remove(0); + lines.pop(); + + general_purpose::STANDARD + .decode(lines.into_iter().flatten().copied().collect::>()) + .map_err(|e| Crypt4GHError::BadBase64Error(e.into()))? + } else { + lines.into_iter().flatten().copied().collect() + }; + + Ok(PublicKey::new(key)) +} + +#[cfg(test)] +mod tests { + use crypt4gh::keys::get_public_key; + use htsget_test::http::get_test_path; + use std::fs; + + use super::*; + use crate::util::{unencrypted_clamp, unencrypted_to_data_block, unencrypted_to_next_data_block}; + + #[test] + fn test_to_encrypted() { + let pos = 80000; + let expected = 120 + 65536 + 12 + 16; + let result = unencrypted_to_data_block(pos, 120, to_encrypted_file_size(100000, 120)); + assert_eq!(result, expected); + } + + #[test] + fn test_to_encrypted_file_size() { + let pos = 110000; + let expected = 60148; + let result = unencrypted_to_data_block(pos, 120, to_encrypted_file_size(60000, 120)); + assert_eq!(result, expected); + } + + #[test] + fn test_to_encrypted_pos_greater_than_file_size() { + let pos = 110000; + let expected = 120 + 65536 + 12 + 16; + let result = unencrypted_to_data_block(pos, 120, to_encrypted_file_size(100000, 120)); + assert_eq!(result, expected); + } + + #[test] + fn test_next_data_block() { + let pos = 100000; + let expected = 120 + (65536 + 12 + 16) * 2; + let result = unencrypted_to_next_data_block(pos, 120, to_encrypted_file_size(150000, 120)); + assert_eq!(result, expected); + } + + #[test] + fn test_next_data_block_file_size() { + let pos = 110000; + let expected = 100176; + let result = unencrypted_to_next_data_block(pos, 120, to_encrypted_file_size(100000, 120)); + assert_eq!(result, expected); + } + + #[test] + fn test_unencrypted_clamp() { + let pos = 0; + let expected = 0; + let result = unencrypted_clamp(pos, to_encrypted_file_size(5485112, 0)); + assert_eq!(result, expected); + + let pos = 145110; + let expected = 131072; + let result = unencrypted_clamp(pos, to_encrypted_file_size(5485112, 0)); + assert_eq!(result, expected); + + let pos = 5485074; + let expected = 5439488; + let result = unencrypted_clamp(pos, to_encrypted_file_size(5485112, 0)); + assert_eq!(result, expected); + } + + #[test] + fn test_unencrypted_clamp_next() { + let pos = 7853; + let expected = 65536; + let result = unencrypted_clamp_next(pos, to_encrypted_file_size(5485112, 0)); + assert_eq!(result, expected); + + let pos = 453039; + let expected = 458752; + let result = unencrypted_clamp_next(pos, to_encrypted_file_size(5485112, 0)); + assert_eq!(result, expected); + + let pos = 5485112; + let expected = 5485112; + let result = unencrypted_clamp_next(pos, to_encrypted_file_size(5485112, 0)); + assert_eq!(result, expected); + } + + #[tokio::test] + async fn test_read_public_key_raw() { + let test_key = vec![ + 56, 44, 122, 180, 24, 116, 207, 149, 165, 49, 204, 77, 224, 136, 232, 121, 209, 249, 23, 51, + 120, 2, 187, 147, 82, 227, 232, 32, 17, 223, 7, 38, + ]; + + let result = read_public_key(test_key.clone()).await.unwrap(); + assert_eq!(result, PublicKey::new(test_key)) + } + + #[tokio::test] + async fn test_read_public_key_with_header() { + let expected_public_key = get_public_key(get_test_path("crypt4gh/keys/bob.pub")).unwrap(); + let result = read_public_key(fs::read(get_test_path("crypt4gh/keys/bob.pub")).unwrap()) + .await + .unwrap(); + + assert_eq!(result, PublicKey::new(expected_public_key)) + } +} diff --git a/data/crypt4gh/README.md b/data/crypt4gh/README.md new file mode 100644 index 000000000..0d33c7e15 --- /dev/null +++ b/data/crypt4gh/README.md @@ -0,0 +1,34 @@ +# Crypt4GH example file + +This is just a customised summary for htsget-rs. Please refer to the official [`crypt4gh-rust` documentation](https://ega-archive.github.io/crypt4gh-rust) for further information. + +## Keygen + +```sh +cargo install crypt4gh +crypt4gh keygen --sk keys/alice.sec --pk keys/alice.pub +crypt4gh keygen --sk keys/bob.sec --pk keys/bob.pub +``` + +## Encrypt +``` +crypt4gh encrypt --sk keys/alice.sec --recipient_pk keys/bob.pub < htsnexus_test_NA12878.bam > htsnexus_test_NA12878.bam.c4gh +``` + +## Decrypt + +```sh +crypt4gh decryptor --range 0-65535 --sk data/crypt4gh/keys/bob.sec \ + --sender-pk data/crypt4gh/keys/alice.pub \ + < data/crypt4gh/htsnexus_test_NA12878.bam.c4gh \ + > out.bam + +samtools view out.bam +(...) +SRR098401.61822403 83 11 5009470 60 76M = 5009376 -169 TCTTCTTGCCCTGGTGTTTCGCCGTTCCAGTGCCCCCTGCTGCAGACCATAAAGGATGGGACTTTGTTGAGGTAGG ?B6BDCD@I?JFI?FHHFEAIIAHHDIJHHFIIIIIJEIIFIJGHCIJDDEEHHHDEHHHCIGGEGFDGFGFBEDC X0:i:1 X1:i:0 MD:Z:76 RG:Z:SRR098401 AM:i:37 NM:i:0 SM:i:37 MQ:i:60 XT:A:U BQ:Z:@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@B + +samtools view: error reading file "out.bam" +samtools view: error closing "out.bam": -1 +``` + +The last samtools view error suggests that the returned bytes do not include BAM file termination. diff --git a/data/crypt4gh/htsnexus_test_NA12878.bam.c4gh b/data/crypt4gh/htsnexus_test_NA12878.bam.c4gh new file mode 100644 index 000000000..35c959faa Binary files /dev/null and b/data/crypt4gh/htsnexus_test_NA12878.bam.c4gh differ diff --git a/data/crypt4gh/htsnexus_test_NA12878.cram.c4gh b/data/crypt4gh/htsnexus_test_NA12878.cram.c4gh new file mode 100644 index 000000000..72dd83e43 Binary files /dev/null and b/data/crypt4gh/htsnexus_test_NA12878.cram.c4gh differ diff --git a/data/crypt4gh/keys/alice.pub b/data/crypt4gh/keys/alice.pub new file mode 100644 index 000000000..686226ca8 --- /dev/null +++ b/data/crypt4gh/keys/alice.pub @@ -0,0 +1,3 @@ +-----BEGIN CRYPT4GH PUBLIC KEY----- +ToQrpj4UfuLgxZRe1wSGIZtXC19fOEHUHe3RQy63qwM= +-----END CRYPT4GH PUBLIC KEY----- diff --git a/data/crypt4gh/keys/alice.sec b/data/crypt4gh/keys/alice.sec new file mode 100644 index 000000000..ecc3b8916 --- /dev/null +++ b/data/crypt4gh/keys/alice.sec @@ -0,0 +1,3 @@ +-----BEGIN CRYPT4GH PRIVATE KEY----- +YzRnaC12MQAEbm9uZQAEbm9uZQAgxi4tNmUO++HAApv9ryZB9S8QfqrWKKe5CunJuChH5vU= +-----END CRYPT4GH PRIVATE KEY----- diff --git a/data/crypt4gh/keys/bob.pub b/data/crypt4gh/keys/bob.pub new file mode 100644 index 000000000..990643c83 --- /dev/null +++ b/data/crypt4gh/keys/bob.pub @@ -0,0 +1,3 @@ +-----BEGIN CRYPT4GH PUBLIC KEY----- +TyKEXZPnfon6dj1kRXl6HumfZDzo/h60RIc8Wd0Ig2s= +-----END CRYPT4GH PUBLIC KEY----- diff --git a/data/crypt4gh/keys/bob.sec b/data/crypt4gh/keys/bob.sec new file mode 100644 index 000000000..0bc62269f --- /dev/null +++ b/data/crypt4gh/keys/bob.sec @@ -0,0 +1,3 @@ +-----BEGIN CRYPT4GH PRIVATE KEY----- +YzRnaC12MQAEbm9uZQAEbm9uZQAg6uLXNqcXAi6FRKzRBk2KBKF4BnmueySZv5MGzKjIPcI= +-----END CRYPT4GH PRIVATE KEY----- diff --git a/data/crypt4gh/sample1-bcbio-cancer.bcf.c4gh b/data/crypt4gh/sample1-bcbio-cancer.bcf.c4gh new file mode 100644 index 000000000..758b98f9a Binary files /dev/null and b/data/crypt4gh/sample1-bcbio-cancer.bcf.c4gh differ diff --git a/data/crypt4gh/spec-v4.3.vcf.gz.c4gh b/data/crypt4gh/spec-v4.3.vcf.gz.c4gh new file mode 100644 index 000000000..4e31cc24b Binary files /dev/null and b/data/crypt4gh/spec-v4.3.vcf.gz.c4gh differ diff --git a/deploy/Dockerfile.dockerignore b/deploy/Dockerfile.dockerignore index da4ac473c..0841f2283 100644 --- a/deploy/Dockerfile.dockerignore +++ b/deploy/Dockerfile.dockerignore @@ -1,5 +1,6 @@ * +!/async-crypt4gh !/htsget-actix !/htsget-config !/htsget-http diff --git a/deploy/README.md b/deploy/README.md index 8639377bf..44ecd689f 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -13,11 +13,12 @@ The CDK code in this directory constructs a CDK app from [`HtsgetLambdaStack`][h [`bin/settings.ts`][htsget-settings]: #### HtsgetSettings + These are general settings for the CDK deployment. | Name | Description | Type | -|----------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------| -| `config` | The location of the htsget-rs server config. This must be specified. This config file configures the htsget-rs server. See [htsget-config] for a list of available server configuration options. | `string` | +| -------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------- | +| `config` | The location of the htsget-rs server config. This must be specified. This config file configures the htsget-rs server. See [htsget-config] for a list of available server configuration options. | `string` | | `domain` | The domain name for the Route53 Hosted Zone that the htsget-rs server will be under. This must be specified. A hosted zone with this name will either be looked up or created depending on the value of [`lookupHostedZone?`](#lookupHostedZone). | `string` | | `authorizer` | Deployment options related to the authorizer. Note that this option allows specifying an AWS [JWT authorizer][jwt-authorizer]. The JWT authorizer automatically verifies tokens issued by a Cognito user pool. | [`HtsgetJwtAuthSettings`](#htsgetjwtauthsettings) | | `subDomain?` | The domain name prefix to use for the htsget-rs server. Together with the [`domain`](#domain), this specifies url that the htsget-rs server will be reachable under. Defaults to `"htsget"`. | `string` | @@ -25,12 +26,13 @@ These are general settings for the CDK deployment. | `lookupHostedZone?` | Whether to lookup the hosted zone with the domain name. Defaults to `true`. If `true`, attempts to lookup an existing hosted zone using the domain name. Set this to `false` if you want to create a new hosted zone with the domain name. | `boolean` | #### HtsgetJwtAuthSettings + These settings are used to determine if the htsget API gateway endpoint is configured to have a JWT authorizer or not. | Name | Description | Type | -|---------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------|------------| +| ------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | | `public` | Whether this deployment is public. If this is `true` then no authorizer is present on the API gateway and the options below have no effect. | `boolean` | -| `jwtAudience?` | A list of the intended recipients of the JWT. A valid JWT must provide an aud that matches at least one entry in this list. | `string[]` | +| `jwtAudience?` | A list of the intended recipients of the JWT. A valid JWT must provide an aud that matches at least one entry in this list. | `string[]` | | `cogUserPoolId?` | The cognito user pool id for the authorizer. If this is not set, then a new user pool is created. No user pool is created if [`public`](#public) is true. | `string` | The [`HtsgetSettings`](#htsgetsettings) are passed into [`HtsgetLambdaStack`][htsget-lambda-stack] in order to change the deployment config. An example of a public instance deployment @@ -47,20 +49,20 @@ can be found under [`bin/htsget-lambda.ts`][htsget-lambda-bin]. This uses the [` After installing the basic dependencies, complete the following steps: -1. Define CDK\_DEFAULT\_* env variables (if not defined already). You must be authenticated with your AWS cloud to run this step. +1. Define CDK_DEFAULT\_\* env variables (if not defined already). You must be authenticated with your AWS cloud to run this step. 1. Add the arm cross-compilation target to rust. 1. Install [cargo-lambda], as it is used to compile artifacts that are uploaded to aws lambda. -1. Define which configuration to use for htsget-rs on `cdk.json` as stated in aforementioned configuration section. +1. Define which configuration to use for htsget-rs on `cdk.json` as stated in aforementioned configuration section. Below is a summary of commands to run in this directory: -```sh +````sh ``export CDK_DEFAULT_ACCOUNT=`aws sts get-caller-identity --query Account --output text` export CDK_DEFAULT_REGION=`aws configure get region``` rustup target add aarch64-unknown-linux-gnu cargo install cargo-lambda npm install -``` +```` ### Deploy to AWS diff --git a/docs/crypt4gh/ARCHITECTURE.md b/docs/crypt4gh/ARCHITECTURE.md new file mode 100644 index 000000000..89cc9ca3e --- /dev/null +++ b/docs/crypt4gh/ARCHITECTURE.md @@ -0,0 +1,84 @@ +# Crypt4GH and htsget-rs + +Crypt4GH is a block-level encryption scheme that encrypts data in 64KiB blocks, and supports indexed reading of genomic data. +Currently, htsget-rs supports returning Crypt4GH encrypted data using `UrlStorage` and a custom protocol that defines +how the client, htsget-rs server, and `UrlStorage` server backend should interact. + +Assuming that data is stored in the `UrlStorage` backend encrypted, the aim of this protocol is for the client to receive +encrypted data via htsget-rs, which ought to be able to calculate the correct byte ranges according to the query. + +## The htsget-rs server + +In order for htsget-rs to do this, it needs a few pieces of information: +1. It needs all the usual information to calculate byte ranges for a normal htsget query on unencrypted data. This includes: + * The index file matching the queried file. Indexes are used by htsget-rs to calculate the closest byte ranges that include + the data requested in the query. For example, these indexes are BAI, TABIX, CRAI and CSI files. + * The header of the queried file. This header is used by htsget-rs to map a reference name in the query to the corresponding + index position. For example, these headers are lines at the beginning BAM files that start with `@`, or lines at the beginning. + of VCF files that start with `##`. + * The size of the queried file. The file size is used to simplify certain queries and for queries that involve byte ranges that + go to the end of the file. For example, this is the size of the BAM, CRAM, VCF, or BCF file. +2. It also needs the size of the Crypt4GH header which the client will use to decrypt the data. Because each Crypt4GH + file has a header at the start, this is needed in order to calculate byte ranges that align to the 64KiB Crypt4GH boundaries. + +With this information, htsget-rs can calculate the correct byte ranges to return to the client, who in cooperation with the `UrlStorage` +backend can fetch the required data and decrypt it. With this approach, the htsget-rs server needs minimal information to +process the request by passing public key information, which is used by htsget-rs and the `UrlStorage` backend to encrypt +data for the client. + +The htsget-rs server can pass client information to the backend, such as it's own public key in order to receive all the +information it requires to calculate byte ranges. With this information, htsget-rs can return the Crypt4GH header for the +client, and an edit list, alongside byte ranges, in order for the client to concatenate and decrypt the data. + +## The protocol + +The protocol starts when the client queries the htsget-rs server. The client sends their public key in a header called +`client-public-key`, which is base64 encoded. From here the htsget-rs queries the `UrlStorage` backend in three separate +requests. + +The htsget-rs sends the requests to the `UrlStorage` backend with its own `client-public-key`, which replaces the client's +`client-public-key`. It is expected that data returned to the htsget-rs server is encrypted with this public key. +The client's original `client-public-key` is mirrored back to the client, which is used when fetching +the URL tickets. Note, that the same header contains two different public keys at different points. Between the client and htsget-rs, +it contains the client's public key. Between htsget-rs and the `UrlStorage` backend, it contains htsget-rs' public key. + +All of the client's headers are forwarded to `UrlStorage`, except the `client-public-key` and `user-agent` (which htsget-rs replaces +with it's own values). Assuming the query is for a BAM file called `id`, and the resolvers do not transform the id, the requests are: + +1. A GET request to fetch the index of the queried file at: `https:///id.bam.bai`. +2. A HEAD request to get the **encrypted** file size at: `https:///id.bam.c4gh`. + 1. The expected response includes the `content-length` header which specifies the file size. This file size is relevant to ensure that the client receives correct end-coordinate URL tickets. + 2. Additionally, there can be a `server-additional-bytes` header which specifies the size of the Crypt4GH header that htsget-rs will receive + from the `UrlStorage` backend. This is used in [request 3.1](#3.1) to ensure that the byte ranges requested from the backend align to the Crypt4GH block boundaries. + If this header is not preset, then htsget-rs will request the full file. + 3. Optionally, there could be a `client-additional-bytes` header which specifies the size of the Crypt4GH header that the client will receive + when it queries the `UrlStorage` backend with the first URL ticket, with byte ranges: `range: "bytes=0-"`. + If this is not specified, it defaults to the `server-additional-bytes`. If `server-additional-bytes` is not present, + then the size of the Crypt4GH header is assumed to be same as the size of the Crypt4GH header that htsget-rs receives from `UrlStorage`. +3. A GET request to get the start of the **encrypted** file containing a Crypt4GH header and the BAM header, at: `https:///id.bam.c4gh`. + 1. This requests has additional headers to specify byte ranges for the start of the file: `range: "bytes=0-`". + +Here, `` and `` can be defined in the htsget config. Currently, htsget-rs implements this design. + +The following is a diagram of this process: +![architecture](./htsget-rs-crypt4gh.drawio.png) + +Note, currently this protocol doesn't include a way to verify the public key of the sender of the Crypt4GH data, although +this could be included in the future. This would involve returning additional public keys from `UrlStorage` to htsget-rs and from htsget-rs to the client. + +### Alternative designs + +For [request 3](#3) an alternative would be return unencrypted header data directly, howevever this has the +disadvantage of not using Crypt4GH for the `htsget-rs <-> UrlStorage` portion of the data transfer. + +In general, it is simple to convert unencrypted byte positions to encrypted byte positions, and vice versa, so it's not +as important whether the size and range returned by `UrlStorage` is unencrypted or encrypted. + +Other designs could explore the client or the `UrlStorage` backend doing more work in terms of calculating byte ranges. +For example, the htsget-rs server could return unencrypted byte ranges which the client needs to convert into encrypted +byte ranges. This is further complicated by the edit lists, which the client would need to obtain or calculate to discard the correct +bytes after decrypting. With the current design, it is convenient to place this logic in htsget-rs because the htsget +already contains logic for byte range calculations for unencrypted data. + + + diff --git a/docs/crypt4gh/htsget-rs-crypt4gh.drawio.png b/docs/crypt4gh/htsget-rs-crypt4gh.drawio.png new file mode 100644 index 000000000..e2675c488 Binary files /dev/null and b/docs/crypt4gh/htsget-rs-crypt4gh.drawio.png differ diff --git a/htsget-actix/Cargo.toml b/htsget-actix/Cargo.toml index 45bb84cd3..3e17c0148 100644 --- a/htsget-actix/Cargo.toml +++ b/htsget-actix/Cargo.toml @@ -13,6 +13,7 @@ repository = "https://github.com/umccr/htsget-rs" [features] s3-storage = ["htsget-config/s3-storage", "htsget-search/s3-storage", "htsget-http/s3-storage", "htsget-test/s3-storage"] url-storage = ["htsget-config/url-storage", "htsget-search/url-storage", "htsget-http/url-storage", "htsget-test/url-storage"] +crypt4gh = ["htsget-config/crypt4gh", "htsget-search/crypt4gh", "htsget-http/crypt4gh", "htsget-test/crypt4gh"] default = [] [dependencies] @@ -28,7 +29,7 @@ htsget-search = { version = "0.7.0", path = "../htsget-search", default-features htsget-config = { version = "0.9.0", path = "../htsget-config", default-features = false } htsget-test = { version = "0.6.0", path = "../htsget-test", features = ["http"], default-features = false } futures = { version = "0.3" } -tokio = { version = "1.28", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.29", features = ["macros", "rt-multi-thread"] } tracing-actix-web = "0.7" tracing = "0.1" diff --git a/htsget-actix/README.md b/htsget-actix/README.md index 698fc0a01..6de6d921a 100644 --- a/htsget-actix/README.md +++ b/htsget-actix/README.md @@ -110,6 +110,7 @@ are exposed in the public API. This crate has the following features: * `s3-storage`: used to enable `S3Storage` functionality. * `url-storage`: used to enable `UrlStorage` functionality. +* `crypt4gh`: used to enable Crypt4GH functionality. ## Benchmarks Benchmarks for this crate written using [Criterion.rs][criterion-rs], and aim to compare the performance of this crate with the diff --git a/htsget-actix/benches/request_benchmarks.rs b/htsget-actix/benches/request_benchmarks.rs index f3762335c..c9abbcf28 100644 --- a/htsget-actix/benches/request_benchmarks.rs +++ b/htsget-actix/benches/request_benchmarks.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::path::PathBuf; use std::process::{Child, Command}; use std::thread::sleep; @@ -61,11 +62,15 @@ fn request(url: reqwest::Url, json_content: &impl Serialize, client: &Client) -> client .get(&json_url.url) .headers( - json_url - .headers - .as_ref() - .unwrap_or(&Headers::default()) - .as_ref_inner() + (&HashMap::from_iter( + json_url + .headers + .as_ref() + .unwrap_or(&Headers::default()) + .as_ref_inner() + .iter() + .map(|(k, v)| (String::from(k), String::from(v))), + )) .try_into() .unwrap(), ) diff --git a/htsget-config/Cargo.toml b/htsget-config/Cargo.toml index 42893e612..428771613 100644 --- a/htsget-config/Cargo.toml +++ b/htsget-config/Cargo.toml @@ -13,6 +13,7 @@ repository = "https://github.com/umccr/htsget-rs" [features] s3-storage = [] url-storage = ["reqwest"] +crypt4gh = ["dep:async-crypt4gh", "dep:crypt4gh"] default = [] [dependencies] @@ -20,7 +21,7 @@ thiserror = "1.0" async-trait = "0.1" noodles = { version = "0.65", features = ["core"] } serde = { version = "1.0", features = ["derive"] } -serde_with = "3.0" +serde_with = "3.2" serde_regex = "1.1" regex = "1.8" figment = { version = "0.10", features = ["env", "toml"] } @@ -36,9 +37,12 @@ rustls = "0.21" # url-storage reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false, optional = true } +async-crypt4gh = { version = "0.1.0", path = "../async-crypt4gh", default-features = false, optional = true } +crypt4gh = { version = "0.4", git = "https://github.com/EGA-archive/crypt4gh-rust", optional = true } + [dev-dependencies] serde_json = "1.0" figment = { version = "0.10", features = ["test"] } -tokio = { version = "1.28", features = ["macros", "rt-multi-thread"] } -tempfile = "3.6" +tokio = { version = "1.29", features = ["macros", "rt-multi-thread"] } +tempfile = "3.7" rcgen = "0.12" diff --git a/htsget-config/README.md b/htsget-config/README.md index 0fde659b7..c6d158c44 100644 --- a/htsget-config/README.md +++ b/htsget-config/README.md @@ -171,18 +171,22 @@ To use `S3Storage`, build htsget-rs with the `s3-storage` feature enabled, and s `UrlStorage` is another storage backend which can be used to serve data from a remote HTTP URL. When using this storage backend, htsget-rs will fetch data from a `url` which is set in the config. It will also forward any headers received with the initial query, which is useful for authentication. To use `UrlStorage`, build htsget-rs with the `url-storage` feature enabled, and set the following options under `[resolvers.storage]`: -| Option | Description | Type | Default | -|--------------------------------------|------------------------------------------------------------------------------------------------------------------------------|--------------------------|-----------------------------------------------------------------------------------------------------------------| -| `url` | The URL to fetch data from. | HTTP URL | `"https://127.0.0.1:8081/"` | -| `response_url` | The URL to return to the client for fetching tickets. | HTTP URL | `"https://127.0.0.1:8081/"` | -| `forward_headers` | When constructing the URL tickets, copy HTTP headers received in the initial query. | Boolean | `true` | -| `header_blacklist` | List of headers that should not be forwarded | Array of headers | `[]` | -| `tls` | Additionally enables client authentication, or sets non-native root certificates for TLS. See [TLS](#tls) for more details. | TOML table | TLS is always allowed, however the default performs no client authentication and uses native root certificates. | - -When using `UrlStorage`, the following requests will be made to the `url`. -* `GET` request to fetch only the headers of the data file (e.g. `GET /data.bam`, with `Range: bytes=0-`). -* `GET` request to fetch the entire index file (e.g. `GET /data.bam.bai`). -* `HEAD` request on the data file to get its length (e.g. `HEAD /data.bam`). +| Option | Description | Type | Default | +|---------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|-----------------------------------------------------------------------------------------------------------------| +| `endpoint_index` | The URL to fetch index for a file. The request will be a GET request which expects the index file specific to a BAM/CRAM/VCF file. | HTTP URL | `"https://127.0.0.1:8081/"` | +| `endpoint_file` | The URL to fetch underlying for a file. The request will be a GET request which expects to get the decrypted underlying header from a BAM/CRAM/VCF file. | HTTP URL | `"https://127.0.0.1:8081/"` | +| `response_url` | The URL to return to the client for fetching tickets. | HTTP URL | `"https://127.0.0.1:8081/"` | +| `forward_headers` | When constructing the URL tickets, copy HTTP headers received in the initial query. Note, the headers received with the query are always forwarded to the `url`. | Boolean | `true` | +| `user_agent` | A user agent to provide when making requests to the URLs. | String | A combination of the cargo package name and version. For example, `htsget-search/0.6.6`. | +| `danger_accept_invalid_certs` | Trusted invalid certificates, such as self-signed certificates. Only affects TLS on the HTTP client in `UrlStorage`. | Boolean | false | +| `header_blacklist` | List of headers that should not be forwarded | Array of headers | `[]` | +| `tls` | Additionally enables client authentication, or sets non-native root certificates for TLS. See [TLS](#tls) for more details. | TOML table | TLS is always allowed, however the default performs no client authentication and uses native root certificates. | + +When using `UrlStorage`, the following requests will be made: +* `GET` request to fetch only the crypt4gh headers size of the data file (e.g. `GET /data.bam`), URL used is configured via `endpoint_crypt4gh_header`. +* `GET` request to fetch only the headers of the data file (e.g. `GET /data.bam`, with `Range: bytes=0-`), URL used is configured via `endpoint_header`. +* `GET` request to fetch the entire index file (e.g. `GET /data.bam.bai`), URL used is configured via `endpoint_index`. +* `HEAD` request on the data file to get its length (e.g. `HEAD /data.bam`), URL used is configured via `endpoint_head`. By default, all headers received in the initial query will be included when making these requests. To exclude certain headers from being forwarded, set the `header_blacklist` option. Note that the blacklisted headers are removed from the requests made to `url` and from the URL tickets as well. @@ -224,6 +228,7 @@ bucket = 'bucket' ``` `UrlStorage` can only be specified manually. + Example of a resolver with `UrlStorage`: ```toml [[resolvers]] @@ -249,7 +254,7 @@ Additionally, the resolver component has a feature, which allows resolving IDs b This is useful as allows the resolver to match an ID, if a particular set of query parameters are also present. For example, a resolver can be set to only resolve IDs if the format is also BAM. -This component can be configured by setting the `[resolver.allow_guard]` table with. The following options are available to restrict which queries are resolved by a resolver: +This component can be configured by setting the `[resolver.allow_guard]` table. The following options are available to restrict which queries are resolved by a resolver: | Option | Description | Type | Default | |-------------------------|-----------------------------------------------------------------------------------------|-----------------------------------------------------------------------|-------------------------------------| @@ -341,6 +346,33 @@ Further TLS examples are available under [`examples/config-files`][examples-conf [rustls]: https://github.com/rustls/rustls [mkcert]: https://github.com/FiloSottile/mkcert +#### Object type +There is additional configuration that changes the way a resolver treats an object. + +By default, all objects are considered `Regular`. However, the `object_type` can be configured to decrypt Crypt4GH files. + +This component can be configured by setting the `[resolver.object_type]` table in order to enable Crypt4GH: + +| Option | Description | Type | Default | +|----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------|------------------------------------------------| +| `send_encrypted_to_client` | Whether to send data encrypted byte ranges to the client. Note, this does not affect data sent to the `UrlStorage` backend, which remains encrypted if type Crypt4GH object type is used. | Boolean | Not set | +| `private_key` | Path to the private key used for decrypted Crypt4GH data. | Path | Not set, generates ephemeral keys if not set. | +| `public_key` | Path to the public key used for decrypted Crypt4GH data. | Path | Not set, generates ephemeral keys if not set. | + +Or, to generate keys uniquely for each request, the `private_key` and `public_key` options should not be set. + +For example to enable Crypt4GH for a resolver, build htsget-rs with the `crypt4gh` feature enabled, and set the following options under `[resolvers.object_type]`: + +```toml +[resolvers.object_type] +# Specify the keys that htsget will use manually. +send_encrypted_to_client = true +private_key = "data/crypt4gh/keys/bob.sec" # pragma: allowlist secret +public_key = "data/crypt4gh/keys/bob.pub" +``` + +Note, currently this functionality only works with `UrlStorage`. + #### Config file location The htsget-rs binaries ([htsget-actix] and [htsget-lambda]) support some command line options. The config file location can @@ -500,6 +532,7 @@ regex, and changing it by using a substitution string. This crate has the following features: * `s3-storage`: used to enable `S3Storage` functionality. * `url-storage`: used to enable `UrlStorage` functionality. +* `crypt4gh`: used to enable Crypt4GH functionality. ## License @@ -510,4 +543,4 @@ This project is licensed under the [MIT license][license]. [virtual-addressing]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html#virtual-hosted-style-access [minio-deployment]: ../deploy/examples/minio/README.md [license]: LICENSE -[minio]: https://min.io/ \ No newline at end of file +[minio]: https://min.io/ diff --git a/htsget-config/examples/config-files/crypt4gh.toml b/htsget-config/examples/config-files/crypt4gh.toml new file mode 100644 index 000000000..9c7602fdc --- /dev/null +++ b/htsget-config/examples/config-files/crypt4gh.toml @@ -0,0 +1,32 @@ +# An example that treats files as Crypt4GH encrypted. +# Run with `cargo run -p htsget-actix --features crypt4gh,url-storage -- --config crypt4gh.toml` + +ticket_server_addr = "0.0.0.0:7000" +#data_server_addr = "0.0.0.0:8081" + +[[resolvers]] +regex = ".*" +substitution_string = "$0" + +[resolvers.object_type] +# This option specified Crypt4GH files. +send_encrypted_to_client = true +# Specify the keys that htsget will use manually. These can be commented out to generate keys. +private_key = "data/crypt4gh/keys/bob.sec" # pragma: allowlist secret +public_key = "data/crypt4gh/keys/bob.pub" + +[resolvers.storage] +response_url = "https://example.com/" +forward_headers = false + +# Add a custom user agent. +#user_agent = "user-agent" +# Trust invalid certificates. +#danger_accept_invalid_certs = true + +# Add a certificate to the client. +#tls.root_store = "htsget-config/examples/config-files/cert.pem" + +[resolvers.storage.endpoints] +file = "https://example.com/" +index = "https://example.com/" \ No newline at end of file diff --git a/htsget-config/examples/config-files/default.toml b/htsget-config/examples/config-files/default.toml index 9e12c7e13..fc2d67b3c 100644 --- a/htsget-config/examples/config-files/default.toml +++ b/htsget-config/examples/config-files/default.toml @@ -38,3 +38,5 @@ allow_classes = [ "body", "header", ] + +[resolvers.object_type] \ No newline at end of file diff --git a/htsget-config/examples/config-files/url_storage.toml b/htsget-config/examples/config-files/url_storage.toml index b71b6711a..315a4a926 100644 --- a/htsget-config/examples/config-files/url_storage.toml +++ b/htsget-config/examples/config-files/url_storage.toml @@ -16,10 +16,13 @@ regex = ".*" substitution_string = "$0" [resolvers.storage] -url = "http://127.0.0.1:8081" response_url = "https://127.0.0.1:8081" forward_headers = true +[resolvers.storage.endpoints] +file = "https://example.com/" +index = "https://example.com/" + # Set client authentication #tls.key = "key.pem" #tls.cert = "cert.pem" diff --git a/htsget-config/src/config/mod.rs b/htsget-config/src/config/mod.rs index 4c9ac9baa..345ceeee3 100644 --- a/htsget-config/src/config/mod.rs +++ b/htsget-config/src/config/mod.rs @@ -410,7 +410,7 @@ impl Config { pub fn from_path(path: &Path) -> io::Result { let config: Self = from_path(path)?; - Ok(config.resolvers_from_data_server_config()) + Ok(config.validate()?) } /// Setup tracing, using a global subscriber. @@ -465,27 +465,28 @@ impl Config { self.resolvers } - /// Set the local resolvers from the data server config. - pub fn resolvers_from_data_server_config(self) -> Self { + /// Validate any settings before constructing config. + pub fn validate(self) -> Result { let Config { - formatting_style: formatting, + formatting_style, ticket_server, data_server, service_info, - mut resolvers, + resolvers, } = self; - resolvers - .iter_mut() - .for_each(|resolver| resolver.resolvers_from_data_server_config(&data_server)); + let resolvers = resolvers + .into_iter() + .map(|resolver| resolver.validate(&data_server)) + .collect::>>()?; - Self::new( - formatting, + Ok(Self::new( + formatting_style, ticket_server, data_server, service_info, resolvers, - ) + )) } } @@ -493,10 +494,10 @@ impl Config { pub(crate) mod tests { use std::fmt::Display; - use crate::config::parser::from_str; use figment::Jail; use http::uri::Authority; + use crate::config::parser::from_str; use crate::storage::Storage; use crate::tls::tests::with_test_certificates; use crate::types::Scheme::Http; @@ -526,12 +527,14 @@ pub(crate) mod tests { test_fn( from_path::(path) .map_err(|err| err.to_string())? - .resolvers_from_data_server_config(), + .validate() + .unwrap(), ); test_fn( from_str::(contents.unwrap_or("")) .map_err(|err| err.to_string())? - .resolvers_from_data_server_config(), + .validate() + .unwrap(), ); Ok(()) diff --git a/htsget-config/src/config/parser.rs b/htsget-config/src/config/parser.rs index fb4bb3cf1..014546ce1 100644 --- a/htsget-config/src/config/parser.rs +++ b/htsget-config/src/config/parser.rs @@ -1,13 +1,15 @@ -use crate::config::Config; -use figment::providers::{Env, Format, Serialized, Toml}; -use figment::Figment; -use serde::Deserialize; use std::fmt::Debug; use std::io; use std::io::ErrorKind; use std::path::Path; + +use figment::providers::{Env, Format, Serialized, Toml}; +use figment::Figment; +use serde::Deserialize; use tracing::{info, instrument}; +use crate::config::Config; + const ENVIRONMENT_VARIABLE_PREFIX: &str = "HTSGET_"; /// A struct to represent a string or a path, used for parsing and deserializing config. diff --git a/htsget-config/src/error.rs b/htsget-config/src/error.rs index b92c42e68..6f7b69adf 100644 --- a/htsget-config/src/error.rs +++ b/htsget-config/src/error.rs @@ -10,15 +10,14 @@ pub type Result = result::Result; pub enum Error { #[error("io found: {0}")] IoError(String), - #[error("failed to parse args found: {0}")] ArgParseError(String), - #[error("failed to setup tracing: {0}")] TracingError(String), - #[error("parse error: {0}")] ParseError(String), + #[error("config error: {0}")] + ConfigError(String), } impl From for io::Error { diff --git a/htsget-config/src/resolver/allow_guard.rs b/htsget-config/src/resolver/allow_guard.rs new file mode 100644 index 000000000..8d45d0515 --- /dev/null +++ b/htsget-config/src/resolver/allow_guard.rs @@ -0,0 +1,176 @@ +use std::collections::HashSet; + +use serde::{Deserialize, Serialize}; +use serde_with::with_prefix; + +use crate::types::Format::{Bam, Bcf, Cram, Vcf}; +use crate::types::{Class, Fields, Format, Interval, Query, TaggedTypeAll, Tags}; + +/// Determines whether the query matches for use with the storage. +pub trait QueryAllowed { + /// Does this query match. + fn query_allowed(&self, query: &Query) -> bool; +} + +with_prefix!(allow_interval_prefix "allow_interval_"); + +/// A query guard represents query parameters that can be allowed to storage for a given query. +#[derive(Serialize, Clone, Debug, Deserialize, PartialEq, Eq)] +#[serde(default)] +pub struct AllowGuard { + allow_reference_names: ReferenceNames, + allow_fields: Fields, + allow_tags: Tags, + allow_formats: Vec, + allow_classes: Vec, + #[serde(flatten, with = "allow_interval_prefix")] + allow_interval: Interval, +} + +/// Reference names that can be matched. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] +#[serde(untagged)] +pub enum ReferenceNames { + Tagged(TaggedTypeAll), + List(HashSet), +} + +impl AllowGuard { + /// Create a new allow guard. + pub fn new( + allow_reference_names: ReferenceNames, + allow_fields: Fields, + allow_tags: Tags, + allow_formats: Vec, + allow_classes: Vec, + allow_interval: Interval, + ) -> Self { + Self { + allow_reference_names, + allow_fields, + allow_tags, + allow_formats, + allow_classes, + allow_interval, + } + } + + /// Get allow formats. + pub fn allow_formats(&self) -> &[Format] { + &self.allow_formats + } + + /// Get allow classes. + pub fn allow_classes(&self) -> &[Class] { + &self.allow_classes + } + + /// Get allow interval. + pub fn allow_interval(&self) -> Interval { + self.allow_interval + } + + /// Get allow reference names. + pub fn allow_reference_names(&self) -> &ReferenceNames { + &self.allow_reference_names + } + + /// Get allow fields. + pub fn allow_fields(&self) -> &Fields { + &self.allow_fields + } + + /// Get allow tags. + pub fn allow_tags(&self) -> &Tags { + &self.allow_tags + } +} + +impl Default for AllowGuard { + fn default() -> Self { + Self { + allow_formats: vec![Bam, Cram, Vcf, Bcf], + allow_classes: vec![Class::Body, Class::Header], + allow_interval: Default::default(), + allow_reference_names: ReferenceNames::Tagged(TaggedTypeAll::All), + allow_fields: Fields::Tagged(TaggedTypeAll::All), + allow_tags: Tags::Tagged(TaggedTypeAll::All), + } + } +} + +impl QueryAllowed for ReferenceNames { + fn query_allowed(&self, query: &Query) -> bool { + match (self, &query.reference_name()) { + (ReferenceNames::Tagged(TaggedTypeAll::All), _) => true, + (ReferenceNames::List(reference_names), Some(reference_name)) => { + reference_names.contains(*reference_name) + } + (ReferenceNames::List(_), None) => false, + } + } +} + +impl QueryAllowed for Fields { + fn query_allowed(&self, query: &Query) -> bool { + match (self, &query.fields()) { + (Fields::Tagged(TaggedTypeAll::All), _) => true, + (Fields::List(self_fields), Fields::List(query_fields)) => { + self_fields.is_subset(query_fields) + } + (Fields::List(_), Fields::Tagged(TaggedTypeAll::All)) => false, + } + } +} + +impl QueryAllowed for Tags { + fn query_allowed(&self, query: &Query) -> bool { + match (self, &query.tags()) { + (Tags::Tagged(TaggedTypeAll::All), _) => true, + (Tags::List(self_tags), Tags::List(query_tags)) => self_tags.is_subset(query_tags), + (Tags::List(_), Tags::Tagged(TaggedTypeAll::All)) => false, + } + } +} + +impl QueryAllowed for AllowGuard { + fn query_allowed(&self, query: &Query) -> bool { + self.allow_formats.contains(&query.format()) + && self.allow_classes.contains(&query.class()) + && self + .allow_interval + .contains(query.interval().start().unwrap_or(u32::MIN)) + && self + .allow_interval + .contains(query.interval().end().unwrap_or(u32::MAX)) + && self.allow_reference_names.query_allowed(query) + && self.allow_fields.query_allowed(query) + && self.allow_tags.query_allowed(query) + } +} + +#[cfg(test)] +mod tests { + use crate::config::tests::test_config_from_file; + + use super::*; + + #[test] + fn config_resolvers_guard_file() { + test_config_from_file( + r#" + [[resolvers]] + regex = "regex" + + [resolvers.allow_guard] + allow_formats = ["BAM"] + "#, + |config| { + assert_eq!( + config.resolvers().first().unwrap().allow_formats(), + &vec![Bam] + ); + }, + ); + } +} diff --git a/htsget-config/src/resolver.rs b/htsget-config/src/resolver/mod.rs similarity index 60% rename from htsget-config/src/resolver.rs rename to htsget-config/src/resolver/mod.rs index b68a91b1e..38c66f2d1 100644 --- a/htsget-config/src/resolver.rs +++ b/htsget-config/src/resolver/mod.rs @@ -1,21 +1,24 @@ -use std::collections::HashSet; use std::result; use async_trait::async_trait; use regex::{Error, Regex}; use serde::{Deserialize, Serialize}; -use serde_with::with_prefix; use tracing::instrument; use crate::config::DataServerConfig; +use crate::error; +use crate::resolver::allow_guard::{AllowGuard, QueryAllowed, ReferenceNames}; +use crate::resolver::object::ObjectType; use crate::storage::local::LocalStorage; #[cfg(feature = "s3-storage")] use crate::storage::s3::S3Storage; #[cfg(feature = "url-storage")] use crate::storage::url::UrlStorageClient; use crate::storage::{ResolvedId, Storage, TaggedStorageTypes}; -use crate::types::Format::{Bam, Bcf, Cram, Vcf}; -use crate::types::{Class, Fields, Format, Interval, Query, Response, Result, TaggedTypeAll, Tags}; +use crate::types::{Class, Fields, Format, Interval, Query, Response, Result, Tags}; + +pub mod allow_guard; +pub mod object; /// A trait which matches the query id, replacing the match in the substitution text. pub trait IdResolver { @@ -48,12 +51,6 @@ pub trait StorageResolver { ) -> Option>; } -/// Determines whether the query matches for use with the storage. -pub trait QueryAllowed { - /// Does this query match. - fn query_allowed(&self, query: &Query) -> bool; -} - /// A regex storage is a storage that matches ids using Regex. #[derive(Serialize, Debug, Clone, Deserialize)] #[serde(default)] @@ -64,6 +61,7 @@ pub struct Resolver { substitution_string: String, storage: Storage, allow_guard: AllowGuard, + object_type: ObjectType, } /// A type which holds a resolved storage and an resolved id. @@ -93,185 +91,32 @@ impl ResolvedStorage { } } -impl ResolvedId {} - -with_prefix!(allow_interval_prefix "allow_interval_"); - -/// A query guard represents query parameters that can be allowed to storage for a given query. -#[derive(Serialize, Clone, Debug, Deserialize, PartialEq, Eq)] -#[serde(default)] -pub struct AllowGuard { - allow_reference_names: ReferenceNames, - allow_fields: Fields, - allow_tags: Tags, - allow_formats: Vec, - allow_classes: Vec, - #[serde(flatten, with = "allow_interval_prefix")] - allow_interval: Interval, -} - -/// Reference names that can be matched. -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] -#[serde(untagged)] -pub enum ReferenceNames { - Tagged(TaggedTypeAll), - List(HashSet), -} - -impl AllowGuard { - /// Create a new allow guard. - pub fn new( - allow_reference_names: ReferenceNames, - allow_fields: Fields, - allow_tags: Tags, - allow_formats: Vec, - allow_classes: Vec, - allow_interval: Interval, - ) -> Self { - Self { - allow_reference_names, - allow_fields, - allow_tags, - allow_formats, - allow_classes, - allow_interval, - } - } - - /// Get allow formats. - pub fn allow_formats(&self) -> &[Format] { - &self.allow_formats - } - - /// Get allow classes. - pub fn allow_classes(&self) -> &[Class] { - &self.allow_classes - } - - /// Get allow interval. - pub fn allow_interval(&self) -> Interval { - self.allow_interval - } - - /// Get allow reference names. - pub fn allow_reference_names(&self) -> &ReferenceNames { - &self.allow_reference_names - } - - /// Get allow fields. - pub fn allow_fields(&self) -> &Fields { - &self.allow_fields - } - - /// Get allow tags. - pub fn allow_tags(&self) -> &Tags { - &self.allow_tags - } - - /// Set the allow reference names. - pub fn with_allow_reference_names(mut self, allow_reference_names: ReferenceNames) -> Self { - self.allow_reference_names = allow_reference_names; - self - } - - /// Set the allow fields. - pub fn with_allow_fields(mut self, allow_fields: Fields) -> Self { - self.allow_fields = allow_fields; - self - } - - /// Set the allow tags. - pub fn with_allow_tags(mut self, allow_tags: Tags) -> Self { - self.allow_tags = allow_tags; - self - } - - /// Set the allow formats. - pub fn with_allow_formats(mut self, allow_formats: Vec) -> Self { - self.allow_formats = allow_formats; - self - } - - /// Set the allow classes. - pub fn with_allow_classes(mut self, allow_classes: Vec) -> Self { - self.allow_classes = allow_classes; - self - } - - /// Set the allow interval. - pub fn with_allow_interval(mut self, allow_interval: Interval) -> Self { - self.allow_interval = allow_interval; - self - } -} - -impl Default for AllowGuard { - fn default() -> Self { - Self { - allow_formats: vec![Bam, Cram, Vcf, Bcf], - allow_classes: vec![Class::Body, Class::Header], - allow_interval: Default::default(), - allow_reference_names: ReferenceNames::Tagged(TaggedTypeAll::All), - allow_fields: Fields::Tagged(TaggedTypeAll::All), - allow_tags: Tags::Tagged(TaggedTypeAll::All), - } - } -} - -impl QueryAllowed for ReferenceNames { - fn query_allowed(&self, query: &Query) -> bool { - match (self, &query.reference_name()) { - (ReferenceNames::Tagged(TaggedTypeAll::All), _) => true, - (ReferenceNames::List(reference_names), Some(reference_name)) => { - reference_names.contains(*reference_name) - } - (ReferenceNames::List(_), None) => false, - } - } -} - -impl QueryAllowed for Fields { - fn query_allowed(&self, query: &Query) -> bool { - match (self, &query.fields()) { - (Fields::Tagged(TaggedTypeAll::All), _) => true, - (Fields::List(self_fields), Fields::List(query_fields)) => { - self_fields.is_subset(query_fields) - } - (Fields::List(_), Fields::Tagged(TaggedTypeAll::All)) => false, - } - } -} - -impl QueryAllowed for Tags { - fn query_allowed(&self, query: &Query) -> bool { - match (self, &query.tags()) { - (Tags::Tagged(TaggedTypeAll::All), _) => true, - (Tags::List(self_tags), Tags::List(query_tags)) => self_tags.is_subset(query_tags), - (Tags::List(_), Tags::Tagged(TaggedTypeAll::All)) => false, +impl IdResolver for Resolver { + #[instrument(level = "trace", skip(self), ret)] + fn resolve_id(&self, query: &Query) -> Option { + if self.regex.is_match(query.id()) && self.allow_guard.query_allowed(query) { + Some(ResolvedId::new( + self + .regex + .replace(query.id(), &self.substitution_string) + .to_string(), + )) + } else { + None } } } -impl QueryAllowed for AllowGuard { - fn query_allowed(&self, query: &Query) -> bool { - self.allow_formats.contains(&query.format()) - && self.allow_classes.contains(&query.class()) - && self - .allow_interval - .contains(query.interval().start().unwrap_or(u32::MIN)) - && self - .allow_interval - .contains(query.interval().end().unwrap_or(u32::MAX)) - && self.allow_reference_names.query_allowed(query) - && self.allow_fields.query_allowed(query) - && self.allow_tags.query_allowed(query) - } -} - impl Default for Resolver { fn default() -> Self { - Self::new(Storage::default(), ".*", "$0", AllowGuard::default()) - .expect("expected valid storage") + Self::new( + Storage::default(), + ".*", + "$0", + AllowGuard::default(), + ObjectType::default(), + ) + .expect("expected valid storage") } } @@ -282,22 +127,36 @@ impl Resolver { regex: &str, replacement_string: &str, allow_guard: AllowGuard, + object_type: ObjectType, ) -> result::Result { Ok(Self { regex: Regex::new(regex)?, substitution_string: replacement_string.to_string(), storage, allow_guard, + object_type, }) } - /// Set the local resolvers from the data server config. - pub fn resolvers_from_data_server_config(&mut self, config: &DataServerConfig) { + /// Validate resolvers and set the local resolvers from the data server config. + pub fn validate(mut self, config: &DataServerConfig) -> error::Result { if let Storage::Tagged(TaggedStorageTypes::Local) = self.storage() { if let Some(local_storage) = config.into() { self.storage = Storage::Local { local_storage }; } } + + #[cfg(all(feature = "crypt4gh", feature = "url-storage"))] + // `Crypt4GHGenerate` is only supported for `UrlStorage`. + if let ObjectType::GenerateKeys { .. } = self.object_type() { + if !matches!(self.storage(), Storage::Url { .. }) { + return Err(error::Error::ParseError( + "generating Crypt4GH keys is not supported if not using `UrlStorage`".to_string(), + )); + } + }; + + Ok(self) } /// Get the match associated with the capture group at index `i` using the `regex_match`. @@ -337,38 +196,27 @@ impl Resolver { /// Get allow interval. pub fn allow_interval(&self) -> Interval { - self.allow_guard.allow_interval + self.allow_guard.allow_interval() } /// Get allow reference names. pub fn allow_reference_names(&self) -> &ReferenceNames { - &self.allow_guard.allow_reference_names + self.allow_guard.allow_reference_names() } /// Get allow fields. pub fn allow_fields(&self) -> &Fields { - &self.allow_guard.allow_fields + self.allow_guard.allow_fields() } /// Get allow tags. pub fn allow_tags(&self) -> &Tags { - &self.allow_guard.allow_tags + self.allow_guard.allow_tags() } -} -impl IdResolver for Resolver { - #[instrument(level = "trace", skip(self), ret)] - fn resolve_id(&self, query: &Query) -> Option { - if self.regex.is_match(query.id()) && self.allow_guard.query_allowed(query) { - Some(ResolvedId::new( - self - .regex - .replace(query.id(), &self.substitution_string) - .to_string(), - )) - } else { - None - } + /// Get the object type config. + pub fn object_type(&self) -> &ObjectType { + &self.object_type } } @@ -383,6 +231,7 @@ impl StorageResolver for Resolver { let _matched_id = query.id().to_string(); query.set_id(resolved_id.into_inner()); + query.set_object_type(self.object_type().clone()); if let Some(response) = self.storage().resolve_local_storage::(query).await { return Some(response); @@ -438,15 +287,27 @@ impl StorageResolver for &[Resolver] { mod tests { use http::uri::Authority; + #[cfg(feature = "s3-storage")] + use {crate::storage::s3::S3Storage, std::collections::HashSet}; #[cfg(feature = "url-storage")] use { - crate::storage::url, crate::storage::url::ValidatedUrl, http::Uri as InnerUrl, - reqwest::ClientBuilder, std::str::FromStr, + crate::storage::url, crate::storage::url::endpoints::Endpoints, + crate::storage::url::ValidatedUrl, http::Uri as InnerUrl, reqwest::ClientBuilder, + std::str::FromStr, + }; + #[cfg(all(feature = "crypt4gh", feature = "url-storage"))] + use { + crate::tls::crypt4gh::Crypt4GHKeyPair, + crate::tls::tests::with_test_certificates, + async_crypt4gh::{KeyPair, PublicKey}, + crypt4gh::keys::{generate_keys, get_private_key, get_public_key}, + rustls::PrivateKey, + std::path::Path, + tempfile::TempDir, }; use crate::config::tests::{test_config_from_env, test_config_from_file}; - #[cfg(feature = "s3-storage")] - use crate::storage::s3::S3Storage; + use crate::types::Format::Bam; use crate::types::Scheme::Http; use crate::types::Url; @@ -472,7 +333,7 @@ mod tests { async fn from_url(url_storage: &UrlStorageClient, _: &Query) -> Result { Ok(Response::new( Bam, - vec![Url::new(url_storage.url().to_string())], + vec![Url::new(url_storage.endpoints().file().to_string())], )) } } @@ -490,10 +351,11 @@ mod tests { "id", "$0-test", AllowGuard::default(), + ObjectType::default(), ) .unwrap(); - expected_resolved_request(resolver, "127.0.0.1:8080").await; + expected_resolved_request(&vec![resolver], "127.0.0.1:8080").await; } #[cfg(feature = "s3-storage")] @@ -505,10 +367,11 @@ mod tests { "(id)-1", "$1-test", AllowGuard::default(), + ObjectType::default(), ) .unwrap(); - expected_resolved_request(resolver, "id").await; + expected_resolved_request(&vec![resolver], "id").await; } #[cfg(feature = "s3-storage")] @@ -519,37 +382,85 @@ mod tests { "(id)-1", "$1-test", AllowGuard::default(), + ObjectType::default(), ) .unwrap(); - expected_resolved_request(resolver, "id").await; + expected_resolved_request(&vec![resolver], "id").await; } #[cfg(feature = "url-storage")] #[tokio::test] async fn resolver_resolve_url_request() { - let client = ClientBuilder::new().build().unwrap(); - let url_storage = UrlStorageClient::new( - ValidatedUrl(url::Url { - inner: InnerUrl::from_str("https://example.com/").unwrap(), - }), - ValidatedUrl(url::Url { - inner: InnerUrl::from_str("https://example.com/").unwrap(), - }), - true, - vec![], - client, - ); + let url_storage = create_url_storage("https://example.com/"); let resolver = Resolver::new( Storage::Url { url_storage }, "(id)-1", "$1-test", AllowGuard::default(), + ObjectType::default(), ) .unwrap(); - expected_resolved_request(resolver, "https://example.com/").await; + expected_resolved_request(&vec![resolver], "https://example.com/").await; + } + + #[cfg(feature = "url-storage")] + fn create_url_storage(endpoint: &str) -> UrlStorageClient { + let client = ClientBuilder::new().build().unwrap(); + + UrlStorageClient::new( + Endpoints::new( + ValidatedUrl(url::Url { + inner: InnerUrl::from_str(endpoint).unwrap(), + }), + ValidatedUrl(url::Url { + inner: InnerUrl::from_str(endpoint).unwrap(), + }), + ), + ValidatedUrl(url::Url { + inner: InnerUrl::from_str(endpoint).unwrap(), + }), + true, + vec![], + Some("user-agent".to_string()), + client, + ) + } + + #[cfg(all(feature = "url-storage", feature = "crypt4gh"))] + #[tokio::test] + async fn resolver_conflicting_object_type() { + let resolvers = vec![ + Resolver::new( + Storage::Url { + url_storage: create_url_storage("127.0.0.1:8080"), + }, + "(id)-2", + "$1-test", + AllowGuard::default(), + ObjectType::GenerateKeys { + send_encrypted_to_client: true, + }, + ) + .unwrap(), + Resolver::new( + Storage::Url { + url_storage: create_url_storage("127.0.0.1:8081"), + }, + "(id)-1", + "$1-test", + AllowGuard::default(), + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(KeyPair::new(PrivateKey(vec![]), PublicKey::new(vec![]))), + send_encrypted_to_client: true, + }, + ) + .unwrap(), + ]; + + expected_resolved_request(&resolvers, "127.0.0.1:8081").await; } #[test] @@ -559,6 +470,7 @@ mod tests { "^(id)/(?P.*)$", "$0", AllowGuard::default(), + ObjectType::default(), ) .unwrap(); let first_match = resolver.get_match(1, "id/key").unwrap(); @@ -568,8 +480,14 @@ mod tests { #[test] fn resolver_get_matches_no_captures() { - let resolver = - Resolver::new(Storage::default(), "^id/id$", "$0", AllowGuard::default()).unwrap(); + let resolver = Resolver::new( + Storage::default(), + "^id/id$", + "$0", + AllowGuard::default(), + ObjectType::default(), + ) + .unwrap(); let first_match = resolver.get_match(1, "/id/key"); assert_eq!(first_match, None); @@ -577,11 +495,17 @@ mod tests { #[test] fn resolver_resolve_id() { - let resolver = - Resolver::new(Storage::default(), "id", "$0-test", AllowGuard::default()).unwrap(); + let resolver = Resolver::new( + Storage::default(), + "id", + "$0-test", + AllowGuard::default(), + ObjectType::default(), + ) + .unwrap(); assert_eq!( resolver - .resolve_id(&Query::new_with_default_request("id", Bam)) + .resolve_id(&Query::new_with_defaults("id", Bam)) .unwrap() .into_inner(), "id-test" @@ -596,6 +520,7 @@ mod tests { "^(id-1)(.*)$", "$1-test-1", AllowGuard::default(), + ObjectType::default(), ) .unwrap(), Resolver::new( @@ -603,6 +528,7 @@ mod tests { "^(id-2)(.*)$", "$1-test-2", AllowGuard::default(), + ObjectType::default(), ) .unwrap(), ]; @@ -610,7 +536,7 @@ mod tests { assert_eq!( resolver .as_slice() - .resolve_id(&Query::new_with_default_request("id-1", Bam)) + .resolve_id(&Query::new_with_defaults("id-1", Bam)) .unwrap() .into_inner(), "id-1-test-1" @@ -618,7 +544,7 @@ mod tests { assert_eq!( resolver .as_slice() - .resolve_id(&Query::new_with_default_request("id-2", Bam)) + .resolve_id(&Query::new_with_defaults("id-2", Bam)) .unwrap() .into_inner(), "id-2-test-2" @@ -704,10 +630,133 @@ mod tests { ); } - async fn expected_resolved_request(resolver: Resolver, expected_id: &str) { + #[cfg(feature = "s3-storage")] + #[test] + fn config_storage_tagged_url_storage_file() { + test_config_from_file( + r#" + [[resolvers]] + regex = "regex" + storage = "S3" + "#, + |config| { + println!("{:?}", config.resolvers().first().unwrap().storage()); + assert!(matches!( + config.resolvers().first().unwrap().storage(), + Storage::Tagged(TaggedStorageTypes::S3) + )); + }, + ); + } + + #[cfg(all(feature = "crypt4gh", feature = "url-storage"))] + #[test] + fn config_resolvers_with_predefined_key_pair() { + with_crypt4gh_keys( + |_, private_key_path, public_key_path, private_key, public_key| { + test_config_from_file( + &format!( + r#" + [[resolvers]] + regex = "regex" + + [resolvers.object_type] + send_encrypted_to_client = false + private_key = "{}" + public_key = "{}" + "#, + private_key_path.to_string_lossy(), + public_key_path.to_string_lossy() + ), + |config| { + assert!(matches!( + config.resolvers().first().unwrap().object_type(), + ObjectType::Crypt4GH { crypt4gh, send_encrypted_to_client } if crypt4gh.key_pair().private_key().0 == private_key && crypt4gh.key_pair().public_key().clone().into_inner() == public_key + && !*send_encrypted_to_client + )); + }, + ); + }, + ); + } + + #[cfg(all(feature = "crypt4gh", feature = "url-storage"))] + #[test] + fn config_resolvers_with_generate_key_pair() { + with_test_certificates(|path, _, _| { + let key_path = path.join("key.pem"); + let cert_path = path.join("cert.pem"); + test_config_from_file( + &format!( + r#" + [[resolvers]] + regex = "regex" + + [resolvers.object_type] + send_encrypted_to_client = false + + [resolvers.storage] + response_url = "https://example.com/" + forward_headers = false + tls.key = "{}" + tls.cert = "{}" + tls.root_store = "{}" + + [resolvers.storage.endpoints] + head = "https://example.com/" + file = "https://example.com/" + index = "https://example.com/" + "#, + key_path.to_string_lossy().escape_default(), + cert_path.to_string_lossy().escape_default(), + cert_path.to_string_lossy().escape_default() + ), + |config| { + assert!(config + .resolvers() + .first() + .unwrap() + .object_type() + .is_crypt4gh()); + }, + ); + }); + } + + #[cfg(all(feature = "crypt4gh", feature = "url-storage"))] + pub(crate) fn with_crypt4gh_keys(test: F) + where + F: FnOnce(&Path, &Path, &Path, &[u8], &[u8]), + { + let tmp_dir = TempDir::new().unwrap(); + + let private_key_path = tmp_dir.path().join("alice.sec"); + let public_key_path = tmp_dir.path().join("alice.pub"); + + generate_keys( + private_key_path.clone(), + public_key_path.clone(), + Ok("".to_string()), + None, + ) + .unwrap(); + + let private_key = get_private_key(private_key_path.clone(), Ok("".to_string())).unwrap(); + let public_key = get_public_key(public_key_path.clone()).unwrap(); + + test( + tmp_dir.path(), + &private_key_path, + &public_key_path, + &private_key, + &public_key, + ); + } + + async fn expected_resolved_request(resolvers: &[Resolver], expected_id: &str) { assert_eq!( - resolver - .resolve_request::(&mut Query::new_with_default_request("id-1", Bam)) + resolvers + .resolve_request::(&mut Query::new_with_defaults("id-1", Bam)) .await .unwrap() .unwrap(), diff --git a/htsget-config/src/resolver/object/mod.rs b/htsget-config/src/resolver/object/mod.rs new file mode 100644 index 000000000..1de043f69 --- /dev/null +++ b/htsget-config/src/resolver/object/mod.rs @@ -0,0 +1,60 @@ +//! Config related to how htsget-rs treats files and objects. Used as part of a `Resolver`. +//! + +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "crypt4gh")] +use crate::tls::crypt4gh::Crypt4GHKeyPair; + +#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq)] +#[serde(untagged, deny_unknown_fields)] +#[non_exhaustive] +pub enum ObjectType { + #[default] + Regular, + #[cfg(feature = "crypt4gh")] + // Only valid for url storage. + GenerateKeys { send_encrypted_to_client: bool }, + #[cfg(feature = "crypt4gh")] + Crypt4GH { + send_encrypted_to_client: bool, + #[serde(flatten, skip_serializing)] + crypt4gh: Crypt4GHKeyPair, + }, +} + +impl ObjectType { + #[cfg(feature = "crypt4gh")] + pub fn is_crypt4gh(&self) -> bool { + match self { + #[cfg(feature = "url-storage")] + ObjectType::GenerateKeys { .. } => true, + ObjectType::Crypt4GH { .. } => true, + _ => false, + } + } + + /// Should returned data be unencrypted for the client. + #[cfg(feature = "crypt4gh")] + pub fn send_encrypted_to_client(&self) -> Option { + match self { + #[cfg(feature = "url-storage")] + ObjectType::GenerateKeys { + send_encrypted_to_client, + } => Some(*send_encrypted_to_client), + ObjectType::Crypt4GH { + send_encrypted_to_client, + .. + } => Some(*send_encrypted_to_client), + _ => None, + } + } + + #[cfg(feature = "crypt4gh")] + pub fn crypt4gh_key_pair(&self) -> Option<&Crypt4GHKeyPair> { + match self { + ObjectType::Crypt4GH { crypt4gh, .. } => Some(crypt4gh), + _ => None, + } + } +} diff --git a/htsget-config/src/storage/url/endpoints.rs b/htsget-config/src/storage/url/endpoints.rs new file mode 100644 index 000000000..21c67a7d1 --- /dev/null +++ b/htsget-config/src/storage/url/endpoints.rs @@ -0,0 +1,37 @@ +use http::Uri; +use serde::{Deserialize, Serialize}; + +use crate::storage::url::{default_url, ValidatedUrl}; + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(default)] +pub struct Endpoints { + index: ValidatedUrl, + file: ValidatedUrl, +} + +impl Default for Endpoints { + fn default() -> Self { + Self { + index: default_url(), + file: default_url(), + } + } +} + +impl Endpoints { + /// Construct a new endpoints config. + pub fn new(index: ValidatedUrl, file: ValidatedUrl) -> Self { + Self { index, file } + } + + /// Get the index endpoint. + pub fn index(&self) -> &Uri { + &self.index.0.inner + } + + /// Get the file endpoint. + pub fn file(&self) -> &Uri { + &self.file.0.inner + } +} diff --git a/htsget-config/src/storage/url.rs b/htsget-config/src/storage/url/mod.rs similarity index 75% rename from htsget-config/src/storage/url.rs rename to htsget-config/src/storage/url/mod.rs index d08c2bb3e..9359f2c3a 100644 --- a/htsget-config/src/storage/url.rs +++ b/htsget-config/src/storage/url/mod.rs @@ -8,8 +8,11 @@ use serde_with::with_prefix; use crate::error::Error::ParseError; use crate::error::{Error, Result}; use crate::storage::local::default_authority; +use crate::storage::url::endpoints::Endpoints; use crate::tls::client::TlsClientConfig; +pub mod endpoints; + fn default_url() -> ValidatedUrl { ValidatedUrl(Url { inner: InnerUrl::from_str(&format!("https://{}", default_authority())) @@ -22,9 +25,11 @@ with_prefix!(client_auth_prefix "client_"); #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(default)] pub struct UrlStorage { - url: ValidatedUrl, + endpoints: Endpoints, response_url: ValidatedUrl, forward_headers: bool, + user_agent: Option, + danger_accept_invalid_certs: bool, header_blacklist: Vec, #[serde(skip_serializing)] tls: TlsClientConfig, @@ -33,9 +38,10 @@ pub struct UrlStorage { #[derive(Deserialize, Debug, Clone)] #[serde(try_from = "UrlStorage")] pub struct UrlStorageClient { - url: ValidatedUrl, + endpoints: Endpoints, response_url: ValidatedUrl, forward_headers: bool, + user_agent: Option, header_blacklist: Vec, client: Client, } @@ -62,10 +68,11 @@ impl TryFrom for UrlStorageClient { .map_err(|err| ParseError(format!("building url storage client: {}", err)))?; Ok(Self::new( - storage.url, + storage.endpoints, storage.response_url, storage.forward_headers, storage.header_blacklist, + storage.user_agent, client, )) } @@ -74,24 +81,26 @@ impl TryFrom for UrlStorageClient { impl UrlStorageClient { /// Create a new url storage client. pub fn new( - url: ValidatedUrl, + endpoints: Endpoints, response_url: ValidatedUrl, forward_headers: bool, header_blacklist: Vec, + user_agent: Option, client: Client, ) -> Self { Self { - url, + endpoints, response_url, forward_headers, + user_agent, header_blacklist, client, } } - /// Get the url called when resolving the query. - pub fn url(&self) -> &InnerUrl { - &self.url.0.inner + /// Get the endpoints config. + pub fn endpoints(&self) -> &Endpoints { + &self.endpoints } /// Get the response url to return to the client @@ -104,6 +113,11 @@ impl UrlStorageClient { self.forward_headers } + /// Get the user agent. + pub fn user_agent(&self) -> Option { + self.user_agent.clone() + } + /// Get the headers that should not be forwarded. pub fn header_blacklist(&self) -> &[String] { &self.header_blacklist @@ -128,6 +142,12 @@ pub(crate) struct Url { #[serde(try_from = "Url")] pub struct ValidatedUrl(pub(crate) Url); +impl From for ValidatedUrl { + fn from(url: InnerUrl) -> Self { + ValidatedUrl(Url { inner: url }) + } +} + impl ValidatedUrl { /// Get the inner url. pub fn into_inner(self) -> InnerUrl { @@ -149,31 +169,35 @@ impl TryFrom for ValidatedUrl { impl UrlStorage { /// Create a new url storage. pub fn new( - url: InnerUrl, + endpoints: Endpoints, response_url: InnerUrl, forward_headers: bool, header_blacklist: Vec, + user_agent: Option, + danger_accept_invalid_certs: bool, tls: TlsClientConfig, ) -> Self { Self { - url: ValidatedUrl(Url { inner: url }), + endpoints, response_url: ValidatedUrl(Url { inner: response_url, }), forward_headers, header_blacklist, + user_agent, + danger_accept_invalid_certs, tls, } } - /// Get the url called when resolving the query. - pub fn url(&self) -> &InnerUrl { - &self.url.0.inner + /// Get the endpoints config. + pub fn endpoints(&self) -> &Endpoints { + &self.endpoints } /// Get the response url which is returned to the client. pub fn response_url(&self) -> &InnerUrl { - &self.url.0.inner + &self.response_url.0.inner } /// Whether headers received in a query request should be @@ -182,6 +206,11 @@ impl UrlStorage { self.forward_headers } + /// Get the user agent. + pub fn user_agent(&self) -> Option<&str> { + self.user_agent.as_deref() + } + /// Get the tls client config. pub fn tls(&self) -> &TlsClientConfig { &self.tls @@ -191,11 +220,13 @@ impl UrlStorage { impl Default for UrlStorage { fn default() -> Self { Self { - url: default_url(), + endpoints: Default::default(), response_url: default_url(), forward_headers: true, header_blacklist: vec![], - tls: TlsClientConfig::default(), + user_agent: None, + danger_accept_invalid_certs: false, + tls: Default::default(), } } } @@ -216,10 +247,19 @@ mod tests { with_test_certificates(|path, _, _| { let client_config = client_config_from_path(path); let url_storage = UrlStorageClient::try_from(UrlStorage::new( - "https://example.com".parse::().unwrap(), + Endpoints::new( + ValidatedUrl(Url { + inner: "https://example.com".parse::().unwrap(), + }), + ValidatedUrl(Url { + inner: "https://example.com".parse::().unwrap(), + }), + ), "https://example.com".parse::().unwrap(), true, vec![], + Some("user-agent".to_string()), + false, client_config, )); @@ -240,12 +280,17 @@ mod tests { regex = "regex" [resolvers.storage] - url = "https://example.com/" response_url = "https://example.com/" forward_headers = false + user_agent = "user-agent" tls.key = "{}" tls.cert = "{}" tls.root_store = "{}" + + [resolvers.storage.endpoints] + head = "https://example.com/" + file = "https://example.com/" + index = "https://example.com/" "#, key_path.to_string_lossy().escape_default(), cert_path.to_string_lossy().escape_default(), @@ -255,8 +300,8 @@ mod tests { println!("{:?}", config.resolvers().first().unwrap().storage()); assert!(matches!( config.resolvers().first().unwrap().storage(), - Storage::Url { url_storage } if *url_storage.url() == "https://example.com/" - && !url_storage.forward_headers() + Storage::Url { url_storage } if *url_storage.endpoints().file() == "https://example.com/" + && !url_storage.forward_headers() && url_storage.user_agent() == Some("user-agent".to_string()) )); }, ); diff --git a/htsget-config/src/tls/crypt4gh.rs b/htsget-config/src/tls/crypt4gh.rs new file mode 100644 index 000000000..55ffdf68b --- /dev/null +++ b/htsget-config/src/tls/crypt4gh.rs @@ -0,0 +1,108 @@ +//! Config related to Crypt4GH keys. + +use std::path::PathBuf; + +use crypt4gh::keys::{get_private_key, get_public_key}; +use serde::{Deserialize, Serialize}; +use tracing::warn; + +use async_crypt4gh::{KeyPair, PublicKey}; + +use crate::error::Error::ParseError; +use crate::error::{Error, Result}; +use crate::tls::load_key; + +/// Wrapper around a private key. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[serde(try_from = "PathBuf", into = "Vec")] +pub struct PrivateKey(rustls::PrivateKey); + +impl PrivateKey { + /// Get the inner value. + pub fn into_inner(self) -> rustls::PrivateKey { + self.0 + } +} + +impl AsRef for PrivateKey { + fn as_ref(&self) -> &rustls::PrivateKey { + &self.0 + } +} + +impl TryFrom for PrivateKey { + type Error = Error; + + fn try_from(path: PathBuf) -> Result { + Ok(PrivateKey(load_key(path)?)) + } +} + +impl From for Vec { + fn from(key: PrivateKey) -> Self { + key.into_inner().0 + } +} + +/// Config for Crypt4GH keys. +#[derive(Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(try_from = "Crypt4GHPath")] +pub struct Crypt4GHKeyPair { + key_pair: KeyPair, +} + +impl Crypt4GHKeyPair { + /// Create a new Crypt4GH config. + pub fn new(key_pair: KeyPair) -> Self { + Self { key_pair } + } + + /// Get the key pair + pub fn key_pair(&self) -> &KeyPair { + &self.key_pair + } +} + +#[derive(Deserialize, Debug, Clone)] +pub struct Crypt4GHPath { + private_key: PathBuf, + public_key: PathBuf, +} + +impl TryFrom for Crypt4GHKeyPair { + type Error = Error; + + fn try_from(crypt4gh_path: Crypt4GHPath) -> Result { + let private_key = get_private_key(crypt4gh_path.private_key.clone(), Ok("".to_string())); + + let private_key = match private_key { + Ok(key) => key, + Err(err) => { + warn!( + err = err.to_string(), + "error getting crypt4gh key, falling back to rustls key" + ); + PrivateKey::try_from(crypt4gh_path.private_key) + .map_err(|_| ParseError(format!("failed to parse crypt4gh key: {}", err)))? + .into_inner() + .0 + } + }; + + let parse_public_key = |key: Option| { + Ok( + key + .map(|key| { + get_public_key(key).map_err(|err| ParseError(format!("loading public key: {}", err))) + }) + .transpose()? + .map(PublicKey::new), + ) + }; + + Ok(Self::new(KeyPair::new( + rustls::PrivateKey(private_key), + parse_public_key(Some(crypt4gh_path.public_key))?.expect("expected valid public key"), + ))) + } +} diff --git a/htsget-config/src/tls/mod.rs b/htsget-config/src/tls/mod.rs index 90c5881a6..3fff4ba56 100644 --- a/htsget-config/src/tls/mod.rs +++ b/htsget-config/src/tls/mod.rs @@ -17,6 +17,9 @@ use crate::error::{Error, Result}; use crate::types::Scheme; use crate::types::Scheme::{Http, Https}; +#[cfg(feature = "crypt4gh")] +pub mod crypt4gh; + /// A trait to determine which scheme a key pair option has. pub trait KeyPairScheme { /// Get the scheme. diff --git a/htsget-config/src/types.rs b/htsget-config/src/types.rs index 08c13284e..26b552c6d 100644 --- a/htsget-config/src/types.rs +++ b/htsget-config/src/types.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Debug, Display, Formatter}; use std::io::ErrorKind::Other; use std::{fmt, io, result}; @@ -12,13 +12,16 @@ use tracing::instrument; use crate::error::Error; use crate::error::Error::ParseError; +use crate::resolver::object::ObjectType; +use crate::types::TaggedTypeAll::All; pub type Result = result::Result; /// An enumeration with all the possible formats. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all(serialize = "UPPERCASE"))] pub enum Format { + #[default] #[serde(alias = "bam", alias = "BAM")] Bam, #[serde(alias = "cram", alias = "CRAM")] @@ -29,6 +32,31 @@ pub enum Format { Bcf, } +/// The type of key of the file. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum KeyType { + File, + Index, +} + +impl KeyType { + /// Get the key type from an ending. + pub fn from_ending>(key: K) -> KeyType { + if key.as_ref().ends_with(Format::Bam.index_file_ending()) + || key.as_ref().ends_with(Format::Bcf.index_file_ending()) + || key.as_ref().ends_with(Format::Cram.index_file_ending()) + || key.as_ref().ends_with(Format::Vcf.index_file_ending()) + || key.as_ref().ends_with(".bam.gzi") + || key.as_ref().ends_with(".vcf.gz.gzi") + || key.as_ref().ends_with(".bcf.gzi") + { + Self::Index + } else { + Self::File + } + } +} + /// Todo allow these to be configurable. impl Format { pub fn file_ending(&self) -> &str { @@ -40,8 +68,17 @@ impl Format { } } - pub fn fmt_file(&self, id: &str) -> String { - format!("{id}{}", self.file_ending()) + pub fn fmt_file(&self, query: &Query) -> String { + let id = query.id(); + let id = format!("{id}{}", self.file_ending()); + + #[cfg(feature = "crypt4gh")] + if query.object_type().is_crypt4gh() { + return format!("{id}.c4gh"); + } + + #[allow(clippy::let_and_return)] + id } pub fn index_file_ending(&self) -> &str { @@ -69,6 +106,18 @@ impl Format { } } + pub fn gzi_endings(&self) -> io::Result<&str> { + match self { + Format::Bam => Ok(".bam.gzi"), + Format::Cram => Err(io::Error::new( + Other, + "CRAM does not support GZI".to_string(), + )), + Format::Vcf => Ok(".vcf.gz.gzi"), + Format::Bcf => Ok(".bcf.gzi"), + } + } + pub fn fmt_gzi(&self, id: &str) -> io::Result { Ok(format!("{id}{}", self.gzi_index_file_ending()?)) } @@ -92,9 +141,10 @@ impl Display for Format { } /// Class component of htsget response. -#[derive(Copy, Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[derive(Copy, Debug, PartialEq, Eq, Clone, Serialize, Deserialize, Default)] #[serde(rename_all(serialize = "lowercase"))] pub enum Class { + #[default] #[serde(alias = "header", alias = "HEADER")] Header, #[serde(alias = "body", alias = "BODY")] @@ -219,6 +269,12 @@ pub enum Fields { List(HashSet), } +impl Default for Fields { + fn default() -> Self { + Self::Tagged(All) + } +} + /// Possible values for the tags parameter. #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] #[serde(untagged)] @@ -229,12 +285,18 @@ pub enum Tags { List(HashSet), } +impl Default for Tags { + fn default() -> Self { + Self::Tagged(All) + } +} + /// The no tags parameter. -#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Default)] pub struct NoTags(pub Option>); /// A struct containing the information from the HTTP request. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] pub struct Request { path: String, query: HashMap, @@ -256,6 +318,12 @@ impl Request { Self::new(id, Default::default(), Default::default()) } + /// Set the request headers. + pub fn with_headers(mut self, headers: HeaderMap) -> Self { + self.headers = headers; + self + } + /// Get the id. pub fn path(&self) -> &str { &self.path @@ -274,7 +342,7 @@ impl Request { /// A query contains all the parameters that can be used when requesting /// a search for either of `reads` or `variants`. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] pub struct Query { id: String, format: Format, @@ -288,11 +356,17 @@ pub struct Query { no_tags: NoTags, /// The raw HTTP request information. request: Request, + object_type: ObjectType, } impl Query { /// Create a new query. - pub fn new(id: impl Into, format: Format, request: Request) -> Self { + pub fn new( + id: impl Into, + format: Format, + request: Request, + object_type: ObjectType, + ) -> Self { Self { id: id.into(), format, @@ -303,13 +377,19 @@ impl Query { tags: Tags::Tagged(TaggedTypeAll::All), no_tags: NoTags(None), request, + object_type, } } /// Create a new query with a default request. - pub fn new_with_default_request(id: impl Into, format: Format) -> Self { + pub fn new_with_defaults(id: impl Into, format: Format) -> Self { let id = id.into(); - Self::new(id.clone(), format, Request::new_with_id(id)) + Self::new( + id.clone(), + format, + Request::new_with_id(id), + Default::default(), + ) } /// Set the id. @@ -408,6 +488,22 @@ impl Query { pub fn request(&self) -> &Request { &self.request } + + /// Get the object type of this query. + pub fn object_type(&self) -> &ObjectType { + &self.object_type + } + + /// Set the object type. + pub fn with_object_type(mut self, object_type: ObjectType) -> Self { + self.set_object_type(object_type); + self + } + + /// Set the object type. + pub fn set_object_type(&mut self, object_type: ObjectType) { + self.object_type = object_type; + } } #[derive(Error, Debug, PartialEq, Eq)] @@ -477,11 +573,11 @@ impl From for HtsGetError { } /// The headers that need to be supplied when requesting data from a url. -#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize)] -pub struct Headers(HashMap); +#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)] +pub struct Headers(BTreeMap); impl Headers { - pub fn new(headers: HashMap) -> Self { + pub fn new(headers: BTreeMap) -> Self { Self(headers) } @@ -512,13 +608,13 @@ impl Headers { self.0.extend(headers.into_inner()); } - /// Get the inner HashMap. - pub fn into_inner(self) -> HashMap { + /// Get the inner BTreeMap. + pub fn into_inner(self) -> BTreeMap { self.0 } - /// Get a reference to the inner HashMap. - pub fn as_ref_inner(&self) -> &HashMap { + /// Get a reference to the inner BTreeMap. + pub fn as_ref_inner(&self) -> &BTreeMap { &self.0 } } @@ -623,7 +719,7 @@ impl Response { #[cfg(test)] mod tests { - use std::collections::{HashMap, HashSet}; + use std::collections::{BTreeMap, HashSet}; use std::str::FromStr; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -741,44 +837,43 @@ mod tests { #[test] fn query_new() { - let result = Query::new_with_default_request("NA12878", Format::Bam); + let result = Query::new_with_defaults("NA12878", Format::Bam); assert_eq!(result.id(), "NA12878"); } #[test] fn query_with_format() { - let result = Query::new_with_default_request("NA12878", Format::Bam); + let result = Query::new_with_defaults("NA12878", Format::Bam); assert_eq!(result.format(), Format::Bam); } #[test] fn query_with_class() { - let result = Query::new_with_default_request("NA12878", Format::Bam).with_class(Class::Header); + let result = Query::new_with_defaults("NA12878", Format::Bam).with_class(Class::Header); assert_eq!(result.class(), Class::Header); } #[test] fn query_with_reference_name() { - let result = - Query::new_with_default_request("NA12878", Format::Bam).with_reference_name("chr1"); + let result = Query::new_with_defaults("NA12878", Format::Bam).with_reference_name("chr1"); assert_eq!(result.reference_name(), Some("chr1")); } #[test] fn query_with_start() { - let result = Query::new_with_default_request("NA12878", Format::Bam).with_start(0); + let result = Query::new_with_defaults("NA12878", Format::Bam).with_start(0); assert_eq!(result.interval().start(), Some(0)); } #[test] fn query_with_end() { - let result = Query::new_with_default_request("NA12878", Format::Bam).with_end(0); + let result = Query::new_with_defaults("NA12878", Format::Bam).with_end(0); assert_eq!(result.interval().end(), Some(0)); } #[test] fn query_with_fields() { - let result = Query::new_with_default_request("NA12878", Format::Bam).with_fields(Fields::List( + let result = Query::new_with_defaults("NA12878", Format::Bam).with_fields(Fields::List( HashSet::from_iter(vec!["QNAME".to_string(), "FLAG".to_string()]), )); assert_eq!( @@ -792,15 +887,14 @@ mod tests { #[test] fn query_with_tags() { - let result = Query::new_with_default_request("NA12878", Format::Bam) - .with_tags(Tags::Tagged(TaggedTypeAll::All)); + let result = + Query::new_with_defaults("NA12878", Format::Bam).with_tags(Tags::Tagged(TaggedTypeAll::All)); assert_eq!(result.tags(), &Tags::Tagged(TaggedTypeAll::All)); } #[test] fn query_with_no_tags() { - let result = - Query::new_with_default_request("NA12878", Format::Bam).with_no_tags(vec!["RG", "OQ"]); + let result = Query::new_with_defaults("NA12878", Format::Bam).with_no_tags(vec!["RG", "OQ"]); assert_eq!( result.no_tags(), &NoTags(Some(HashSet::from_iter(vec![ @@ -836,19 +930,19 @@ mod tests { #[test] fn headers_with_header() { - let header = Headers::new(HashMap::new()).with_header("Range", "bytes=0-1023"); + let header = Headers::new(BTreeMap::new()).with_header("Range", "bytes=0-1023"); let result = header.0.get("Range"); assert_eq!(result, Some(&"bytes=0-1023".to_string())); } #[test] fn headers_is_empty() { - assert!(Headers::new(HashMap::new()).is_empty()); + assert!(Headers::new(BTreeMap::new()).is_empty()); } #[test] fn headers_insert() { - let mut header = Headers::new(HashMap::new()); + let mut header = Headers::new(BTreeMap::new()); header.insert("Range", "bytes=0-1023"); let result = header.0.get("Range"); assert_eq!(result, Some(&"bytes=0-1023".to_string())); @@ -856,10 +950,10 @@ mod tests { #[test] fn headers_extend() { - let mut headers = Headers::new(HashMap::new()); + let mut headers = Headers::new(BTreeMap::new()); headers.insert("Range", "bytes=0-1023"); - let mut extend_with = Headers::new(HashMap::new()); + let mut extend_with = Headers::new(BTreeMap::new()); extend_with.insert("header", "value"); headers.extend(extend_with); @@ -873,7 +967,7 @@ mod tests { #[test] fn headers_multiple_values() { - let headers = Headers::new(HashMap::new()) + let headers = Headers::new(BTreeMap::new()) .with_header("Range", "bytes=0-1023") .with_header("Range", "bytes=1024-2047"); let result = headers.0.get("Range"); @@ -907,7 +1001,7 @@ mod tests { #[test] fn serialize_headers() { - let headers = Headers::new(HashMap::new()) + let headers = Headers::new(BTreeMap::new()) .with_header("Range", "bytes=0-1023") .with_header("Range", "bytes=1024-2047"); @@ -923,23 +1017,23 @@ mod tests { #[test] fn url_with_headers() { let result = Url::new("data:application/vnd.ga4gh.bam;base64,QkFNAQ==") - .with_headers(Headers::new(HashMap::new())); + .with_headers(Headers::new(BTreeMap::new())); assert_eq!(result.headers, None); } #[test] fn url_add_headers() { - let mut headers = Headers::new(HashMap::new()); + let mut headers = Headers::new(BTreeMap::new()); headers.insert("Range", "bytes=0-1023"); - let mut extend_with = Headers::new(HashMap::new()); + let mut extend_with = Headers::new(BTreeMap::new()); extend_with.insert("header", "value"); let result = Url::new("data:application/vnd.ga4gh.bam;base64,QkFNAQ==") .with_headers(headers) .add_headers(extend_with); - let expected_headers = Headers::new(HashMap::new()) + let expected_headers = Headers::new(BTreeMap::new()) .with_header("Range", "bytes=0-1023") .with_header("header", "value"); diff --git a/htsget-http/Cargo.toml b/htsget-http/Cargo.toml index 985c1de7a..d966877e9 100644 --- a/htsget-http/Cargo.toml +++ b/htsget-http/Cargo.toml @@ -13,6 +13,7 @@ repository = "https://github.com/umccr/htsget-rs" [features] s3-storage = ["htsget-config/s3-storage", "htsget-search/s3-storage", "htsget-test/s3-storage"] url-storage = ["htsget-config/url-storage", "htsget-search/url-storage", "htsget-test/url-storage"] +crypt4gh = ["htsget-config/crypt4gh", "htsget-search/crypt4gh", "htsget-test/crypt4gh"] default = [] [dependencies] @@ -23,5 +24,5 @@ htsget-search = { version = "0.7.0", path = "../htsget-search", default-features htsget-config = { version = "0.9.0", path = "../htsget-config", default-features = false } htsget-test = { version = "0.6.0", path = "../htsget-test", default-features = false } futures = { version = "0.3" } -tokio = { version = "1.28", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.29", features = ["macros", "rt-multi-thread"] } tracing = "0.1" diff --git a/htsget-http/README.md b/htsget-http/README.md index 391ce8d22..46dc3fe0f 100644 --- a/htsget-http/README.md +++ b/htsget-http/README.md @@ -38,6 +38,7 @@ These functions take query and endpoint information, and process it using [htsge This crate has the following features: * `s3-storage`: used to enable `S3Storage` functionality. * `url-storage`: used to enable `UrlStorage` functionality. +* `crypt4gh`: used to enable Crypt4GH functionality. [warp]: https://github.com/seanmonstar/warp [htsget-search]: ../htsget-search diff --git a/htsget-http/src/lib.rs b/htsget-http/src/lib.rs index fee7b0db0..1fc94bab3 100644 --- a/htsget-http/src/lib.rs +++ b/htsget-http/src/lib.rs @@ -92,7 +92,6 @@ mod tests { use htsget_search::htsget::from_storage::HtsGetFromStorage; use htsget_search::htsget::HtsGet; use htsget_search::storage::local::LocalStorage; - use htsget_test::util::expected_bgzf_eof_data_url; use super::*; @@ -125,7 +124,7 @@ mod tests { let request = HashMap::new(); let mut expected_response_headers = Headers::default(); - expected_response_headers.insert("Range".to_string(), "bytes=0-2596770".to_string()); + expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string()); let request = Request::new( "bam/htsnexus_test_NA12878".to_string(), @@ -164,7 +163,7 @@ mod tests { request.insert("end".to_string(), "200".to_string()); let mut expected_response_headers = Headers::default(); - expected_response_headers.insert("Range".to_string(), "bytes=0-3465".to_string()); + expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string()); let request = Request::new( "vcf/sample1-bcbio-cancer".to_string(), @@ -191,7 +190,7 @@ mod tests { }; let mut expected_response_headers = Headers::default(); - expected_response_headers.insert("Range".to_string(), "bytes=0-2596770".to_string()); + expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string()); assert_eq!( post(get_searcher(), body, request, Endpoint::Reads).await, @@ -234,7 +233,7 @@ mod tests { }; let mut expected_response_headers = Headers::default(); - expected_response_headers.insert("Range".to_string(), "bytes=0-3465".to_string()); + expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string()); assert_eq!( post(get_searcher(), body, request, Endpoint::Variants).await, @@ -248,7 +247,6 @@ mod tests { vec![ Url::new("http://127.0.0.1:8081/data/vcf/sample1-bcbio-cancer.vcf.gz".to_string()) .with_headers(headers), - Url::new(expected_bgzf_eof_data_url()), ], )) } @@ -259,7 +257,6 @@ mod tests { vec![ Url::new("http://127.0.0.1:8081/data/bam/htsnexus_test_NA12878.bam".to_string()) .with_headers(headers), - Url::new(expected_bgzf_eof_data_url()), ], )) } diff --git a/htsget-http/src/post_request.rs b/htsget-http/src/post_request.rs index f41f9a000..ebb5bf3d5 100644 --- a/htsget-http/src/post_request.rs +++ b/htsget-http/src/post_request.rs @@ -81,7 +81,7 @@ mod tests { } .get_queries(request.clone(), &Endpoint::Variants) .unwrap(), - vec![Query::new("id", Format::Vcf, request).with_class(Class::Header)] + vec![Query::new("id", Format::Vcf, request, Default::default()).with_class(Class::Header)] ); } @@ -104,7 +104,7 @@ mod tests { } .get_queries(request.clone(), &Endpoint::Variants) .unwrap(), - vec![Query::new("id", Format::Vcf, request) + vec![Query::new("id", Format::Vcf, request, Default::default()) .with_class(Class::Header) .with_reference_name("20".to_string()) .with_start(150) @@ -139,12 +139,12 @@ mod tests { .get_queries(request.clone(), &Endpoint::Variants) .unwrap(), vec![ - Query::new("id", Format::Vcf, request.clone()) + Query::new("id", Format::Vcf, request.clone(), Default::default()) .with_class(Class::Header) .with_reference_name("20".to_string()) .with_start(150) .with_end(153), - Query::new("id", Format::Vcf, request) + Query::new("id", Format::Vcf, request, Default::default()) .with_class(Class::Header) .with_reference_name("11".to_string()) .with_start(152) diff --git a/htsget-http/src/query_builder.rs b/htsget-http/src/query_builder.rs index dc32c8c11..e474519dc 100644 --- a/htsget-http/src/query_builder.rs +++ b/htsget-http/src/query_builder.rs @@ -17,7 +17,7 @@ impl QueryBuilder { let id = request.path().to_string(); Self { - query: Query::new(id, format, request), + query: Query::new(id, format, request, Default::default()), } } diff --git a/htsget-lambda/Cargo.toml b/htsget-lambda/Cargo.toml index 0e0cdad78..6dbcf07a9 100644 --- a/htsget-lambda/Cargo.toml +++ b/htsget-lambda/Cargo.toml @@ -13,10 +13,11 @@ repository = "https://github.com/umccr/htsget-rs" [features] s3-storage = ["htsget-config/s3-storage", "htsget-search/s3-storage", "htsget-http/s3-storage", "htsget-test/s3-storage"] url-storage = ["htsget-config/url-storage", "htsget-search/url-storage", "htsget-http/url-storage", "htsget-test/url-storage"] +crypt4gh = ["htsget-config/crypt4gh", "htsget-search/crypt4gh", "htsget-http/crypt4gh", "htsget-test/crypt4gh"] default = [] [dependencies] -tokio = { version = "1.28", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.29", features = ["macros", "rt-multi-thread"] } tower-http = { version = "0.4", features = ["cors"] } lambda_http = { version = "0.8" } lambda_runtime = { version = "0.8" } @@ -27,7 +28,7 @@ htsget-test = { version = "0.6.0", path = "../htsget-test", features = ["http"], serde = { version = "1.0" } serde_json = "1.0" mime = "0.3" -regex = "1.8" +regex = "1.9" tracing = "0.1" tracing-subscriber = "0.3" bytes = "1.4" @@ -35,4 +36,4 @@ bytes = "1.4" [dev-dependencies] async-trait = "0.1" query_map = { version = "0.7", features = ["url-query"] } -tempfile = "3.6" +tempfile = "3.7" diff --git a/htsget-lambda/README.md b/htsget-lambda/README.md index cb69330be..08c7dd916 100644 --- a/htsget-lambda/README.md +++ b/htsget-lambda/README.md @@ -46,6 +46,7 @@ routing queries are exposed in the public API. This crate has the following features: * `s3-storage`: used to enable `S3Storage` functionality. * `url-storage`: used to enable `UrlStorage` functionality. +* `crypt4gh`: used to enable Crypt4GH functionality. ## License diff --git a/htsget-lambda/src/lib.rs b/htsget-lambda/src/lib.rs index a18657ea7..f091ebaa6 100644 --- a/htsget-lambda/src/lib.rs +++ b/htsget-lambda/src/lib.rs @@ -282,7 +282,7 @@ mod tests { use htsget_search::storage::configure_cors; use htsget_search::storage::data_server::BindDataServer; use htsget_test::http::server::{expected_url_path, test_response, test_response_service_info}; - use htsget_test::http::{config_with_tls, default_test_config, get_test_file}; + use htsget_test::http::{config_with_tls, default_test_config, get_test_file_string}; use htsget_test::http::{cors, server}; use htsget_test::http::{Header, Response as TestResponse, TestRequest, TestServer}; @@ -698,8 +698,8 @@ mod tests { test(router).await; } - fn get_request_from_file(file_path: &str) -> Request { - let event = get_test_file(file_path); + async fn get_request_from_file(file_path: &str) -> Request { + let event = get_test_file_string(file_path).await; lambda_http::request::from_str(&event).expect("Failed to create lambda request.") } @@ -720,7 +720,7 @@ mod tests { with_router( |router| async move { let response = route_request_to_response( - get_request_from_file(file_path), + get_request_from_file(file_path).await, router, expected_path, config, @@ -739,7 +739,7 @@ mod tests { with_router( |router| async { let response = route_request_to_response( - get_request_from_file(file_path), + get_request_from_file(file_path).await, router, expected_path, config, diff --git a/htsget-search/Cargo.toml b/htsget-search/Cargo.toml index fb6e30b2c..36522d18f 100644 --- a/htsget-search/Cargo.toml +++ b/htsget-search/Cargo.toml @@ -12,26 +12,34 @@ repository = "https://github.com/umccr/htsget-rs" [features] s3-storage = [ - "dep:bytes", "dep:aws-sdk-s3", "dep:aws-config", + "dep:bytes", "htsget-config/s3-storage", "htsget-test/s3-storage", "htsget-test/aws-mocks" ] url-storage = [ "dep:bytes", + "dep:mockall", + "dep:mockall_double", + "dep:pin-project", "hyper/client", "reqwest", - "pin-project-lite", "htsget-config/url-storage", "htsget-test/url-storage" ] +crypt4gh = [ + "dep:async-crypt4gh", + "dep:crypt4gh", + "htsget-config/crypt4gh", + "htsget-test/crypt4gh" +] default = [] [dependencies] # Axum server -url = "2.3" +url = "2.4" hyper = { version = "0.14", features = ["http1", "http2", "server"] } tower-http = { version = "0.4", features = ["trace", "cors", "fs"] } http = "0.2" @@ -41,7 +49,7 @@ tower = { version = "0.4", features = ["make"] } # Async tokio-rustls = "0.24" -tokio = { version = "1.28", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.29", features = ["macros", "rt-multi-thread"] } tokio-util = { version = "0.7", features = ["io", "compat"] } futures = { version = "0.3" } futures-util = "0.3" @@ -57,7 +65,11 @@ aws-config = { version = "0.56", optional = true } # Url storage reqwest = { version = "0.11", features = ["rustls-tls", "stream"], default-features = false, optional = true } -pin-project-lite = { version = "0.2", optional = true } +crypt4gh = { version = "0.4", git = "https://github.com/EGA-archive/crypt4gh-rust", optional = true } +async-crypt4gh = { version = "0.1.0", path = "../async-crypt4gh", optional = true } +pin-project = { version = "1.1", optional = true } +mockall = { version = "0.9", optional = true } +mockall_double = { version = "0.3", optional = true } # Error control, tracing, config thiserror = "1.0" @@ -70,10 +82,9 @@ serde = "1.0" [dev-dependencies] tempfile = "3.6" data-url = "0.3" +walkdir = "2.5" -# Axum server -reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } - +reqwest = { version = "0.11", features = ["rustls-tls", "stream"], default-features = false } criterion = { version = "0.5", features = ["async_tokio"] } [[bench]] diff --git a/htsget-search/README.md b/htsget-search/README.md index 60b332e29..e0d660573 100644 --- a/htsget-search/README.md +++ b/htsget-search/README.md @@ -26,7 +26,16 @@ specific code, this defines an interface that handles the core logic of a htsget Future work may split these two modules into separate crates. +There are three different kinds of storage: +* `LocalStorage`: which spawns a local server that can respond to URL tickets. +* `S3Storage`: which returns pre-signed AWS S3 URLs for the tickets. +* `UrlStorage`: which returns a custom URL endpoint which is intended to respond to URL tickets. + * For `UrlStorage`, returning Crypt4GH encrypted files is supported using a custom protocol, + by compiling with the `crypt4gh` flag. See the crypt4gh [ARCHITECTURE.md][architecture] file for Crypt4GH for a description on + how this works. + [noodles]: https://github.com/zaeleus/noodles +[architecture]: ../docs/crypt4gh/ARCHITECTURE.md ### Traits abstraction @@ -56,6 +65,7 @@ For htsget-rs to function, files need to be organised in the following way: * BGZF compressed files (BAM, CRAM, VCF) can optionally also have a [GZ index][gzi] to make byte ranges smaller. * GZI files must end with `.gzi`. * See [minimising byte ranges][minimising-byte-ranges] for more details on GZI. +* Crypt4GH encrypted files must end with `.c4gh`. This is quite inflexible, and is likely to change in the future to allow arbitrary mappings of files and indices. @@ -77,6 +87,7 @@ used to process requests. This crate has the following features: * `s3-storage`: used to enable `S3Storage` functionality. * `url-storage`: used to enable `UrlStorage` functionality. +* `crypt4gh`: used to enable Crypt4GH functionality. [htsget]: src/htsget [storage]: src/storage diff --git a/htsget-search/benches/search_benchmarks.rs b/htsget-search/benches/search_benchmarks.rs index f86c1da20..c71c67749 100644 --- a/htsget-search/benches/search_benchmarks.rs +++ b/htsget-search/benches/search_benchmarks.rs @@ -46,12 +46,12 @@ fn criterion_benchmark(c: &mut Criterion) { bench_query( &mut group, "[LIGHT] Bam query all", - Query::new_with_default_request("bam/htsnexus_test_NA12878", Bam), + Query::new_with_defaults("bam/htsnexus_test_NA12878", Bam), ); bench_query( &mut group, "[LIGHT] Bam query specific", - Query::new_with_default_request("bam/htsnexus_test_NA12878", Bam) + Query::new_with_defaults("bam/htsnexus_test_NA12878", Bam) .with_reference_name("11") .with_start(4999977) .with_end(5008321), @@ -59,18 +59,18 @@ fn criterion_benchmark(c: &mut Criterion) { bench_query( &mut group, "[LIGHT] Bam query header", - Query::new_with_default_request("bam/htsnexus_test_NA12878", Bam).with_class(Header), + Query::new_with_defaults("bam/htsnexus_test_NA12878", Bam).with_class(Header), ); bench_query( &mut group, "[LIGHT] Cram query all", - Query::new_with_default_request("cram/htsnexus_test_NA12878", Cram), + Query::new_with_defaults("cram/htsnexus_test_NA12878", Cram), ); bench_query( &mut group, "[LIGHT] Cram query specific", - Query::new_with_default_request("cram/htsnexus_test_NA12878", Cram) + Query::new_with_defaults("cram/htsnexus_test_NA12878", Cram) .with_reference_name("11") .with_start(4999977) .with_end(5008321), @@ -78,18 +78,18 @@ fn criterion_benchmark(c: &mut Criterion) { bench_query( &mut group, "[LIGHT] Cram query header", - Query::new_with_default_request("cram/htsnexus_test_NA12878", Cram).with_class(Header), + Query::new_with_defaults("cram/htsnexus_test_NA12878", Cram).with_class(Header), ); bench_query( &mut group, "[LIGHT] Vcf query all", - Query::new_with_default_request("vcf/sample1-bcbio-cancer", Vcf), + Query::new_with_defaults("vcf/sample1-bcbio-cancer", Vcf), ); bench_query( &mut group, "[LIGHT] Vcf query specific", - Query::new_with_default_request("vcf/sample1-bcbio-cancer", Vcf) + Query::new_with_defaults("vcf/sample1-bcbio-cancer", Vcf) .with_reference_name("chrM") .with_start(151) .with_end(153), @@ -97,18 +97,18 @@ fn criterion_benchmark(c: &mut Criterion) { bench_query( &mut group, "[LIGHT] Vcf query header", - Query::new_with_default_request("vcf/sample1-bcbio-cancer", Vcf).with_class(Header), + Query::new_with_defaults("vcf/sample1-bcbio-cancer", Vcf).with_class(Header), ); bench_query( &mut group, "[LIGHT] Bcf query all", - Query::new_with_default_request("bcf/sample1-bcbio-cancer", Bcf), + Query::new_with_defaults("bcf/sample1-bcbio-cancer", Bcf), ); bench_query( &mut group, "[LIGHT] Bcf query specific", - Query::new_with_default_request("bcf/sample1-bcbio-cancer", Bcf) + Query::new_with_defaults("bcf/sample1-bcbio-cancer", Bcf) .with_reference_name("chrM") .with_start(151) .with_end(153), @@ -116,7 +116,7 @@ fn criterion_benchmark(c: &mut Criterion) { bench_query( &mut group, "[LIGHT] Bcf query header", - Query::new_with_default_request("bcf/sample1-bcbio-cancer", Bcf).with_class(Header), + Query::new_with_defaults("bcf/sample1-bcbio-cancer", Bcf).with_class(Header), ); group.finish(); diff --git a/htsget-search/src/htsget/bam_search.rs b/htsget-search/src/htsget/bam_search.rs index 7aa17c65a..f7933d59a 100644 --- a/htsget-search/src/htsget/bam_search.rs +++ b/htsget-search/src/htsget/bam_search.rs @@ -1,6 +1,7 @@ //! Module providing the search capability using BAM/BAI files //! +use std::num::NonZeroUsize; use std::sync::Arc; use async_trait::async_trait; @@ -8,6 +9,7 @@ use noodles::bam; use noodles::bam::bai; use noodles::bam::bai::Index; use noodles::bgzf; +use noodles::bgzf::r#async::reader::Builder; use noodles::bgzf::VirtualPosition; use noodles::csi::binning_index::index::reference_sequence::index::LinearIndex; use noodles::csi::binning_index::index::ReferenceSequence; @@ -19,6 +21,7 @@ use tracing::{instrument, trace}; use crate::htsget::search::{BgzfSearch, Search, SearchAll, SearchReads}; use crate::htsget::HtsGetError; +use crate::storage::HeadOutput; use crate::Class::Body; use crate::{ htsget::{Format, Query, Result}, @@ -37,13 +40,14 @@ impl BgzfSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { #[instrument(level = "trace", skip(self, index))] async fn get_byte_ranges_for_unmapped( &self, - query: &Query, + _query: &Query, index: &Index, + head_output: &HeadOutput, ) -> Result> { trace!("getting byte ranges for unmapped reads"); let last_interval = index.last_first_record_start_position(); @@ -60,7 +64,7 @@ where Ok(vec![BytesPosition::default() .with_start(start.compressed()) - .with_end(self.position_at_eof(query).await?) + .with_end(self.position_at_eof(head_output).await?) .with_class(Body)]) } @@ -79,10 +83,14 @@ impl for BamSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { fn init_reader(inner: ReaderType) -> AsyncReader { - AsyncReader::new(inner) + AsyncReader::from( + Builder::default() + .set_worker_count(NonZeroUsize::try_from(1).expect("expected valid non zero usize")) + .build_with_reader(inner), + ) } async fn read_header(reader: &mut AsyncReader) -> io::Result
{ @@ -94,6 +102,10 @@ where reader.read_index().await } + fn into_inner(reader: AsyncReader) -> ReaderType { + reader.into_inner().into_inner() + } + #[instrument(level = "trace", skip(self, index, header, query))] async fn get_byte_ranges_for_reference_name( &self, @@ -101,10 +113,11 @@ where index: &Index, header: &Header, query: &Query, + head_output: &HeadOutput, ) -> Result> { trace!("getting byte ranges for reference name"); self - .get_byte_ranges_for_reference_name_reads(&reference_name, index, header, query) + .get_byte_ranges_for_reference_name_reads(&reference_name, index, header, query, head_output) .await } @@ -123,7 +136,7 @@ impl for BamSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { async fn get_reference_sequence_from_name<'a>( &self, @@ -137,8 +150,11 @@ where &self, query: &Query, bai_index: &Index, + head_output: &HeadOutput, ) -> Result> { - self.get_byte_ranges_for_unmapped(query, bai_index).await + self + .get_byte_ranges_for_unmapped(query, bai_index, head_output) + .await } async fn get_byte_ranges_for_reference_sequence( @@ -146,9 +162,10 @@ where ref_seq_id: usize, query: &Query, index: &Index, + head_output: &HeadOutput, ) -> Result> { self - .get_byte_ranges_for_reference_sequence_bgzf(query, ref_seq_id, index) + .get_byte_ranges_for_reference_sequence_bgzf(query, ref_seq_id, index, head_output) .await } } @@ -156,7 +173,7 @@ where impl BamSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { /// Create the bam search. pub fn new(storage: Arc) -> Self { @@ -170,7 +187,6 @@ pub(crate) mod tests { use htsget_config::storage::local::LocalStorage as ConfigLocalStorage; use htsget_test::http::concat::ConcatResponse; - use htsget_test::util::expected_bgzf_eof_data_url; #[cfg(feature = "s3-storage")] use crate::htsget::from_storage::tests::with_aws_storage_fn; @@ -188,17 +204,14 @@ pub(crate) mod tests { async fn search_all_reads() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam); + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam); let response = search.search(query).await; println!("{response:#?}"); let expected_response = Ok(Response::new( Format::Bam, - vec![ - Url::new(expected_url()) - .with_headers(Headers::default().with_header("Range", "bytes=0-2596770")), - Url::new(expected_bgzf_eof_data_url()), - ], + vec![Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=0-2596798"))], )); assert_eq!(response, expected_response); @@ -211,8 +224,8 @@ pub(crate) mod tests { async fn search_unmapped_reads() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) - .with_reference_name("*"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_reference_name("*"); let response = search.search(query).await; println!("{response:#?}"); @@ -223,9 +236,8 @@ pub(crate) mod tests { .with_headers(Headers::default().with_header("Range", "bytes=0-4667")) .with_class(Header), Url::new(expected_url()) - .with_headers(Headers::default().with_header("Range", "bytes=2060795-2596770")) + .with_headers(Headers::default().with_header("Range", "bytes=2060795-2596798")) .with_class(Body), - Url::new(expected_bgzf_eof_data_url()).with_class(Body), ], )); assert_eq!(response, expected_response); @@ -239,8 +251,8 @@ pub(crate) mod tests { async fn search_reference_name_without_seq_range_chr11() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) - .with_reference_name("11"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_reference_name("11"); let response = search.search(query).await; println!("{response:#?}"); @@ -249,7 +261,8 @@ pub(crate) mod tests { vec![ Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=0-996014")), - Url::new(expected_bgzf_eof_data_url()), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=2596771-2596798")), ], )); assert_eq!(response, expected_response); @@ -263,8 +276,8 @@ pub(crate) mod tests { async fn search_reference_name_without_seq_range_chr20() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) - .with_reference_name("20"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_reference_name("20"); let response = search.search(query).await; println!("{response:#?}"); @@ -277,7 +290,9 @@ pub(crate) mod tests { Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=977196-2128165")) .with_class(Body), - Url::new(expected_bgzf_eof_data_url()).with_class(Body), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=2596771-2596798")) + .with_class(Body), ], )); assert_eq!(response, expected_response); @@ -291,7 +306,7 @@ pub(crate) mod tests { async fn search_reference_name_with_seq_range() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam) .with_reference_name("11") .with_start(5015000) .with_end(5050000); @@ -313,7 +328,9 @@ pub(crate) mod tests { Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=977196-996014")) .with_class(Body), - Url::new(expected_bgzf_eof_data_url()).with_class(Body), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=2596771-2596798")) + .with_class(Body), ], )); assert_eq!(response, expected_response); @@ -327,7 +344,7 @@ pub(crate) mod tests { async fn search_reference_name_no_end_position() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam) .with_reference_name("11") .with_start(5015000); let response = search.search(query).await; @@ -342,7 +359,9 @@ pub(crate) mod tests { Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=256721-996014")) .with_class(Body), - Url::new(expected_bgzf_eof_data_url()).with_class(Body), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=2596771-2596798")) + .with_class(Body), ], )); assert_eq!(response, expected_response); @@ -356,7 +375,7 @@ pub(crate) mod tests { async fn search_many_response_urls() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam) .with_reference_name("11") .with_start(4999976) .with_end(5003981); @@ -376,7 +395,8 @@ pub(crate) mod tests { .with_headers(Headers::default().with_header("Range", "bytes=824361-842100")), Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=977196-996014")), - Url::new(expected_bgzf_eof_data_url()), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=2596771-2596798")), ], )); assert_eq!(response, expected_response); @@ -391,7 +411,7 @@ pub(crate) mod tests { with_local_storage_fn( |storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam) .with_reference_name("11") .with_start(5015000) .with_end(5050000); @@ -407,7 +427,9 @@ pub(crate) mod tests { Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=256721-1065951")) .with_class(Body), - Url::new(expected_bgzf_eof_data_url()).with_class(Body), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=2596771-2596798")) + .with_class(Body), ], )); assert_eq!(response, expected_response); @@ -424,8 +446,7 @@ pub(crate) mod tests { async fn search_header() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = - Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam).with_class(Header); + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_class(Header); let response = search.search(query).await; println!("{response:#?}"); @@ -449,8 +470,8 @@ pub(crate) mod tests { async fn search_header_with_no_mapped_reads() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) - .with_reference_name("22"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_reference_name("22"); let response = search.search(query).await; println!("{response:#?}"); @@ -460,7 +481,9 @@ pub(crate) mod tests { Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=0-4667")) .with_class(Header), - Url::new(expected_bgzf_eof_data_url()).with_class(Body), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=2596771-2596798")) + .with_class(Body), ], )); assert_eq!(response, expected_response); @@ -474,8 +497,8 @@ pub(crate) mod tests { async fn search_header_with_non_existent_reference_name() { with_local_storage(|storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) - .with_reference_name("25"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_reference_name("25"); let response = search.search(query).await; println!("{response:#?}"); @@ -491,7 +514,7 @@ pub(crate) mod tests { with_local_storage_fn( |storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam); + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -508,8 +531,8 @@ pub(crate) mod tests { with_local_storage_fn( |storage| async move { let search = BamSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) - .with_reference_name("20"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_reference_name("20"); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -527,7 +550,7 @@ pub(crate) mod tests { |storage| async move { let search = BamSearch::new(storage.clone()); let query = - Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam).with_class(Header); + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_class(Header); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -545,7 +568,7 @@ pub(crate) mod tests { |storage| async move { let search = BamSearch::new(storage.clone()); let query = - Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam).with_class(Header); + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_class(Header); let index = search.read_index(&query).await.unwrap(); let response = search.get_header_end_offset(&index).await; @@ -566,7 +589,7 @@ pub(crate) mod tests { with_aws_storage_fn( |storage| async move { let search = BamSearch::new(storage); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam); + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam); let response = search.search(query).await; assert!(response.is_err()); @@ -584,8 +607,8 @@ pub(crate) mod tests { with_aws_storage_fn( |storage| async move { let search = BamSearch::new(storage); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam) - .with_reference_name("20"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_reference_name("20"); let response = search.search(query).await; assert!(response.is_err()); @@ -604,7 +627,7 @@ pub(crate) mod tests { |storage| async move { let search = BamSearch::new(storage); let query = - Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam).with_class(Header); + Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam).with_class(Header); let response = search.search(query).await; assert!(response.is_err()); diff --git a/htsget-search/src/htsget/bcf_search.rs b/htsget-search/src/htsget/bcf_search.rs index 545e5a7f3..cc73bc7c5 100644 --- a/htsget-search/src/htsget/bcf_search.rs +++ b/htsget-search/src/htsget/bcf_search.rs @@ -1,11 +1,13 @@ //! Module providing the search capability using BCF files //! +use std::num::NonZeroUsize; use std::sync::Arc; use async_trait::async_trait; use futures_util::stream::FuturesOrdered; use noodles::bcf; +use noodles::bgzf::r#async::reader::Builder; use noodles::bgzf::VirtualPosition; use noodles::csi::binning_index::index::reference_sequence::index::BinnedIndex; use noodles::csi::binning_index::index::ReferenceSequence; @@ -18,7 +20,7 @@ use tracing::{instrument, trace}; use crate::htsget::search::{find_first, BgzfSearch, Search}; use crate::htsget::ParsedHeader; -use crate::storage::{BytesPosition, Storage}; +use crate::storage::{BytesPosition, HeadOutput, Storage}; use crate::{Format, Query, Result}; type AsyncReader = bcf::AsyncReader>; @@ -33,7 +35,7 @@ impl BgzfSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { async fn read_bytes(_header: &Header, reader: &mut AsyncReader) -> Option { reader.read_lazy_record(&mut Default::default()).await.ok() @@ -50,10 +52,14 @@ impl for BcfSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { fn init_reader(inner: ReaderType) -> AsyncReader { - AsyncReader::new(inner) + AsyncReader::from( + Builder::default() + .set_worker_count(NonZeroUsize::try_from(1).expect("expected valid non zero usize")) + .build_with_reader(inner), + ) } async fn read_header(reader: &mut AsyncReader) -> io::Result
{ @@ -72,6 +78,10 @@ where csi::AsyncReader::new(inner).read_index().await } + fn into_inner(reader: AsyncReader) -> ReaderType { + reader.into_inner().into_inner() + } + #[instrument(level = "trace", skip(self, index, header, query))] async fn get_byte_ranges_for_reference_name( &self, @@ -79,6 +89,7 @@ where index: &Index, header: &Header, query: &Query, + head_output: &HeadOutput, ) -> Result> { trace!("getting byte ranges for reference name"); // We are assuming the order of the contigs in the header and the references sequences @@ -102,7 +113,7 @@ where .await?; let byte_ranges = self - .get_byte_ranges_for_reference_sequence_bgzf(query, ref_seq_id, index) + .get_byte_ranges_for_reference_sequence_bgzf(query, ref_seq_id, index, head_output) .await?; Ok(byte_ranges) } @@ -119,7 +130,7 @@ where impl BcfSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { /// Create the bcf search. pub fn new(storage: Arc) -> Self { @@ -134,7 +145,6 @@ mod tests { use htsget_config::storage::local::LocalStorage as ConfigLocalStorage; use htsget_config::types::Class::Body; use htsget_test::http::concat::ConcatResponse; - use htsget_test::util::expected_bgzf_eof_data_url; #[cfg(feature = "s3-storage")] use crate::htsget::from_storage::tests::with_aws_storage_fn; @@ -155,7 +165,7 @@ mod tests { with_local_storage(|storage| async move { let search = BcfSearch::new(storage.clone()); let filename = "sample1-bcbio-cancer"; - let query = Query::new_with_default_request(filename, Format::Bcf); + let query = Query::new_with_defaults(filename, Format::Bcf); let response = search.search(query).await; println!("{response:#?}"); @@ -175,17 +185,14 @@ mod tests { with_local_storage(|storage| async move { let search = BcfSearch::new(storage.clone()); let filename = "vcf-spec-v4.3"; - let query = Query::new_with_default_request(filename, Format::Bcf).with_reference_name("20"); + let query = Query::new_with_defaults(filename, Format::Bcf).with_reference_name("20"); let response = search.search(query).await; println!("{response:#?}"); let expected_response = Ok(Response::new( Format::Bcf, - vec![ - Url::new(expected_url(filename)) - .with_headers(Headers::default().with_header("Range", "bytes=0-949")), - Url::new(expected_bgzf_eof_data_url()), - ], + vec![Url::new(expected_url(filename)) + .with_headers(Headers::default().with_header("Range", "bytes=0-977"))], )); assert_eq!(response, expected_response); @@ -210,7 +217,7 @@ mod tests { with_local_storage(|storage| async move { let search = BcfSearch::new(storage.clone()); let filename = "sample1-bcbio-cancer"; - let query = Query::new_with_default_request(filename, Format::Bcf) + let query = Query::new_with_defaults(filename, Format::Bcf) .with_reference_name("chrM") .with_start(151); let response = search.search(query).await; @@ -242,7 +249,7 @@ mod tests { with_local_storage(|storage| async move { let search = BcfSearch::new(storage.clone()); let filename = "vcf-spec-v4.3"; - let query = Query::new_with_default_request(filename, Format::Bcf).with_class(Header); + let query = Query::new_with_defaults(filename, Format::Bcf).with_class(Header); let response = search.search(query).await; println!("{response:#?}"); @@ -267,7 +274,7 @@ mod tests { with_local_storage_fn( |storage| async move { let search = BcfSearch::new(storage.clone()); - let query = Query::new_with_default_request("vcf-spec-v4.3", Format::Bcf); + let query = Query::new_with_defaults("vcf-spec-v4.3", Format::Bcf); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -285,7 +292,7 @@ mod tests { |storage| async move { let search = BcfSearch::new(storage.clone()); let query = - Query::new_with_default_request("vcf-spec-v4.3", Format::Bcf).with_reference_name("chrM"); + Query::new_with_defaults("vcf-spec-v4.3", Format::Bcf).with_reference_name("chrM"); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -302,8 +309,7 @@ mod tests { with_local_storage_fn( |storage| async move { let search = BcfSearch::new(storage.clone()); - let query = - Query::new_with_default_request("vcf-spec-v4.3", Format::Bcf).with_class(Header); + let query = Query::new_with_defaults("vcf-spec-v4.3", Format::Bcf).with_class(Header); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -320,7 +326,7 @@ mod tests { with_local_storage(|storage| async move { let search = BcfSearch::new(storage.clone()); let query = - Query::new_with_default_request("vcf-spec-v4.3", Format::Bcf).with_reference_name("chr1"); + Query::new_with_defaults("vcf-spec-v4.3", Format::Bcf).with_reference_name("chr1"); let response = search.search(query).await; println!("{response:#?}"); @@ -336,8 +342,7 @@ mod tests { with_local_storage_fn( |storage| async move { let search = BcfSearch::new(storage.clone()); - let query = - Query::new_with_default_request("vcf-spec-v4.3", Format::Bcf).with_class(Header); + let query = Query::new_with_defaults("vcf-spec-v4.3", Format::Bcf).with_class(Header); let index = search.read_index(&query).await.unwrap(); let response = search.get_header_end_offset(&index).await; @@ -358,7 +363,7 @@ mod tests { with_aws_storage_fn( |storage| async move { let search = BcfSearch::new(storage); - let query = Query::new_with_default_request("vcf-spec-v4.3", Format::Bcf); + let query = Query::new_with_defaults("vcf-spec-v4.3", Format::Bcf); let response = search.search(query).await; assert!(response.is_err()); @@ -377,7 +382,7 @@ mod tests { |storage| async move { let search = BcfSearch::new(storage); let query = - Query::new_with_default_request("vcf-spec-v4.3", Format::Bcf).with_reference_name("chrM"); + Query::new_with_defaults("vcf-spec-v4.3", Format::Bcf).with_reference_name("chrM"); let response = search.search(query).await; assert!(response.is_err()); @@ -395,8 +400,7 @@ mod tests { with_aws_storage_fn( |storage| async move { let search = BcfSearch::new(storage); - let query = - Query::new_with_default_request("vcf-spec-v4.3", Format::Bcf).with_class(Header); + let query = Query::new_with_defaults("vcf-spec-v4.3", Format::Bcf).with_class(Header); let response = search.search(query).await; assert!(response.is_err()); @@ -413,7 +417,7 @@ mod tests { ) -> Option<(String, ConcatResponse)> { let search = BcfSearch::new(storage.clone()); let filename = "sample1-bcbio-cancer"; - let query = Query::new_with_default_request(filename, Format::Bcf) + let query = Query::new_with_defaults(filename, Format::Bcf) .with_reference_name("chrM") .with_start(151) .with_end(153); @@ -432,11 +436,8 @@ mod tests { fn expected_bcf_response(filename: &str) -> Response { Response::new( Format::Bcf, - vec![ - Url::new(expected_url(filename)) - .with_headers(Headers::default().with_header("Range", "bytes=0-3529")), - Url::new(expected_bgzf_eof_data_url()), - ], + vec![Url::new(expected_url(filename)) + .with_headers(Headers::default().with_header("Range", "bytes=0-3557"))], ) } diff --git a/htsget-search/src/htsget/cram_search.rs b/htsget-search/src/htsget/cram_search.rs index 49fa286fc..1a1183fed 100644 --- a/htsget-search/src/htsget/cram_search.rs +++ b/htsget-search/src/htsget/cram_search.rs @@ -21,7 +21,7 @@ use htsget_config::types::Interval; use crate::htsget::search::{Search, SearchAll, SearchReads}; use crate::htsget::{ConcurrencyError, ParsedHeader}; -use crate::storage::{BytesPosition, DataBlock, Storage}; +use crate::storage::{BytesPosition, DataBlock, HeadOutput, Storage}; use crate::Class::Body; use crate::{Format, HtsGetError, Query, Result}; @@ -45,12 +45,12 @@ impl for CramSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { #[instrument(level = "trace", skip_all, ret)] - async fn get_byte_ranges_for_all(&self, query: &Query) -> Result> { + async fn get_byte_ranges_for_all(&self, head_output: &HeadOutput) -> Result> { Ok(vec![ - BytesPosition::default().with_end(self.position_at_eof(query).await?) + BytesPosition::default().with_end(self.position_at_eof(head_output).await?) ]) } @@ -75,6 +75,7 @@ where _header: &Header, _reader: &mut AsyncReader, _query: &Query, + _head_output: &HeadOutput, ) -> Result { Ok( BytesPosition::default() @@ -101,7 +102,7 @@ impl for CramSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { async fn get_reference_sequence_from_name<'a>( &self, @@ -115,12 +116,14 @@ where &self, query: &Query, index: &Index, + head_output: &HeadOutput, ) -> Result> { Self::bytes_ranges_from_index( self, query, index, Arc::new(|record: &Record| record.reference_sequence_id().is_none()), + head_output, ) .await } @@ -130,12 +133,14 @@ where ref_seq_id: usize, query: &Query, index: &Index, + head_output: &HeadOutput, ) -> Result> { Self::bytes_ranges_from_index( self, query, index, Arc::new(move |record: &Record| record.reference_sequence_id() == Some(ref_seq_id)), + head_output, ) .await } @@ -147,7 +152,7 @@ impl Search, Index, AsyncReader< for CramSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { fn init_reader(inner: ReaderType) -> AsyncReader { AsyncReader::new(BufReader::new(inner)) @@ -169,15 +174,20 @@ where crai::AsyncReader::new(inner).read_index().await } + fn into_inner(reader: AsyncReader) -> ReaderType { + reader.into_inner().into_inner() + } + async fn get_byte_ranges_for_reference_name( &self, reference_name: String, index: &Index, header: &Header, query: &Query, + head_output: &HeadOutput, ) -> Result> { self - .get_byte_ranges_for_reference_name_reads(&reference_name, index, header, query) + .get_byte_ranges_for_reference_name_reads(&reference_name, index, header, query, head_output) .await } @@ -193,7 +203,7 @@ where impl CramSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { /// Create the cram search. pub fn new(storage: Arc) -> Self { @@ -207,6 +217,7 @@ where query: &Query, crai_index: &[Record], predicate: Arc, + head_output: &HeadOutput, ) -> Result> where F: Fn(&Record) -> bool + Send + Sync + 'static, @@ -247,9 +258,11 @@ where )); } Some(last) if predicate(last) => { - if let Some(range) = - Self::bytes_ranges_for_record(query.interval(), last, self.position_at_eof(query).await?)? - { + if let Some(range) = Self::bytes_ranges_for_record( + query.interval(), + last, + self.position_at_eof(head_output).await?, + )? { byte_ranges.push(range); } } @@ -293,7 +306,6 @@ mod tests { use htsget_config::storage::local::LocalStorage as ConfigLocalStorage; use htsget_test::http::concat::ConcatResponse; - use htsget_test::util::expected_cram_eof_data_url; #[cfg(feature = "s3-storage")] use crate::htsget::from_storage::tests::with_aws_storage_fn; @@ -311,17 +323,14 @@ mod tests { async fn search_all_reads() { with_local_storage(|storage| async move { let search = CramSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram); + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram); let response = search.search(query).await; println!("{response:#?}"); let expected_response = Ok(Response::new( Format::Cram, - vec![ - Url::new(expected_url()) - .with_headers(Headers::default().with_header("Range", "bytes=0-1672409")), - Url::new(expected_cram_eof_data_url()), - ], + vec![Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=0-1672447"))], )); assert_eq!(response, expected_response); @@ -334,8 +343,8 @@ mod tests { async fn search_unmapped_reads() { with_local_storage(|storage| async move { let search = CramSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram) - .with_reference_name("*"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram).with_reference_name("*"); let response = search.search(query).await; println!("{response:#?}"); @@ -346,9 +355,8 @@ mod tests { .with_headers(Headers::default().with_header("Range", "bytes=0-6133")) .with_class(Header), Url::new(expected_url()) - .with_headers(Headers::default().with_header("Range", "bytes=1324614-1672409")) + .with_headers(Headers::default().with_header("Range", "bytes=1324614-1672447")) .with_class(Body), - Url::new(expected_cram_eof_data_url()).with_class(Body), ], )); assert_eq!(response, expected_response); @@ -362,8 +370,8 @@ mod tests { async fn search_reference_name_without_seq_range_chr11() { with_local_storage(|storage| async move { let search = CramSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram) - .with_reference_name("11"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram).with_reference_name("11"); let response = search.search(query).await; println!("{response:#?}"); @@ -372,7 +380,8 @@ mod tests { vec![ Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=0-625727")), - Url::new(expected_cram_eof_data_url()), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=1672410-1672447")), ], )); assert_eq!(response, expected_response); @@ -386,8 +395,8 @@ mod tests { async fn search_reference_name_without_seq_range_chr20() { with_local_storage(|storage| async move { let search = CramSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram) - .with_reference_name("20"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram).with_reference_name("20"); let response = search.search(query).await; println!("{response:#?}"); @@ -400,7 +409,9 @@ mod tests { Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=625728-1324613")) .with_class(Body), - Url::new(expected_cram_eof_data_url()).with_class(Body), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=1672410-1672447")) + .with_class(Body), ], )); assert_eq!(response, expected_response); @@ -414,7 +425,7 @@ mod tests { async fn search_reference_name_with_seq_range_no_overlap() { with_local_storage(|storage| async move { let search = CramSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram) + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram) .with_reference_name("11") .with_start(5000000) .with_end(5050000); @@ -426,7 +437,8 @@ mod tests { vec![ Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=0-480537")), - Url::new(expected_cram_eof_data_url()), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=1672410-1672447")), ], )); assert_eq!(response, expected_response); @@ -440,7 +452,7 @@ mod tests { async fn search_reference_name_with_seq_range_overlap() { with_local_storage(|storage| async move { let search = CramSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram) + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram) .with_reference_name("11") .with_start(5000000) .with_end(5100000); @@ -459,7 +471,7 @@ mod tests { async fn search_reference_name_with_no_end_position() { with_local_storage(|storage| async move { let search = CramSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram) + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram) .with_reference_name("11") .with_start(5000000); let response = search.search(query).await; @@ -479,7 +491,8 @@ mod tests { vec![ Url::new(expected_url()) .with_headers(Headers::default().with_header("Range", "bytes=0-625727")), - Url::new(expected_cram_eof_data_url()), + Url::new(expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=1672410-1672447")), ], ) } @@ -489,7 +502,7 @@ mod tests { with_local_storage(|storage| async move { let search = CramSearch::new(storage.clone()); let query = - Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram).with_class(Header); + Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram).with_class(Header); let response = search.search(query).await; println!("{response:#?}"); @@ -514,7 +527,7 @@ mod tests { with_local_storage_fn( |storage| async move { let search = CramSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram); + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -531,8 +544,8 @@ mod tests { with_local_storage_fn( |storage| async move { let search = CramSearch::new(storage.clone()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram) - .with_reference_name("20"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram).with_reference_name("20"); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -550,7 +563,7 @@ mod tests { |storage| async move { let search = CramSearch::new(storage.clone()); let query = - Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram).with_class(Header); + Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram).with_class(Header); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -568,7 +581,7 @@ mod tests { with_aws_storage_fn( |storage| async move { let search = CramSearch::new(storage); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram); + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram); let response = search.search(query).await; assert!(response.is_err()); @@ -586,8 +599,8 @@ mod tests { with_aws_storage_fn( |storage| async move { let search = CramSearch::new(storage); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram) - .with_reference_name("20"); + let query = + Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram).with_reference_name("20"); let response = search.search(query).await; assert!(response.is_err()); @@ -606,7 +619,7 @@ mod tests { |storage| async move { let search = CramSearch::new(storage); let query = - Query::new_with_default_request("htsnexus_test_NA12878", Format::Cram).with_class(Header); + Query::new_with_defaults("htsnexus_test_NA12878", Format::Cram).with_class(Header); let response = search.search(query).await; assert!(response.is_err()); diff --git a/htsget-search/src/htsget/from_storage.rs b/htsget-search/src/htsget/from_storage.rs index a392c42f8..2c6ee03c4 100644 --- a/htsget-search/src/htsget/from_storage.rs +++ b/htsget-search/src/htsget/from_storage.rs @@ -57,7 +57,7 @@ impl HtsGet for &[Resolver] { #[async_trait] impl HtsGet for HtsGetFromStorage where - R: AsyncRead + Send + Sync + Unpin, + R: AsyncRead + Send + Sync + Unpin + 'static, S: Storage + Sync + Send + 'static, { #[instrument(level = "debug", skip(self))] @@ -101,11 +101,15 @@ impl ResolveResponse for HtsGetFromStorage { async fn from_url(url_storage_config: &UrlStorageConfig, query: &Query) -> Result { let searcher = HtsGetFromStorage::new(UrlStorage::new( url_storage_config.client_cloned(), - url_storage_config.url().clone(), + url_storage_config.endpoints().clone(), url_storage_config.response_url().clone(), url_storage_config.forward_headers(), url_storage_config.header_blacklist().to_vec(), - )); + url_storage_config.user_agent(), + query, + #[cfg(feature = "crypt4gh")] + Default::default(), + )?); searcher.search(query.clone()).await } } @@ -137,7 +141,6 @@ pub(crate) mod tests { use htsget_config::types::Class::Body; use htsget_config::types::Scheme::Http; use htsget_test::http::concat::ConcatResponse; - use htsget_test::util::expected_bgzf_eof_data_url; use crate::htsget::bam_search::tests::{ expected_url as bam_expected_url, with_local_storage as with_bam_local_storage, BAM_FILE_NAME, @@ -156,17 +159,14 @@ pub(crate) mod tests { async fn search_bam() { with_bam_local_storage(|storage| async move { let htsget = HtsGetFromStorage::new(Arc::try_unwrap(storage).unwrap()); - let query = Query::new_with_default_request("htsnexus_test_NA12878", Format::Bam); + let query = Query::new_with_defaults("htsnexus_test_NA12878", Format::Bam); let response = htsget.search(query).await; println!("{response:#?}"); let expected_response = Ok(Response::new( Format::Bam, - vec![ - Url::new(bam_expected_url()) - .with_headers(Headers::default().with_header("Range", "bytes=0-2596770")), - Url::new(expected_bgzf_eof_data_url()), - ], + vec![Url::new(bam_expected_url()) + .with_headers(Headers::default().with_header("Range", "bytes=0-2596798"))], )); assert_eq!(response, expected_response); @@ -180,7 +180,7 @@ pub(crate) mod tests { with_vcf_local_storage(|storage| async move { let htsget = HtsGetFromStorage::new(Arc::try_unwrap(storage).unwrap()); let filename = "spec-v4.3"; - let query = Query::new_with_default_request(filename, Format::Vcf); + let query = Query::new_with_defaults(filename, Format::Vcf); let response = htsget.search(query).await; println!("{response:#?}"); @@ -199,7 +199,7 @@ pub(crate) mod tests { with_config_local_storage( |_, local_storage| async move { let filename = "spec-v4.3"; - let query = Query::new_with_default_request(filename, Format::Vcf); + let query = Query::new_with_defaults(filename, Format::Vcf); let response = HtsGetFromStorage::<()>::from_local(&local_storage, &query).await; assert_eq!(response, expected_vcf_response(filename)); @@ -224,11 +224,12 @@ pub(crate) mod tests { ".*", "$0", Default::default(), + Default::default(), ) .unwrap()]; let filename = "spec-v4.3"; - let query = Query::new_with_default_request(filename, Format::Vcf); + let query = Query::new_with_defaults(filename, Format::Vcf); let response = resolvers.search(query).await; assert_eq!(response, expected_vcf_response(filename)); @@ -247,11 +248,8 @@ pub(crate) mod tests { fn expected_vcf_response(filename: &str) -> Result { Ok(Response::new( Format::Vcf, - vec![ - Url::new(vcf_expected_url(filename)) - .with_headers(Headers::default().with_header("Range", "bytes=0-822")), - Url::new(expected_bgzf_eof_data_url()), - ], + vec![Url::new(vcf_expected_url(filename)) + .with_headers(Headers::default().with_header("Range", "bytes=0-850"))], )) } diff --git a/htsget-search/src/htsget/search.rs b/htsget-search/src/htsget/search.rs index 88fc56d9a..171950bd0 100644 --- a/htsget-search/src/htsget/search.rs +++ b/htsget-search/src/htsget/search.rs @@ -8,6 +8,8 @@ use std::collections::BTreeSet; use std::sync::Arc; +// #[cfg(feature = "crypt4gh")] +// use async_crypt4gh::reader::builder::Builder; use async_trait::async_trait; use futures::StreamExt; use futures_util::stream::FuturesOrdered; @@ -26,7 +28,9 @@ use tracing::{instrument, trace, trace_span, Instrument}; use htsget_config::types::Class::Header; use crate::htsget::ConcurrencyError; -use crate::storage::{BytesPosition, HeadOptions, RangeUrlOptions, Storage}; +use crate::storage::{ + BytesPosition, BytesPositionOptions, HeadOptions, HeadOutput, RangeUrlOptions, Storage, +}; use crate::storage::{DataBlock, GetOptions}; use crate::{Class, Class::Body, Format, HtsGetError, Query, Response, Result}; @@ -73,7 +77,7 @@ where Index: Send + Sync, { /// This returns mapped and placed unmapped ranges. - async fn get_byte_ranges_for_all(&self, query: &Query) -> Result>; + async fn get_byte_ranges_for_all(&self, head_output: &HeadOutput) -> Result>; /// Get the offset in the file of the end of the header. async fn get_header_end_offset(&self, index: &Index) -> Result; @@ -85,6 +89,7 @@ where header: &Header, reader: &mut Reader, query: &Query, + head_output: &HeadOutput, ) -> Result; /// Get the eof marker for this format. @@ -92,6 +97,28 @@ where /// Get the eof data block for this format. fn get_eof_data_block(&self) -> Option; + + /// Get the eof bytes positions converting from a data block. + fn get_eof_byte_positions(&self, file_size: u64) -> Option> { + if let Some(DataBlock::Data(data, class)) = self.get_eof_data_block() { + let data_len = + u64::try_from(data.len()).map_err(|err| HtsGetError::InvalidInput(err.to_string())); + + return match data_len { + Ok(data_len) => { + let bytes_position = BytesPosition::default() + .with_start(file_size - data_len) + .with_end(file_size); + let bytes_position = bytes_position.set_class(class); + + Some(Ok(bytes_position)) + } + Err(err) => Some(Err(err)), + }; + } + + None + } } /// [SearchReads] represents searching bytes ranges for the reads endpoint. @@ -124,6 +151,7 @@ where &self, query: &Query, index: &Index, + head_output: &HeadOutput, ) -> Result>; /// Get reads ranges for a reference sequence implementation. @@ -132,6 +160,7 @@ where ref_seq_id: usize, query: &Query, index: &Index, + head_output: &HeadOutput, ) -> Result>; ///Get reads for a given reference name and an optional sequence range. @@ -141,9 +170,12 @@ where index: &Index, header: &Header, query: &Query, + head_output: &HeadOutput, ) -> Result> { if reference_name == "*" { - return self.get_byte_ranges_for_unmapped_reads(query, index).await; + return self + .get_byte_ranges_for_unmapped_reads(query, index, head_output) + .await; } let maybe_ref_seq = self @@ -155,7 +187,8 @@ where "reference name not found: {reference_name}" ))), Some(ref_seq_id) => { - Self::get_byte_ranges_for_reference_sequence(self, ref_seq_id, query, index).await + Self::get_byte_ranges_for_reference_sequence(self, ref_seq_id, query, index, head_output) + .await } }?; Ok(byte_ranges) @@ -185,6 +218,8 @@ where async fn read_header(reader: &mut Reader) -> io::Result
; async fn read_index_inner(inner: T) -> io::Result; + fn into_inner(reader: Reader) -> ReaderType; + /// Get ranges for a given reference name and an optional sequence range. async fn get_byte_ranges_for_reference_name( &self, @@ -192,6 +227,7 @@ where index: &Index, header: &Header, query: &Query, + head_output: &HeadOutput, ) -> Result>; /// Get the storage of this format. @@ -202,21 +238,28 @@ where /// Get the position at the end of file marker. #[instrument(level = "trace", skip(self), ret)] - async fn position_at_eof(&self, query: &Query) -> Result { - let file_size = self - .get_storage() - .head( - query.format().fmt_file(query.id()), - HeadOptions::new(query.request().headers()), - ) - .await?; + async fn position_at_eof(&self, file_size: &HeadOutput) -> Result { Ok( - file_size + file_size.content_length() - u64::try_from(self.get_eof_marker().len()) .map_err(|err| HtsGetError::InvalidInput(err.to_string()))?, ) } + /// Get the file size. + #[instrument(level = "trace", skip(self), ret)] + async fn file_size(&self, query: &Query) -> Result { + Ok( + self + .get_storage() + .head( + query.format().fmt_file(query), + HeadOptions::new(query.request().headers(), query.object_type()), + ) + .await?, + ) + } + /// Read the index from the key. #[instrument(level = "trace", skip(self))] async fn read_index(&self, query: &Query) -> Result { @@ -225,7 +268,8 @@ where .get_storage() .get( query.format().fmt_index(query.id()), - GetOptions::new_with_default_range(query.request().headers()), + GetOptions::new_with_default_range(query.request().headers(), query.object_type()), + &mut None, ) .await?; Self::read_index_inner(storage) @@ -235,6 +279,13 @@ where /// Search based on the query. async fn search(&self, query: Query) -> Result { + let index = self.read_index(&query).await?; + + let header_end = self.get_header_end_offset(&index).await?; + let mut output = self.file_size(&query).await?; + + let (header, mut reader) = self.get_header(&query, header_end, &mut output).await?; + match query.class() { Body => { let format = self.get_format(); @@ -246,26 +297,22 @@ where ))); } - let byte_ranges = match query.reference_name().as_ref() { - None => self.get_byte_ranges_for_all(&query).await?, + let mut byte_ranges = match query.reference_name().as_ref() { + None => self.get_byte_ranges_for_all(&output).await?, Some(reference_name) => { - let index = self.read_index(&query).await?; - - let header_end = self.get_header_end_offset(&index).await?; - let (header, mut reader) = self.get_header(&query, header_end).await?; - let mut byte_ranges = self .get_byte_ranges_for_reference_name( reference_name.to_string(), &index, &header, &query, + &output, ) .await?; byte_ranges.push( self - .get_byte_ranges_for_header(&index, &header, &mut reader, &query) + .get_byte_ranges_for_header(&index, &header, &mut reader, &query, &output) .await?, ); @@ -273,11 +320,23 @@ where } }; - let mut blocks = DataBlock::from_bytes_positions(byte_ranges); - if let Some(eof) = self.get_eof_data_block() { - blocks.push(eof); + if let Some(eof) = self.get_eof_byte_positions(output.content_length()) { + byte_ranges.push(eof?); } + let blocks = self + .get_storage() + .update_byte_positions( + Self::into_inner(reader), + BytesPositionOptions::new( + byte_ranges, + output.content_length(), + query.request().headers(), + query.object_type(), + ), + ) + .await?; + self.build_response(&query, blocks).await } Class::Header => { @@ -285,26 +344,29 @@ where self .get_storage() .head( - query.format().fmt_file(query.id()), - HeadOptions::new(query.request().headers()), + query.format().fmt_file(&query), + HeadOptions::new(query.request().headers(), query.object_type()), ) .await?; - let index = self.read_index(&query).await?; - - let header_end = self.get_header_end_offset(&index).await?; - let (header, mut reader) = self.get_header(&query, header_end).await?; - let header_byte_ranges = self - .get_byte_ranges_for_header(&index, &header, &mut reader, &query) + .get_byte_ranges_for_header(&index, &header, &mut reader, &query, &output) .await?; - self - .build_response( - &query, - DataBlock::from_bytes_positions(vec![header_byte_ranges]), + let blocks = self + .get_storage() + .update_byte_positions( + Self::into_inner(reader), + BytesPositionOptions::new( + vec![header_byte_ranges], + output.content_length(), + query.request().headers(), + query.object_type(), + ), ) - .await + .await?; + + self.build_response(&query, blocks).await } } } @@ -324,8 +386,12 @@ where storage_futures.push_back(tokio::spawn(async move { storage .range_url( - query_owned.format().fmt_file(query_owned.id()), - RangeUrlOptions::new(range, query_owned.request().headers()), + query_owned.format().fmt_file(&query_owned), + RangeUrlOptions::new( + range, + query_owned.request().headers(), + query_owned.object_type(), + ), ) .await })); @@ -349,19 +415,29 @@ where /// Get the header from the file specified by the id and format. #[instrument(level = "trace", skip(self))] - async fn get_header(&self, query: &Query, offset: u64) -> Result<(Header, Reader)> { + async fn get_header( + &self, + query: &Query, + offset: u64, + output: &mut HeadOutput, + ) -> Result<(Header, Reader)> { trace!("getting header"); let get_options = GetOptions::new( BytesPosition::default().with_end(offset), query.request().headers(), + query.object_type(), ); let reader_type = self .get_storage() - .get(query.format().fmt_file(query.id()), get_options) + .get( + query.format().fmt_file(query), + get_options, + &mut Some(output), + ) .await?; - let mut reader = Self::init_reader(reader_type); + let mut reader = Self::init_reader(reader_type); Ok(( Self::read_header(&mut reader).await.map_err(|err| { HtsGetError::io_error(format!("reading `{}` header: {}", self.get_format(), err)) @@ -430,6 +506,7 @@ where query: &Query, ref_seq_id: usize, index: &Index, + head_output: &HeadOutput, ) -> Result> { let chunks: Result> = trace_span!("querying chunks").in_scope(|| { trace!(id = ?query.id(), ref_seq_id = ?ref_seq_id, "querying chunks"); @@ -446,7 +523,8 @@ where .get_storage() .get( query.format().fmt_gzi(query.id())?, - GetOptions::new_with_default_range(query.request().headers()), + GetOptions::new_with_default_range(query.request().headers(), query.object_type()), + &mut None, ) .await; let byte_ranges: Vec = match gzi_data { @@ -469,13 +547,13 @@ where .await; self - .bytes_positions_from_chunks(query, chunks?.into_iter(), gzi?.into_iter()) + .bytes_positions_from_chunks(head_output, chunks?.into_iter(), gzi?.into_iter()) .await? } Err(_) => { self .bytes_positions_from_chunks( - query, + head_output, chunks?.into_iter(), Self::index_positions(index).into_iter(), ) @@ -490,7 +568,7 @@ where #[instrument(level = "trace", skip(self, chunks, positions))] async fn bytes_positions_from_chunks<'a>( &self, - query: &Query, + head_output: &HeadOutput, chunks: impl Iterator + Send + 'a, mut positions: impl Iterator + Send + 'a, ) -> Result> { @@ -523,7 +601,7 @@ where let end = match maybe_end { None => match end_position { None => { - let pos = self.position_at_eof(query).await?; + let pos = self.position_at_eof(head_output).await?; end_position = Some(pos); pos } @@ -543,6 +621,7 @@ where &self, _query: &Query, _index: &Index, + _head_output: &HeadOutput, ) -> Result> { Ok(Vec::new()) } @@ -566,9 +645,9 @@ where T: BgzfSearch + Send + Sync, { #[instrument(level = "debug", skip(self), ret)] - async fn get_byte_ranges_for_all(&self, query: &Query) -> Result> { + async fn get_byte_ranges_for_all(&self, head_output: &HeadOutput) -> Result> { Ok(vec![ - BytesPosition::default().with_end(self.position_at_eof(query).await?) + BytesPosition::default().with_end(self.position_at_eof(head_output).await?) ]) } @@ -585,6 +664,7 @@ where )) })?; + // Todo consider the header length if it includes the crypt4gh header. // The header can only extend past the first index position by the maximum BGZF block size // because otherwise the first index position wouldn't be representing the first reference. Ok(first_index_position + MAX_BGZF_ISIZE) @@ -595,7 +675,8 @@ where index: &Index, header: &Header, reader: &mut Reader, - query: &Query, + _query: &Query, + head_output: &HeadOutput, ) -> Result { let current_block_index = self.virtual_position(reader); @@ -621,7 +702,7 @@ where let position = positions.into_iter().next().unwrap_or_default(); if position == 0 { - self.position_at_eof(query).await? + self.position_at_eof(head_output).await? } else { position } diff --git a/htsget-search/src/htsget/vcf_search.rs b/htsget-search/src/htsget/vcf_search.rs index 08a391a6a..c9c837ba1 100644 --- a/htsget-search/src/htsget/vcf_search.rs +++ b/htsget-search/src/htsget/vcf_search.rs @@ -1,11 +1,13 @@ //! Module providing the search capability using VCF files //! +use std::num::NonZeroUsize; use std::sync::Arc; use async_trait::async_trait; use futures_util::stream::FuturesOrdered; use noodles::bgzf; +use noodles::bgzf::r#async::reader::Builder; use noodles::bgzf::VirtualPosition; use noodles::csi::binning_index::index::reference_sequence::index::LinearIndex; use noodles::csi::binning_index::index::ReferenceSequence; @@ -21,7 +23,7 @@ use tracing::{instrument, trace}; use htsget_config::types::HtsGetError; use crate::htsget::search::{find_first, BgzfSearch, Search}; -use crate::storage::{BytesPosition, Storage}; +use crate::storage::{BytesPosition, HeadOutput, Storage}; use crate::{Format, Query, Result}; type AsyncReader = vcf::AsyncReader>; @@ -36,7 +38,7 @@ impl BgzfSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { async fn read_bytes(header: &Header, reader: &mut AsyncReader) -> Option { reader @@ -56,10 +58,14 @@ impl for VcfSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { fn init_reader(inner: ReaderType) -> AsyncReader { - AsyncReader::new(bgzf::AsyncReader::new(inner)) + AsyncReader::new( + Builder::default() + .set_worker_count(NonZeroUsize::try_from(1).expect("expected valid non zero usize")) + .build_with_reader(inner), + ) } async fn read_header(reader: &mut AsyncReader) -> io::Result
{ @@ -70,6 +76,10 @@ where tabix::AsyncReader::new(inner).read_index().await } + fn into_inner(reader: AsyncReader) -> ReaderType { + reader.into_inner().into_inner() + } + #[instrument(level = "trace", skip(self, index, query))] async fn get_byte_ranges_for_reference_name( &self, @@ -77,6 +87,7 @@ where index: &Index, _header: &Header, query: &Query, + head_output: &HeadOutput, ) -> Result> { trace!("getting byte ranges for reference name"); // We are assuming the order of the names and the references sequences @@ -107,7 +118,7 @@ where .await?; let byte_ranges = self - .get_byte_ranges_for_reference_sequence_bgzf(query, ref_seq_id, index) + .get_byte_ranges_for_reference_sequence_bgzf(query, ref_seq_id, index, head_output) .await?; Ok(byte_ranges) } @@ -124,7 +135,7 @@ where impl VcfSearch where S: Storage + Send + Sync + 'static, - ReaderType: AsyncRead + Unpin + Send + Sync, + ReaderType: AsyncRead + Unpin + Send + Sync + 'static, { /// Create the vcf search. pub fn new(storage: Arc) -> Self { @@ -139,7 +150,6 @@ pub(crate) mod tests { use htsget_config::storage::local::LocalStorage as ConfigLocalStorage; use htsget_config::types::Class::Body; use htsget_test::http::concat::ConcatResponse; - use htsget_test::util::expected_bgzf_eof_data_url; #[cfg(feature = "s3-storage")] use crate::htsget::from_storage::tests::with_aws_storage_fn; @@ -160,7 +170,7 @@ pub(crate) mod tests { with_local_storage(|storage| async move { let search = VcfSearch::new(storage.clone()); let filename = "sample1-bcbio-cancer"; - let query = Query::new_with_default_request(filename, Format::Vcf); + let query = Query::new_with_defaults(filename, Format::Vcf); let response = search.search(query).await; println!("{response:#?}"); @@ -180,17 +190,14 @@ pub(crate) mod tests { with_local_storage(|storage| async move { let search = VcfSearch::new(storage.clone()); let filename = "spec-v4.3"; - let query = Query::new_with_default_request(filename, Format::Vcf).with_reference_name("20"); + let query = Query::new_with_defaults(filename, Format::Vcf).with_reference_name("20"); let response = search.search(query).await; println!("{response:#?}"); let expected_response = Ok(Response::new( Format::Vcf, - vec![ - Url::new(expected_url(filename)) - .with_headers(Headers::default().with_header("Range", "bytes=0-822")), - Url::new(expected_bgzf_eof_data_url()), - ], + vec![Url::new(expected_url(filename)) + .with_headers(Headers::default().with_header("Range", "bytes=0-850"))], )); assert_eq!(response, expected_response); @@ -213,7 +220,7 @@ pub(crate) mod tests { with_local_storage(|storage| async move { let search = VcfSearch::new(storage.clone()); let filename = "sample1-bcbio-cancer"; - let query = Query::new_with_default_request(filename, Format::Vcf) + let query = Query::new_with_defaults(filename, Format::Vcf) .with_reference_name("chrM") .with_start(151) .with_end(153); @@ -249,7 +256,7 @@ pub(crate) mod tests { with_local_storage(|storage| async move { let search = VcfSearch::new(storage.clone()); let filename = "spec-v4.3"; - let query = Query::new_with_default_request(filename, Format::Vcf).with_class(Header); + let query = Query::new_with_defaults(filename, Format::Vcf).with_class(Header); let response = search.search(query).await; println!("{response:#?}"); @@ -274,7 +281,7 @@ pub(crate) mod tests { with_local_storage_fn( |storage| async move { let search = VcfSearch::new(storage.clone()); - let query = Query::new_with_default_request("spec-v4.3", Format::Vcf); + let query = Query::new_with_defaults("spec-v4.3", Format::Vcf); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -291,8 +298,7 @@ pub(crate) mod tests { with_local_storage_fn( |storage| async move { let search = VcfSearch::new(storage.clone()); - let query = - Query::new_with_default_request("spec-v4.3", Format::Vcf).with_reference_name("chrM"); + let query = Query::new_with_defaults("spec-v4.3", Format::Vcf).with_reference_name("chrM"); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -309,7 +315,7 @@ pub(crate) mod tests { with_local_storage_fn( |storage| async move { let search = VcfSearch::new(storage.clone()); - let query = Query::new_with_default_request("spec-v4.3", Format::Vcf).with_class(Header); + let query = Query::new_with_defaults("spec-v4.3", Format::Vcf).with_class(Header); let response = search.search(query).await; assert!(matches!(response, Err(NotFound(_)))); @@ -325,8 +331,7 @@ pub(crate) mod tests { async fn search_header_with_non_existent_reference_name() { with_local_storage(|storage| async move { let search = VcfSearch::new(storage.clone()); - let query = - Query::new_with_default_request("spec-v4.3", Format::Vcf).with_reference_name("chr1"); + let query = Query::new_with_defaults("spec-v4.3", Format::Vcf).with_reference_name("chr1"); let response = search.search(query).await; println!("{response:#?}"); @@ -342,7 +347,7 @@ pub(crate) mod tests { with_local_storage_fn( |storage| async move { let search = VcfSearch::new(storage.clone()); - let query = Query::new_with_default_request("spec-v4.3", Format::Vcf).with_class(Header); + let query = Query::new_with_defaults("spec-v4.3", Format::Vcf).with_class(Header); let index = search.read_index(&query).await.unwrap(); let response = search.get_header_end_offset(&index).await; @@ -363,7 +368,7 @@ pub(crate) mod tests { with_aws_storage_fn( |storage| async move { let search = VcfSearch::new(storage); - let query = Query::new_with_default_request("spec-v4.3", Format::Vcf); + let query = Query::new_with_defaults("spec-v4.3", Format::Vcf); let response = search.search(query).await; assert!(response.is_err()); @@ -381,8 +386,7 @@ pub(crate) mod tests { with_aws_storage_fn( |storage| async move { let search = VcfSearch::new(storage); - let query = - Query::new_with_default_request("spec-v4.3", Format::Vcf).with_reference_name("chrM"); + let query = Query::new_with_defaults("spec-v4.3", Format::Vcf).with_reference_name("chrM"); let response = search.search(query).await; assert!(response.is_err()); @@ -400,7 +404,7 @@ pub(crate) mod tests { with_aws_storage_fn( |storage| async move { let search = VcfSearch::new(storage); - let query = Query::new_with_default_request("spec-v4.3", Format::Vcf).with_class(Header); + let query = Query::new_with_defaults("spec-v4.3", Format::Vcf).with_class(Header); let response = search.search(query).await; assert!(response.is_err()); @@ -417,7 +421,7 @@ pub(crate) mod tests { ) -> Option<(String, ConcatResponse)> { let search = VcfSearch::new(storage.clone()); let filename = "sample1-bcbio-cancer"; - let query = Query::new_with_default_request(filename, Format::Vcf) + let query = Query::new_with_defaults(filename, Format::Vcf) .with_reference_name("chrM") .with_start(151) .with_end(153); @@ -436,11 +440,8 @@ pub(crate) mod tests { fn expected_vcf_response(filename: &str) -> Response { Response::new( Format::Vcf, - vec![ - Url::new(expected_url(filename)) - .with_headers(Headers::default().with_header("Range", "bytes=0-3465")), - Url::new(expected_bgzf_eof_data_url()), - ], + vec![Url::new(expected_url(filename)) + .with_headers(Headers::default().with_header("Range", "bytes=0-3493"))], ) } diff --git a/htsget-search/src/lib.rs b/htsget-search/src/lib.rs index f8da24578..8ee37e0e4 100644 --- a/htsget-search/src/lib.rs +++ b/htsget-search/src/lib.rs @@ -1,6 +1,6 @@ pub use htsget_config::config::{Config, DataServerConfig, ServiceInfo, TicketServerConfig}; pub use htsget_config::resolver::{ - IdResolver, QueryAllowed, ResolveResponse, Resolver, StorageResolver, + allow_guard::QueryAllowed, IdResolver, ResolveResponse, Resolver, StorageResolver, }; pub use htsget_config::storage::Storage; pub use htsget_config::types::{ diff --git a/htsget-search/src/storage/local.rs b/htsget-search/src/storage/local.rs index ec02b8829..e5be189f4 100644 --- a/htsget-search/src/storage/local.rs +++ b/htsget-search/src/storage/local.rs @@ -11,7 +11,7 @@ use tracing::debug; use tracing::instrument; use url::Url; -use crate::storage::{HeadOptions, Storage, UrlFormatter}; +use crate::storage::{HeadOptions, HeadOutput, Storage, UrlFormatter}; use crate::Url as HtsGetUrl; use super::{GetOptions, RangeUrlOptions, Result, StorageError}; @@ -87,6 +87,7 @@ impl Storage for LocalStorage { &self, key: K, _options: GetOptions<'_>, + _head_output: &mut Option<&mut HeadOutput>, ) -> Result { debug!(calling_from = ?self, key = key.as_ref(), "getting file with key {:?}", key.as_ref()); self.get(key).await @@ -129,7 +130,7 @@ impl Storage for LocalStorage { &self, key: K, _options: HeadOptions<'_>, - ) -> Result { + ) -> Result { let path = self.get_path_from_key(&key)?; let len = tokio::fs::metadata(path) .await @@ -137,7 +138,7 @@ impl Storage for LocalStorage { .len(); debug!(calling_from = ?self, key = key.as_ref(), len, "size of key {:?} is {}", key.as_ref(), len); - Ok(len) + Ok(len.into()) } } @@ -174,7 +175,8 @@ pub(crate) mod tests { let result = Storage::get( &storage, "folder", - GetOptions::new_with_default_range(&Default::default()), + GetOptions::new_with_default_range(&Default::default(), &Default::default()), + &mut Default::default(), ) .await; assert!(matches!(result, Err(StorageError::KeyNotFound(msg)) if msg == "folder")); @@ -188,7 +190,8 @@ pub(crate) mod tests { let result = Storage::get( &storage, "folder/../../passwords", - GetOptions::new_with_default_range(&Default::default()), + GetOptions::new_with_default_range(&Default::default(), &Default::default()), + &mut Default::default(), ) .await; assert!( @@ -204,7 +207,8 @@ pub(crate) mod tests { let result = Storage::get( &storage, "folder/../key1", - GetOptions::new_with_default_range(&Default::default()), + GetOptions::new_with_default_range(&Default::default(), &Default::default()), + &mut Default::default(), ) .await; assert!(result.is_ok()); @@ -218,7 +222,7 @@ pub(crate) mod tests { let result = Storage::range_url( &storage, "non-existing-key", - RangeUrlOptions::new_with_default_range(&Default::default()), + RangeUrlOptions::new_with_default_range(&Default::default(), &Default::default()), ) .await; assert!(matches!(result, Err(StorageError::KeyNotFound(msg)) if msg == "non-existing-key")); @@ -232,7 +236,7 @@ pub(crate) mod tests { let result = Storage::range_url( &storage, "folder", - RangeUrlOptions::new_with_default_range(&Default::default()), + RangeUrlOptions::new_with_default_range(&Default::default(), &Default::default()), ) .await; assert!(matches!(result, Err(StorageError::KeyNotFound(msg)) if msg == "folder")); @@ -246,7 +250,7 @@ pub(crate) mod tests { let result = Storage::range_url( &storage, "folder/../../passwords", - RangeUrlOptions::new_with_default_range(&Default::default()), + RangeUrlOptions::new_with_default_range(&Default::default(), &Default::default()), ) .await; assert!( @@ -262,7 +266,7 @@ pub(crate) mod tests { let result = Storage::range_url( &storage, "folder/../key1", - RangeUrlOptions::new_with_default_range(&Default::default()), + RangeUrlOptions::new_with_default_range(&Default::default(), &Default::default()), ) .await; let expected = Url::new("http://127.0.0.1:8081/data/key1"); @@ -280,6 +284,7 @@ pub(crate) mod tests { RangeUrlOptions::new( BytesPosition::new(Some(7), Some(10), None), &Default::default(), + &Default::default(), ), ) .await; @@ -296,7 +301,11 @@ pub(crate) mod tests { let result = Storage::range_url( &storage, "folder/../key1", - RangeUrlOptions::new(BytesPosition::new(Some(7), None, None), &Default::default()), + RangeUrlOptions::new( + BytesPosition::new(Some(7), None, None), + &Default::default(), + &Default::default(), + ), ) .await; let expected = Url::new("http://127.0.0.1:8081/data/key1") @@ -312,11 +321,11 @@ pub(crate) mod tests { let result = Storage::head( &storage, "folder/../key1", - HeadOptions::new(&Default::default()), + HeadOptions::new(&Default::default(), &Default::default()), ) .await; let expected: u64 = 6; - assert!(matches!(result, Ok(size) if size == expected)); + assert!(matches!(result, Ok(size) if size.content_length() == expected)); }) .await; } diff --git a/htsget-search/src/storage/mod.rs b/htsget-search/src/storage/mod.rs index 812061481..e72aabb48 100644 --- a/htsget-search/src/storage/mod.rs +++ b/htsget-search/src/storage/mod.rs @@ -5,6 +5,8 @@ use std::fmt::{Debug, Display, Formatter}; use std::io; use std::io::ErrorKind; use std::net::AddrParseError; +use std::num::ParseIntError; +use std::str::FromStr; use std::time::Duration; use async_trait::async_trait; @@ -16,7 +18,12 @@ use tokio::io::AsyncRead; use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer, ExposeHeaders}; use tracing::instrument; +#[cfg(feature = "crypt4gh")] +use async_crypt4gh::util::{unencrypted_clamp, unencrypted_clamp_next}; +#[cfg(feature = "crypt4gh")] +use async_crypt4gh::util::{unencrypted_to_data_block, unencrypted_to_next_data_block}; use htsget_config::config::cors::CorsConfig; +use htsget_config::resolver::object::ObjectType; use htsget_config::storage::local::LocalStorage; use htsget_config::types::{Class, Scheme}; @@ -31,6 +38,39 @@ pub mod url; type Result = core::result::Result; +/// Output for the head function in Storage. +#[derive(Debug, Default, Clone)] +pub struct HeadOutput { + pub(crate) content_length: u64, + pub(crate) response_headers: Option, +} + +impl HeadOutput { + /// Create a new HeadOutput. + pub fn new(content_length: u64, response_headers: Option) -> Self { + Self { + content_length, + response_headers, + } + } + + /// Get the content length. + pub fn content_length(&self) -> u64 { + self.content_length + } + + /// Get any additional response headers. + pub fn response_headers(&self) -> Option<&HeaderMap> { + self.response_headers.as_ref() + } +} + +impl From for HeadOutput { + fn from(content_length: u64) -> Self { + Self::new(content_length, None) + } +} + /// A Storage represents some kind of object based storage (either locally or in the cloud) /// that can be used to retrieve files for alignments, variants or its respective indexes. #[async_trait] @@ -42,6 +82,7 @@ pub trait Storage { &self, key: K, options: GetOptions<'_>, + head_output: &mut Option<&mut HeadOutput>, ) -> Result; /// Get the url of the object represented by the key using a bytes range. It is not required for @@ -57,7 +98,7 @@ pub trait Storage { &self, key: K, options: HeadOptions<'_>, - ) -> Result; + ) -> Result; /// Get the url of the object using an inline data uri. #[instrument(level = "trace", ret)] @@ -71,6 +112,18 @@ pub trait Storage { )) .set_class(class) } + + /// Optionally update byte positions before they are passed to the other functions. + #[instrument(level = "trace", ret, skip(self, _reader))] + async fn update_byte_positions( + &self, + _reader: Self::Streamable, + positions_options: BytesPositionOptions<'_>, + ) -> Result> { + Ok(DataBlock::from_bytes_positions( + positions_options.merge_all().into_inner(), + )) + } } /// Formats a url for use with storage. @@ -207,12 +260,9 @@ pub enum DataBlock { } impl DataBlock { - /// Convert a vec of bytes positions to a vec of data blocks. Merges bytes positions. + /// Convert a vec of bytes positions to a vec of data blocks. pub fn from_bytes_positions(positions: Vec) -> Vec { - BytesPosition::merge_all(positions) - .into_iter() - .map(DataBlock::Range) - .collect() + positions.into_iter().map(DataBlock::Range).collect() } /// Update the classes of all blocks so that they all contain a class, or None. Does not merge @@ -263,6 +313,38 @@ impl From<&BytesRange> for String { } } +/// Convert from a http range to a bytes position. +impl FromStr for BytesPosition { + type Err = StorageError; + + fn from_str(range: &str) -> Result { + let range = range.replacen("bytes=", "", 1); + + let split: Vec<&str> = range.splitn(2, '-').collect(); + if split.len() > 2 { + return Err(StorageError::InternalError( + "failed to split range".to_string(), + )); + } + + let parse_range = |range: Option<&str>| { + let range = range.unwrap_or_default(); + if range.is_empty() { + Ok::<_, Self::Err>(None) + } else { + Ok(Some(range.parse().map_err(|err: ParseIntError| { + StorageError::InternalError(err.to_string()) + })?)) + } + }; + + let start = parse_range(split.first().copied())?; + let end = parse_range(split.last().copied())?.map(|value| value + 1); + + Ok(Self::new(start, end, None)) + } +} + impl Display for BytesRange { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let start = self @@ -390,24 +472,55 @@ impl BytesPosition { optimized_ranges } } + + /// Convert the range to crypt4gh byte range. + #[cfg(feature = "crypt4gh")] + pub fn convert_to_crypt4gh_ranges(mut self, crypt4gh_header_length: u64, file_size: u64) -> Self { + self.start = self + .start + .map(|start| unencrypted_to_data_block(start, crypt4gh_header_length, file_size)); + self.end = self + .end + .map(|end| unencrypted_to_next_data_block(end, crypt4gh_header_length, file_size)); + + self + } + + /// Convert the range to clamped crypt4gh ranges. + #[cfg(feature = "crypt4gh")] + pub fn convert_to_clamped_crypt4gh_ranges(mut self, file_size: u64) -> Self { + self.start = self.start.map(|start| unencrypted_clamp(start, file_size)); + self.end = self.end.map(|end| unencrypted_clamp_next(end, file_size)); + + self + } } #[derive(Debug)] pub struct GetOptions<'a> { range: BytesPosition, request_headers: &'a HeaderMap, + object_type: &'a ObjectType, } impl<'a> GetOptions<'a> { - pub fn new(range: BytesPosition, request_headers: &'a HeaderMap) -> Self { + pub fn new( + range: BytesPosition, + request_headers: &'a HeaderMap, + object_type: &'a ObjectType, + ) -> Self { Self { range, request_headers, + object_type, } } - pub fn new_with_default_range(request_headers: &'a HeaderMap) -> Self { - Self::new(Default::default(), request_headers) + pub fn new_with_default_range( + request_headers: &'a HeaderMap, + object_type: &'a ObjectType, + ) -> Self { + Self::new(Default::default(), request_headers, object_type) } pub fn with_max_length(mut self, max_length: u64) -> Self { @@ -420,6 +533,11 @@ impl<'a> GetOptions<'a> { self } + /// Get the object type. + pub fn object_type(&self) -> &ObjectType { + self.object_type + } + /// Get the range. pub fn range(&self) -> &BytesPosition { &self.range @@ -431,22 +549,96 @@ impl<'a> GetOptions<'a> { } } +#[derive(Debug, Clone)] +pub struct BytesPositionOptions<'a> { + positions: Vec, + file_size: u64, + headers: &'a HeaderMap, + object_type: &'a ObjectType, +} + +impl<'a> BytesPositionOptions<'a> { + pub fn new( + positions: Vec, + file_size: u64, + headers: &'a HeaderMap, + object_type: &'a ObjectType, + ) -> Self { + Self { + positions, + file_size, + headers, + object_type, + } + } + + /// Get the response headers. + pub fn headers(&self) -> &'a HeaderMap { + self.headers + } + + pub fn positions(&self) -> &Vec { + &self.positions + } + + pub fn file_size(&self) -> u64 { + self.file_size + } + + /// Get the inner value. + pub fn into_inner(self) -> Vec { + self.positions + } + + /// Merge all bytes positions + pub fn merge_all(mut self) -> Self { + self.positions = BytesPosition::merge_all(self.positions); + self + } + + /// Get the object type. + pub fn object_type(&self) -> &ObjectType { + self.object_type + } + + /// Convert the ranges to crypt4gh byte ranges. Does not include the crypt4gh header. + #[cfg(feature = "crypt4gh")] + pub fn convert_to_crypt4gh_ranges(mut self, header_length: u64, file_size: u64) -> Self { + self.positions = self + .positions + .into_iter() + .map(|pos| pos.convert_to_crypt4gh_ranges(header_length, file_size)) + .collect(); + + self + } +} + #[derive(Debug)] pub struct RangeUrlOptions<'a> { range: BytesPosition, response_headers: &'a HeaderMap, + object_type: &'a ObjectType, } impl<'a> RangeUrlOptions<'a> { - pub fn new(range: BytesPosition, response_headers: &'a HeaderMap) -> Self { + pub fn new( + range: BytesPosition, + response_headers: &'a HeaderMap, + object_type: &'a ObjectType, + ) -> Self { Self { range, response_headers, + object_type, } } - pub fn new_with_default_range(request_headers: &'a HeaderMap) -> Self { - Self::new(Default::default(), request_headers) + pub fn new_with_default_range( + request_headers: &'a HeaderMap, + object_type: &'a ObjectType, + ) -> Self { + Self::new(Default::default(), request_headers, object_type) } pub fn with_range(mut self, range: BytesPosition) -> Self { @@ -475,29 +667,43 @@ impl<'a> RangeUrlOptions<'a> { pub fn response_headers(&self) -> &'a HeaderMap { self.response_headers } + + /// Get the object type. + pub fn object_type(&self) -> &ObjectType { + self.object_type + } } /// A struct to represent options passed to a `Storage` head call. #[derive(Debug)] pub struct HeadOptions<'a> { request_headers: &'a HeaderMap, + object_type: &'a ObjectType, } impl<'a> HeadOptions<'a> { /// Create a new HeadOptions struct. - pub fn new(request_headers: &'a HeaderMap) -> Self { - Self { request_headers } + pub fn new(request_headers: &'a HeaderMap, object_type: &'a ObjectType) -> Self { + Self { + request_headers, + object_type, + } } /// Get the request headers. pub fn request_headers(&self) -> &'a HeaderMap { self.request_headers } + + /// Get the object type. + pub fn object_type(&self) -> &ObjectType { + self.object_type + } } #[cfg(test)] mod tests { - use std::collections::HashMap; + use std::collections::BTreeMap; use http::uri::Authority; @@ -881,10 +1087,10 @@ mod tests { #[test] fn data_block_from_bytes_positions() { - let blocks = DataBlock::from_bytes_positions(vec![ + let blocks = DataBlock::from_bytes_positions(BytesPosition::merge_all(vec![ BytesPosition::new(None, Some(1), None), BytesPosition::new(Some(1), Some(2), None), - ]); + ])); assert_eq!( blocks, vec![DataBlock::Range(BytesPosition::new(None, Some(2), None))] @@ -901,7 +1107,10 @@ mod tests { #[test] fn get_options_with_max_length() { let request_headers = Default::default(); - let result = GetOptions::new_with_default_range(&request_headers).with_max_length(1); + let object_type = Default::default(); + + let result = + GetOptions::new_with_default_range(&request_headers, &object_type).with_max_length(1); assert_eq!( result.range(), &BytesPosition::default().with_start(0).with_end(1) @@ -911,7 +1120,9 @@ mod tests { #[test] fn get_options_with_range() { let request_headers = Default::default(); - let result = GetOptions::new_with_default_range(&request_headers) + let object_type = Default::default(); + + let result = GetOptions::new_with_default_range(&request_headers, &object_type) .with_range(BytesPosition::new(Some(5), Some(11), Some(Class::Header))); assert_eq!( result.range(), @@ -922,7 +1133,8 @@ mod tests { #[test] fn url_options_with_range() { let request_headers = Default::default(); - let result = RangeUrlOptions::new_with_default_range(&request_headers) + let object_type = Default::default(); + let result = RangeUrlOptions::new_with_default_range(&request_headers, &object_type) .with_range(BytesPosition::new(Some(5), Some(11), Some(Class::Header))); assert_eq!( result.range(), @@ -935,20 +1147,22 @@ mod tests { let result = RangeUrlOptions::new( BytesPosition::new(Some(5), Some(11), Some(Class::Header)), &Default::default(), + &Default::default(), ) .apply(Url::new("")); println!("{result:?}"); assert_eq!( result, Url::new("") - .with_headers(Headers::new(HashMap::new()).with_header("Range", "bytes=5-10")) + .with_headers(Headers::new(BTreeMap::new()).with_header("Range", "bytes=5-10")) .with_class(Class::Header) ); } #[test] fn url_options_apply_no_bytes_range() { - let result = RangeUrlOptions::new_with_default_range(&Default::default()).apply(Url::new("")); + let result = RangeUrlOptions::new_with_default_range(&Default::default(), &Default::default()) + .apply(Url::new("")); assert_eq!(result, Url::new("")); } @@ -957,6 +1171,7 @@ mod tests { let result = RangeUrlOptions::new( BytesPosition::new(Some(5), Some(11), Some(Class::Header)), &Default::default(), + &Default::default(), ) .apply(Url::new("").with_headers(Headers::default().with_header("header", "value"))); println!("{result:?}"); @@ -965,7 +1180,7 @@ mod tests { result, Url::new("") .with_headers( - Headers::new(HashMap::new()) + Headers::new(BTreeMap::new()) .with_header("Range", "bytes=5-10") .with_header("header", "value") ) diff --git a/htsget-search/src/storage/s3.rs b/htsget-search/src/storage/s3.rs index a2bb89414..b73dca117 100644 --- a/htsget-search/src/storage/s3.rs +++ b/htsget-search/src/storage/s3.rs @@ -23,7 +23,7 @@ use tracing::{debug, warn}; use crate::storage::s3::Retrieval::{Delayed, Immediate}; use crate::storage::StorageError::{AwsS3Error, KeyNotFound}; -use crate::storage::{BytesPosition, HeadOptions, StorageError}; +use crate::storage::{BytesPosition, HeadOptions, HeadOutput, StorageError}; use crate::storage::{BytesRange, Storage}; use crate::Url; @@ -219,6 +219,7 @@ impl Storage for S3Storage { &self, key: K, options: GetOptions<'_>, + _head_output: &mut Option<&mut HeadOutput>, ) -> Result { let key = key.as_ref(); debug!(calling_from = ?self, key, "getting file with key {:?}", key); @@ -248,7 +249,7 @@ impl Storage for S3Storage { &self, key: K, _options: HeadOptions<'_>, - ) -> Result { + ) -> Result { let key = key.as_ref(); let head = self.s3_head(key).await?; @@ -260,7 +261,7 @@ impl Storage for S3Storage { })?; debug!(calling_from = ?self, key, len, "size of key {:?} is {}", key, len); - Ok(len) + Ok(len.into()) } } @@ -304,7 +305,8 @@ pub(crate) mod tests { let result = storage .get( "key2", - GetOptions::new_with_default_range(&Default::default()), + GetOptions::new_with_default_range(&Default::default(), &Default::default()), + &mut Default::default(), ) .await; assert!(result.is_ok()); @@ -318,7 +320,8 @@ pub(crate) mod tests { let result = storage .get( "non-existing-key", - GetOptions::new_with_default_range(&Default::default()), + GetOptions::new_with_default_range(&Default::default(), &Default::default()), + &mut Default::default(), ) .await; assert!(matches!(result, Err(StorageError::AwsS3Error(_, _)))); @@ -332,7 +335,7 @@ pub(crate) mod tests { let result = storage .range_url( "key2", - RangeUrlOptions::new_with_default_range(&Default::default()), + RangeUrlOptions::new_with_default_range(&Default::default(), &Default::default()), ) .await .unwrap(); @@ -354,6 +357,7 @@ pub(crate) mod tests { RangeUrlOptions::new( BytesPosition::new(Some(7), Some(9), None), &Default::default(), + &Default::default(), ), ) .await @@ -378,7 +382,11 @@ pub(crate) mod tests { let result = storage .range_url( "key2", - RangeUrlOptions::new(BytesPosition::new(Some(7), None, None), &Default::default()), + RangeUrlOptions::new( + BytesPosition::new(Some(7), None, None), + &Default::default(), + &Default::default(), + ), ) .await .unwrap(); @@ -399,11 +407,12 @@ pub(crate) mod tests { #[tokio::test] async fn file_size() { with_aws_s3_storage(|storage| async move { + let object_type = Default::default(); let result = storage - .head("key2", HeadOptions::new(&Default::default())) + .head("key2", HeadOptions::new(&Default::default(), &object_type)) .await; let expected: u64 = 6; - assert!(matches!(result, Ok(size) if size == expected)); + assert!(matches!(result, Ok(size) if size.content_length() == expected)); }) .await; } diff --git a/htsget-search/src/storage/url.rs b/htsget-search/src/storage/url.rs deleted file mode 100644 index 2ac418cba..000000000 --- a/htsget-search/src/storage/url.rs +++ /dev/null @@ -1,648 +0,0 @@ -use std::fmt::Debug; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use async_trait::async_trait; -use bytes::Bytes; -use futures::Stream; -use futures_util::TryStreamExt; -use http::header::CONTENT_LENGTH; -use http::{HeaderMap, Method, Request, Uri}; -use pin_project_lite::pin_project; -use reqwest::{Client, ClientBuilder}; -use tokio_util::io::StreamReader; -use tracing::{debug, instrument}; - -use htsget_config::error; - -use crate::storage::StorageError::{InternalError, KeyNotFound, ResponseError, UrlParseError}; -use crate::storage::{GetOptions, HeadOptions, RangeUrlOptions, Result, Storage, StorageError}; -use crate::Url as HtsGetUrl; - -/// A storage struct which derives data from HTTP URLs. -#[derive(Debug, Clone)] -pub struct UrlStorage { - client: Client, - url: Uri, - response_url: Uri, - forward_headers: bool, - header_blacklist: Vec, -} - -impl UrlStorage { - /// Construct a new UrlStorage. - pub fn new( - client: Client, - url: Uri, - response_url: Uri, - forward_headers: bool, - header_blacklist: Vec, - ) -> Self { - Self { - client, - url, - response_url, - forward_headers, - header_blacklist, - } - } - - /// Construct a new UrlStorage with a default client. - pub fn new_with_default_client( - url: Uri, - response_url: Uri, - forward_headers: bool, - header_blacklist: Vec, - ) -> Result { - Ok(Self { - client: ClientBuilder::new() - .build() - .map_err(|err| InternalError(format!("failed to build reqwest client: {}", err)))?, - url, - response_url, - forward_headers, - header_blacklist, - }) - } - - /// Get a url from the key. - pub fn get_url_from_key + Send>(&self, key: K) -> Result { - format!("{}{}", self.url, key.as_ref()) - .parse::() - .map_err(|err| UrlParseError(err.to_string())) - } - - /// Get a url from the key. - pub fn get_response_url_from_key + Send>(&self, key: K) -> Result { - format!("{}{}", self.response_url, key.as_ref()) - .parse::() - .map_err(|err| UrlParseError(err.to_string())) - } - - /// Remove blacklisted headers from the headers. - pub fn remove_blacklisted_headers(&self, mut headers: HeaderMap) -> HeaderMap { - for blacklisted_header in &self.header_blacklist { - headers.remove(blacklisted_header); - } - headers - } - - /// Construct and send a request - pub async fn send_request + Send>( - &self, - key: K, - headers: &HeaderMap, - method: Method, - ) -> Result { - let key = key.as_ref(); - let url = self.get_url_from_key(key)?; - - let request = Request::builder().method(method).uri(&url); - - let request = headers - .iter() - .fold(request, |acc, (key, value)| acc.header(key, value)) - .body(vec![]) - .map_err(|err| UrlParseError(err.to_string()))?; - - let response = self - .client - .execute( - request - .try_into() - .map_err(|err| InternalError(format!("failed to create http request: {}", err)))?, - ) - .await - .map_err(|err| KeyNotFound(format!("{} with key {}", err, key)))?; - - let status = response.status(); - if status.is_client_error() || status.is_server_error() { - Err(KeyNotFound(format!( - "url returned {} for key {}", - status, key - ))) - } else { - Ok(response) - } - } - - /// Construct and send a request - pub fn format_url + Send>( - &self, - key: K, - options: RangeUrlOptions<'_>, - ) -> Result { - let url = self.get_response_url_from_key(key)?.into_parts(); - let url = Uri::from_parts(url) - .map_err(|err| InternalError(format!("failed to convert to uri from parts: {}", err)))?; - - let mut url = HtsGetUrl::new(url.to_string()); - if self.forward_headers { - url = url.with_headers( - options - .response_headers() - .try_into() - .map_err(|err: error::Error| StorageError::InvalidInput(err.to_string()))?, - ) - } - - Ok(options.apply(url)) - } - - /// Get the head from the key. - pub async fn head_key + Send>( - &self, - key: K, - headers: &HeaderMap, - ) -> Result { - self.send_request(key, headers, Method::HEAD).await - } - - /// Get the key. - pub async fn get_key + Send>( - &self, - key: K, - headers: &HeaderMap, - ) -> Result { - self.send_request(key, headers, Method::GET).await - } -} - -pin_project! { - /// A wrapper around a stream used by `UrlStorage`. - pub struct UrlStream { - #[pin] - inner: Box> + Unpin + Send + Sync> - } -} - -impl UrlStream { - /// Create a new UrlStream. - pub fn new(inner: Box> + Unpin + Send + Sync>) -> Self { - Self { inner } - } -} - -impl Stream for UrlStream { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_next(cx) - } -} - -#[async_trait] -impl Storage for UrlStorage { - type Streamable = StreamReader; - - #[instrument(level = "trace", skip(self))] - async fn get + Send + Debug>( - &self, - key: K, - options: GetOptions<'_>, - ) -> Result { - let key = key.as_ref().to_string(); - debug!(calling_from = ?self, key, "getting file with key {:?}", key); - - let request_headers = self.remove_blacklisted_headers(options.request_headers().clone()); - let response = self.get_key(key.to_string(), &request_headers).await?; - - Ok(StreamReader::new(UrlStream::new(Box::new( - response - .bytes_stream() - .map_err(|err| ResponseError(format!("reading body from response: {}", err))), - )))) - } - - #[instrument(level = "trace", skip(self))] - async fn range_url + Send + Debug>( - &self, - key: K, - options: RangeUrlOptions<'_>, - ) -> Result { - let key = key.as_ref(); - debug!(calling_from = ?self, key, "getting url with key {:?}", key); - - let response_headers = self.remove_blacklisted_headers(options.response_headers().clone()); - let new_options = RangeUrlOptions::new(options.range().clone(), &response_headers); - - self.format_url(key, new_options) - } - - #[instrument(level = "trace", skip(self))] - async fn head + Send + Debug>( - &self, - key: K, - options: HeadOptions<'_>, - ) -> Result { - let key = key.as_ref(); - - let request_headers = self.remove_blacklisted_headers(options.request_headers().clone()); - let head = self.head_key(key, &request_headers).await?; - - let len = head - .headers() - .get(CONTENT_LENGTH) - .and_then(|content_length| content_length.to_str().ok()) - .and_then(|content_length| content_length.parse().ok()) - .ok_or_else(|| { - ResponseError(format!( - "failed to get content length from head response for key: {}", - key - )) - })?; - - debug!(calling_from = ?self, key, len, "size of key {:?} is {}", key, len); - Ok(len) - } -} - -#[cfg(test)] -mod tests { - use std::future::Future; - use std::net::TcpListener; - use std::path::Path; - use std::str::FromStr; - use std::{result, vec}; - - use axum::middleware::Next; - use axum::response::Response; - use axum::{middleware, Router}; - use http::header::{AUTHORIZATION, HOST}; - use http::{HeaderName, HeaderValue, Request, StatusCode}; - use tokio::io::AsyncReadExt; - use tower_http::services::ServeDir; - - use htsget_config::types::Headers; - - use crate::storage::local::tests::create_local_test_files; - - use super::*; - - #[test] - fn get_url_from_key() { - let storage = UrlStorage::new( - test_client(), - Uri::from_str("https://example.com").unwrap(), - Uri::from_str("https://localhost:8080").unwrap(), - true, - vec![], - ); - - assert_eq!( - storage.get_url_from_key("assets/key1").unwrap(), - Uri::from_str("https://example.com/assets/key1").unwrap() - ); - } - - #[test] - fn get_response_url_from_key() { - let storage = UrlStorage::new( - test_client(), - Uri::from_str("https://example.com").unwrap(), - Uri::from_str("https://localhost:8080").unwrap(), - true, - vec![], - ); - - assert_eq!( - storage.get_response_url_from_key("assets/key1").unwrap(), - Uri::from_str("https://localhost:8080/assets/key1").unwrap() - ); - } - - #[test] - fn remove_blacklisted_headers() { - let storage = UrlStorage::new( - test_client(), - Uri::from_str("https://example.com").unwrap(), - Uri::from_str("https://localhost:8080").unwrap(), - true, - vec![HOST.to_string()], - ); - - let mut headers = HeaderMap::default(); - headers.insert( - HeaderName::from_str(HOST.as_str()).unwrap(), - HeaderValue::from_str("example.com").unwrap(), - ); - headers.insert( - HeaderName::from_str(AUTHORIZATION.as_str()).unwrap(), - HeaderValue::from_str("secret").unwrap(), - ); - - let headers = storage.remove_blacklisted_headers(headers.clone()); - - assert_eq!(headers.len(), 1); - } - - #[tokio::test] - async fn send_request() { - with_url_test_server(|url| async move { - let storage = UrlStorage::new( - test_client(), - Uri::from_str(&url).unwrap(), - Uri::from_str(&url).unwrap(), - true, - vec![], - ); - - let mut headers = HeaderMap::default(); - let headers = test_headers(&mut headers); - - let response = String::from_utf8( - storage - .send_request("assets/key1", headers, Method::GET) - .await - .unwrap() - .bytes() - .await - .unwrap() - .to_vec(), - ) - .unwrap(); - assert_eq!(response, "value1"); - }) - .await; - } - - #[tokio::test] - async fn get_key() { - with_url_test_server(|url| async move { - let storage = UrlStorage::new( - test_client(), - Uri::from_str(&url).unwrap(), - Uri::from_str(&url).unwrap(), - true, - vec![], - ); - - let mut headers = HeaderMap::default(); - let headers = test_headers(&mut headers); - - let response = String::from_utf8( - storage - .get_key("assets/key1", headers) - .await - .unwrap() - .bytes() - .await - .unwrap() - .to_vec(), - ) - .unwrap(); - assert_eq!(response, "value1"); - }) - .await; - } - - #[tokio::test] - async fn head_key() { - with_url_test_server(|url| async move { - let storage = UrlStorage::new( - test_client(), - Uri::from_str(&url).unwrap(), - Uri::from_str(&url).unwrap(), - true, - vec![], - ); - - let mut headers = HeaderMap::default(); - let headers = test_headers(&mut headers); - - let response: u64 = storage - .get_key("assets/key1", headers) - .await - .unwrap() - .headers() - .get(CONTENT_LENGTH) - .unwrap() - .to_str() - .unwrap() - .parse() - .unwrap(); - assert_eq!(response, 6); - }) - .await; - } - - #[tokio::test] - async fn get_storage() { - with_url_test_server(|url| async move { - let storage = UrlStorage::new( - test_client(), - Uri::from_str(&url).unwrap(), - Uri::from_str(&url).unwrap(), - true, - vec![], - ); - - let mut headers = HeaderMap::default(); - let headers = test_headers(&mut headers); - let options = GetOptions::new_with_default_range(headers); - - let mut reader = storage.get("assets/key1", options).await.unwrap(); - - let mut response = [0; 6]; - reader.read_exact(&mut response).await.unwrap(); - - assert_eq!(String::from_utf8(response.to_vec()).unwrap(), "value1"); - }) - .await; - } - - #[tokio::test] - async fn range_url_storage() { - with_url_test_server(|url| async move { - let storage = UrlStorage::new( - test_client(), - Uri::from_str(&url).unwrap(), - Uri::from_str(&url).unwrap(), - true, - vec![], - ); - - let mut headers = HeaderMap::default(); - let options = test_range_options(&mut headers); - - assert_eq!( - storage.range_url("assets/key1", options).await.unwrap(), - HtsGetUrl::new(format!("{}/assets/key1", url)) - .with_headers(Headers::default().with_header(AUTHORIZATION.as_str(), "secret")) - ); - }) - .await; - } - - #[tokio::test] - async fn range_url_storage_blacklisted_headers() { - with_url_test_server(|url| async move { - let storage = UrlStorage::new( - test_client(), - Uri::from_str(&url).unwrap(), - Uri::from_str(&url).unwrap(), - true, - vec![HOST.to_string()], - ); - - let mut headers = HeaderMap::default(); - headers.insert( - HeaderName::from_str(HOST.as_str()).unwrap(), - HeaderValue::from_str("example.com").unwrap(), - ); - - let options = test_range_options(&mut headers); - - assert_eq!( - storage.range_url("assets/key1", options).await.unwrap(), - HtsGetUrl::new(format!("{}/assets/key1", url)) - .with_headers(Headers::default().with_header(AUTHORIZATION.as_str(), "secret")) - ); - }) - .await; - } - - #[tokio::test] - async fn head_storage() { - with_url_test_server(|url| async move { - let storage = UrlStorage::new( - test_client(), - Uri::from_str(&url).unwrap(), - Uri::from_str(&url).unwrap(), - true, - vec![], - ); - - let mut headers = HeaderMap::default(); - let headers = test_headers(&mut headers); - let options = HeadOptions::new(headers); - - assert_eq!(storage.head("assets/key1", options).await.unwrap(), 6); - }) - .await; - } - - #[test] - fn format_url() { - let storage = UrlStorage::new( - test_client(), - Uri::from_str("https://example.com").unwrap(), - Uri::from_str("https://localhost:8080").unwrap(), - true, - vec![], - ); - - let mut headers = HeaderMap::default(); - let options = test_range_options(&mut headers); - - assert_eq!( - storage.format_url("assets/key1", options).unwrap(), - HtsGetUrl::new("https://localhost:8080/assets/key1") - .with_headers(Headers::default().with_header(AUTHORIZATION.as_str(), "secret")) - ); - } - - #[test] - fn format_url_different_response_scheme() { - let storage = UrlStorage::new( - test_client(), - Uri::from_str("https://example.com").unwrap(), - Uri::from_str("http://example.com").unwrap(), - true, - vec![], - ); - - let mut headers = HeaderMap::default(); - let options = test_range_options(&mut headers); - - assert_eq!( - storage.format_url("assets/key1", options).unwrap(), - HtsGetUrl::new("http://example.com/assets/key1") - .with_headers(Headers::default().with_header(AUTHORIZATION.as_str(), "secret")) - ); - } - - #[test] - fn format_url_no_headers() { - let storage = UrlStorage::new( - test_client(), - Uri::from_str("https://example.com").unwrap(), - Uri::from_str("https://localhost:8081").unwrap(), - false, - vec![], - ); - - let mut headers = HeaderMap::default(); - let options = test_range_options(&mut headers); - - assert_eq!( - storage.format_url("assets/key1", options).unwrap(), - HtsGetUrl::new("https://localhost:8081/assets/key1") - ); - } - - fn test_client() -> Client { - ClientBuilder::new().build().unwrap() - } - - pub(crate) async fn with_url_test_server(test: F) - where - F: FnOnce(String) -> Fut, - Fut: Future, - { - let (_, base_path) = create_local_test_files().await; - with_test_server(base_path.path(), test).await; - } - - async fn test_auth( - request: Request, - next: Next, - ) -> result::Result { - let auth_header = request - .headers() - .get(AUTHORIZATION) - .and_then(|header| header.to_str().ok()); - - match auth_header { - Some("secret") => Ok(next.run(request).await), - _ => Err(StatusCode::UNAUTHORIZED), - } - } - - pub(crate) async fn with_test_server(server_base_path: &Path, test: F) - where - F: FnOnce(String) -> Fut, - Fut: Future, - { - let router = Router::new() - .nest_service("/assets", ServeDir::new(server_base_path.to_str().unwrap())) - .route_layer(middleware::from_fn(test_auth)); - - // TODO fix this in htsget-test to bind and return tcp listener. - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - - tokio::spawn( - axum::Server::from_tcp(listener) - .unwrap() - .serve(router.into_make_service()), - ); - - test(format!("http://{}", addr)).await; - } - - fn test_headers(headers: &mut HeaderMap) -> &HeaderMap { - headers.append( - HeaderName::from_str(AUTHORIZATION.as_str()).unwrap(), - HeaderValue::from_str("secret").unwrap(), - ); - headers - } - - fn test_range_options(headers: &mut HeaderMap) -> RangeUrlOptions { - let headers = test_headers(headers); - let options = RangeUrlOptions::new_with_default_range(headers); - - options - } -} diff --git a/htsget-search/src/storage/url/encrypt.rs b/htsget-search/src/storage/url/encrypt.rs new file mode 100644 index 000000000..75962fb2d --- /dev/null +++ b/htsget-search/src/storage/url/encrypt.rs @@ -0,0 +1,69 @@ +use crate::storage::url::UrlStreamReader; +use crate::storage::Result; +use crate::storage::StorageError::UrlParseError; +use async_crypt4gh::edit_lists::{ClampedPosition, EditHeader, UnencryptedPosition}; +use async_crypt4gh::reader::Reader; +use async_crypt4gh::{util, KeyPair, PublicKey}; +use mockall::mock; +use std::fmt::{Debug, Formatter}; +use tokio_rustls::rustls::PrivateKey; + +/// A wrapper around url storage encryption. +#[derive(Debug, Clone, Default)] +pub struct Encrypt; + +impl Encrypt { + pub fn new_with_defaults() -> Self { + Self + } + + pub fn generate_key_pair(&self) -> Result { + util::generate_key_pair().map_err(|err| UrlParseError(err.to_string())) + } + + pub fn edit_list( + &self, + reader: &Reader, + unencrypted_positions: Vec, + clamped_positions: Vec, + private_key: PrivateKey, + public_key: PublicKey, + ) -> Result<(Vec, Vec)> { + let (header_info, _, edit_list_packet) = EditHeader::new( + reader, + unencrypted_positions, + clamped_positions, + private_key, + public_key, + ) + .edit_list() + .map_err(|err| UrlParseError(err.to_string()))? + .ok_or_else(|| UrlParseError("crypt4gh header has not been read".to_string()))? + .into_inner(); + + Ok((header_info, edit_list_packet)) + } +} + +mock! { + pub Encrypt { + pub fn generate_key_pair(&self) -> Result; + + pub fn edit_list( + &self, + reader: &Reader, + unencrypted_positions: Vec, + clamped_positions: Vec, + private_key: PrivateKey, + public_key: PublicKey, + ) -> Result<(Vec, Vec)>; + } + + impl Clone for Encrypt { + fn clone(&self) -> Self; + } + + impl Debug for Encrypt { + fn fmt<'a>(&self, f: &mut Formatter<'a>) -> std::fmt::Result; + } +} diff --git a/htsget-search/src/storage/url/mod.rs b/htsget-search/src/storage/url/mod.rs new file mode 100644 index 000000000..6e1eec47c --- /dev/null +++ b/htsget-search/src/storage/url/mod.rs @@ -0,0 +1,1769 @@ +#[cfg(feature = "crypt4gh")] +pub mod encrypt; + +use std::fmt::Debug; +use std::pin::Pin; +use std::result; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::stream::MapErr; +use futures_util::{Stream, TryStreamExt}; +use http::header::{Entry, IntoHeaderName, CONTENT_LENGTH, USER_AGENT}; +use http::{HeaderMap, Method, Request, Uri}; +use pin_project::pin_project; +use reqwest::{Client, ClientBuilder}; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio_util::io::StreamReader; +use tracing::{debug, info, instrument}; +#[cfg(feature = "crypt4gh")] +use { + crate::storage::{BytesPosition, BytesRange}, + async_crypt4gh::edit_lists::{ClampedPosition, UnencryptedPosition}, + async_crypt4gh::reader::builder::Builder, + async_crypt4gh::reader::Reader, + async_crypt4gh::util::encode_public_key, + async_crypt4gh::util::to_encrypted_file_size, + async_crypt4gh::util::{read_public_key, to_unencrypted_file_size}, + async_crypt4gh::KeyPair, + async_crypt4gh::PublicKey, + base64::engine::general_purpose, + base64::Engine, + crypt4gh::Keys, + htsget_config::types::Class, + http::header::InvalidHeaderValue, + http::header::RANGE, + http::HeaderValue, + mockall_double::double, + tokio_rustls::rustls::PrivateKey, +}; + +#[cfg(feature = "crypt4gh")] +#[double] +use crate::storage::url::encrypt::Encrypt; +use crate::storage::StorageError::{InternalError, KeyNotFound, ResponseError, UrlParseError}; +use crate::storage::{ + BytesPositionOptions, DataBlock, GetOptions, HeadOptions, HeadOutput, RangeUrlOptions, Result, + Storage, StorageError, +}; +use crate::Url as HtsGetUrl; +use htsget_config::error; +#[cfg(feature = "crypt4gh")] +use htsget_config::resolver::object::ObjectType; +use htsget_config::storage::url::endpoints::Endpoints; +use htsget_config::types::{KeyType, Query}; + +pub const CLIENT_PUBLIC_KEY_NAME: &str = "client-public-key"; +pub const CLIENT_ADDITIONAL_BYTES: &str = "client-additional-bytes"; +pub const SERVER_ADDITIONAL_BYTES: &str = "server-additional-bytes"; + +static HTSGET_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); + +/// A storage struct which derives data from HTTP URLs. +#[derive(Debug, Clone)] +pub struct UrlStorage { + client: Client, + endpoints: Endpoints, + response_url: Uri, + forward_headers: bool, + header_blacklist: Vec, + user_agent: Option, + #[cfg(feature = "crypt4gh")] + key_pair: Option, + #[cfg(feature = "crypt4gh")] + encrypt: Encrypt, +} + +impl UrlStorage { + /// Construct a new UrlStorage. + #[allow(clippy::too_many_arguments)] + pub fn new( + client: Client, + endpoints: Endpoints, + response_url: Uri, + forward_headers: bool, + header_blacklist: Vec, + user_agent: Option, + _query: &Query, + #[cfg(feature = "crypt4gh")] _encrypt: Encrypt, + ) -> Result { + #[cfg(feature = "crypt4gh")] + let mut key_pair = None; + #[cfg(feature = "crypt4gh")] + if _query.object_type().crypt4gh_key_pair().is_none() { + key_pair = Some( + _encrypt + .generate_key_pair() + .map_err(|err| UrlParseError(err.to_string()))?, + ); + } + + Ok(Self { + client, + endpoints, + response_url, + forward_headers, + header_blacklist, + user_agent, + #[cfg(feature = "crypt4gh")] + key_pair, + #[cfg(feature = "crypt4gh")] + encrypt: _encrypt, + }) + } + + /// Construct a new UrlStorage with a default client. + pub fn new_with_default_client( + endpoints: Endpoints, + response_url: Uri, + forward_headers: bool, + header_blacklist: Vec, + user_agent: Option, + _query: &Query, + #[cfg(feature = "crypt4gh")] _encrypt: Encrypt, + ) -> Result { + #[cfg(feature = "crypt4gh")] + let mut key_pair = None; + #[cfg(feature = "crypt4gh")] + if _query.object_type().crypt4gh_key_pair().is_none() { + key_pair = Some( + _encrypt + .generate_key_pair() + .map_err(|err| UrlParseError(err.to_string()))?, + ); + } + + Ok(Self { + client: ClientBuilder::new() + .build() + .map_err(|err| InternalError(format!("failed to build reqwest client: {}", err)))?, + endpoints, + response_url, + forward_headers, + header_blacklist, + user_agent, + #[cfg(feature = "crypt4gh")] + key_pair, + #[cfg(feature = "crypt4gh")] + encrypt: _encrypt, + }) + } + + /// Construct the Crypt4GH query. + #[cfg(feature = "crypt4gh")] + fn encode_key(public_key: &PublicKey) -> String { + general_purpose::STANDARD.encode(public_key.get_ref()) + } + + /// Decode a public key using base64. + #[cfg(feature = "crypt4gh")] + fn decode_public_key(headers: &HeaderMap, name: &str) -> Result> { + general_purpose::STANDARD + .decode( + headers + .get(name) + .ok_or_else(|| StorageError::InvalidInput("no public key found in header".to_string()))? + .as_bytes(), + ) + .map_err(|err| StorageError::InvalidInput(format!("failed to decode public key: {}", err))) + } + + /// Get a url from the key. + pub fn get_url_from_key + Send>(&self, key: K, endpoint: &Uri) -> Result { + // Todo: proper url parsing here, probably with the `url` crate. + let uri = if endpoint.to_string().ends_with('/') { + format!("{}{}", endpoint, key.as_ref()) + } else { + format!("{}/{}", endpoint, key.as_ref()) + }; + + uri + .parse::() + .map_err(|err| UrlParseError(err.to_string())) + } + + /// Remove blacklisted headers from the headers. + pub fn remove_blacklisted_headers(&self, mut headers: HeaderMap) -> HeaderMap { + for blacklisted_header in &self.header_blacklist { + headers.remove(blacklisted_header); + } + headers + } + + /// Construct and send a request + pub async fn send_request + Send>( + &self, + key: K, + headers: &HeaderMap, + method: Method, + url: &Uri, + ) -> Result { + let key = key.as_ref(); + let url = self.get_url_from_key(key, url)?; + + let request = Request::builder().method(method).uri(&url); + + let request = headers + .iter() + .fold(request, |acc, (key, value)| acc.header(key, value)) + .header( + USER_AGENT, + self.user_agent.as_deref().unwrap_or(HTSGET_USER_AGENT), + ) + .body(vec![]) + .map_err(|err| UrlParseError(err.to_string()))?; + + debug!("Calling with request: {:#?}", &request); + + let response = self + .client + .execute( + request + .try_into() + .map_err(|err| InternalError(format!("failed to create http request: {}", err)))?, + ) + .await + .map_err(|err| KeyNotFound(format!("{} with key {}", err, key)))?; + + let status = response.status(); + if status.is_client_error() || status.is_server_error() { + Err(KeyNotFound(format!( + "url returned {} for key {}", + status, key + ))) + } else { + Ok(response) + } + } + + /// Construct and send a request + pub async fn format_url + Send>( + &self, + key: K, + options: RangeUrlOptions<'_>, + endpoint: &Uri, + ) -> Result { + let key = key.as_ref(); + + #[cfg(feature = "crypt4gh")] + let key = if options + .object_type() + .send_encrypted_to_client() + .is_some_and(|value| !value) + { + key.strip_suffix(".c4gh").unwrap_or(key) + } else { + key + }; + + let url = self.get_url_from_key(key, endpoint)?.into_parts(); + let url = Uri::from_parts(url) + .map_err(|err| InternalError(format!("failed to convert to uri from parts: {}", err)))?; + + let mut url = HtsGetUrl::new(url.to_string()); + if self.forward_headers { + url = url.with_headers( + options + .response_headers() + .try_into() + .map_err(|err: error::Error| StorageError::InvalidInput(err.to_string()))?, + ) + } + + Ok(options.apply(url)) + } + + /// Get the head from the key. + pub async fn head_key + Send>( + &self, + key: K, + headers: &HeaderMap, + ) -> Result { + self + .send_request(key, headers, Method::HEAD, self.endpoints.file()) + .await + } + + /// Get the key. + pub async fn get_header + Send>( + &self, + key: K, + headers: &HeaderMap, + ) -> Result { + self + .send_request(key, headers, Method::GET, self.endpoints.file()) + .await + } + + /// Get the key. + pub async fn get_index + Send>( + &self, + key: K, + headers: &HeaderMap, + ) -> Result { + self + .send_request(key, headers, Method::GET, self.endpoints.index()) + .await + } + + /// Remove all header entries from the header map. + pub fn remove_header_entries(headers: &mut HeaderMap, key: K) { + match headers.entry(key) { + Entry::Occupied(entry) => { + entry.remove_entry_mult(); + } + Entry::Vacant(_) => {} + } + } + + /// Update the headers with the correct keys and user agent. + #[cfg(feature = "crypt4gh")] + pub async fn update_headers( + &self, + object_type: &ObjectType, + headers: &HeaderMap, + ) -> Result<(HeaderMap, KeyPair)> { + let key_pair = if let Some(key_pair) = object_type.crypt4gh_key_pair() { + let key_pair = key_pair.key_pair().clone(); + info!("Got key pair from config"); + key_pair + } else { + info!("Got key pair generated"); + self + .key_pair + .as_ref() + .ok_or_else(|| InternalError("missing key pair".to_string()))? + .clone() + }; + + let mut headers = headers.clone(); + Self::remove_header_entries(&mut headers, CLIENT_PUBLIC_KEY_NAME); + Self::remove_header_entries(&mut headers, USER_AGENT); + let mut headers = self.remove_blacklisted_headers(headers); + + headers.append( + CLIENT_PUBLIC_KEY_NAME, + Self::encode_key(&PublicKey::new( + encode_public_key(key_pair.public_key().clone()) + .await + .as_bytes() + .to_vec(), + )) + .try_into() + .map_err(|err: InvalidHeaderValue| UrlParseError(err.to_string()))?, + ); + + info!("appended server public key"); + + Ok((headers, key_pair)) + } +} + +/// Type representing the `StreamReader` for `UrlStorage`. +/// Todo, definitely tidy this type... +pub type UrlStreamReader = StreamReader< + MapErr< + Pin> + Send + Sync>>, + fn(reqwest::Error) -> StorageError, + >, + Bytes, +>; + +/// An enum representing the variants of a stream reader. Note, cannot use tokio_util::Either +/// directly because this needs to be gated behind a feature flag. +/// Todo, make this less ugly, better separate feature flags. +#[pin_project(project = ProjectUrlStream)] +pub enum UrlStreamEither { + A(#[pin] UrlStreamReader), + #[cfg(feature = "crypt4gh")] + B(#[pin] Crypt4GHReader), +} + +#[cfg(feature = "crypt4gh")] +#[pin_project] +pub struct Crypt4GHReader { + #[pin] + reader: Reader, + client_additional_bytes: Option, +} + +#[cfg(feature = "crypt4gh")] +impl AsyncRead for Crypt4GHReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().reader.poll_read(cx, buf) + } +} + +impl AsyncRead for UrlStreamEither { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.project() { + ProjectUrlStream::A(a) => a.poll_read(cx, buf), + #[cfg(feature = "crypt4gh")] + ProjectUrlStream::B(b) => b.poll_read(cx, buf), + } + } +} + +impl From for UrlStreamEither { + fn from(response: reqwest::Response) -> Self { + let response: Pin> + Send + Sync>> = + Box::pin(response.bytes_stream()); + let stream_reader: UrlStreamReader = StreamReader::new( + response.map_err(|err| ResponseError(format!("reading body from response: {}", err))), + ); + + Self::A(stream_reader) + } +} + +#[async_trait] +impl Storage for UrlStorage { + type Streamable = UrlStreamEither; + + #[instrument(level = "trace", skip(self))] + async fn get + Send + Debug>( + &self, + key: K, + options: GetOptions<'_>, + head_output: &mut Option<&mut HeadOutput>, + ) -> Result { + info!("Getting underlying file"); + let key = key.as_ref().to_string(); + debug!(calling_from = ?self, key, "getting file with key {:?}", key); + + match KeyType::from_ending(&key) { + KeyType::File => { + #[cfg(feature = "crypt4gh")] + if options.object_type().is_crypt4gh() { + let (mut headers, key_pair) = self + .update_headers(options.object_type(), options.request_headers()) + .await?; + { + // Additional length for the header. + let output_headers = head_output + .as_ref() + .and_then(|output| output.response_headers()); + + let additional_header_length: Option = output_headers + .and_then(|headers| headers.get(SERVER_ADDITIONAL_BYTES)) + .and_then(|length| length.to_str().ok()) + .and_then(|length| length.parse().ok()); + + let file_size: Option = output_headers + .and_then(|headers| headers.get(CONTENT_LENGTH)) + .and_then(|length| length.to_str().ok()) + .and_then(|length| length.parse().ok()); + + if let (Some(crypt4gh_header_length), Some(file_size)) = + (additional_header_length, file_size) + { + let range = options.range(); + let range = range + .clone() + .convert_to_crypt4gh_ranges(crypt4gh_header_length, file_size); + + let range: String = String::from(&BytesRange::from(&range)); + if !range.is_empty() { + headers.append( + RANGE, + range + .parse() + .map_err(|err: InvalidHeaderValue| UrlParseError(err.to_string()))?, + ); + } + } + } + + info!("Sending to storage backend with headers: {:#?}", headers); + + let response = self.get_header(key.to_string(), &headers).await?; + + let crypt4gh_keys = Keys { + method: 0, + privkey: key_pair.private_key().clone().0, + recipient_pubkey: key_pair.public_key().clone().into_inner(), + }; + + let response: Pin< + Box> + Send + Sync>, + > = Box::pin(response.bytes_stream()); + let stream_reader: UrlStreamReader = StreamReader::new( + response.map_err(|err| ResponseError(format!("reading body from response: {}", err))), + ); + + info!("got stream reader"); + + let mut reader = Builder::default().build_with_reader(stream_reader, vec![crypt4gh_keys]); + + reader + .read_header() + .await + .map_err(|err| UrlParseError(err.to_string()))?; + + // Additional length for the header. + let client_additional_bytes: Option = head_output + .as_ref() + .and_then(|output| output.response_headers()) + .and_then(|headers| { + headers + .get(CLIENT_ADDITIONAL_BYTES) + .or_else(|| headers.get(SERVER_ADDITIONAL_BYTES)) + }) + .and_then(|length| length.to_str().ok()) + .and_then(|length| length.parse().ok()); + + // Convert back to the original file size for the rest of the code. + head_output.iter_mut().try_for_each(|output| { + let original_file_size = to_unencrypted_file_size( + output.content_length, + client_additional_bytes.unwrap_or_else(|| reader.header_size().unwrap_or_default()), + ); + + output.content_length = original_file_size; + + let header_content_length = HeaderValue::from_str(&original_file_size.to_string()) + .map_err(|err| UrlParseError(err.to_string()))?; + output.response_headers.iter_mut().for_each(|header| { + header + .get_mut(CONTENT_LENGTH) + .iter_mut() + .for_each(|header| **header = header_content_length.clone()) + }); + + Ok::<_, StorageError>(()) + })?; + + info!( + "additional bytes to return to client: {:#?}", + client_additional_bytes + ); + + return Ok(UrlStreamEither::B(Crypt4GHReader { + reader, + client_additional_bytes, + })); + } + + Ok( + self + .get_header(key.to_string(), options.request_headers()) + .await? + .into(), + ) + } + KeyType::Index => { + #[cfg(feature = "crypt4gh")] + if options.object_type().is_crypt4gh() { + let (headers, _) = self + .update_headers(options.object_type(), options.request_headers()) + .await?; + { + return Ok(self.get_index(key.to_string(), &headers).await?.into()); + } + } + + Ok( + self + .get_index(key.to_string(), options.request_headers()) + .await? + .into(), + ) + } + } + } + + #[instrument(level = "trace", skip(self))] + async fn range_url + Send + Debug>( + &self, + key: K, + options: RangeUrlOptions<'_>, + ) -> Result { + info!("creating range urls"); + let key = key.as_ref(); + debug!(calling_from = ?self, key, "getting url with key {:?}", key); + + let response_headers = self.remove_blacklisted_headers(options.response_headers().clone()); + let new_options = RangeUrlOptions::new( + options.range().clone(), + &response_headers, + options.object_type(), + ); + + self.format_url(key, new_options, &self.response_url).await + } + + #[instrument(level = "trace", skip(self))] + async fn head + Send + Debug>( + &self, + key: K, + options: HeadOptions<'_>, + ) -> Result { + info!("getting head"); + let key = key.as_ref(); + + #[allow(unused_mut)] + let mut headers = options.request_headers().clone(); + #[cfg(feature = "crypt4gh")] + if options.object_type().is_crypt4gh() { + let (updated_headers, _) = self.update_headers(options.object_type(), &headers).await?; + headers = updated_headers; + } + + let head = self.head_key(key, &headers).await?; + + let len: u64 = head + .headers() + .get(CONTENT_LENGTH) + .and_then(|content_length| content_length.to_str().ok()) + .and_then(|content_length| content_length.parse().ok()) + .ok_or_else(|| { + ResponseError(format!( + "failed to get content length from head response for key: {}", + key + )) + })?; + + debug!(calling_from = ?self, key, len, "size of key {:?} is {}", key, len); + Ok(HeadOutput::new(len, Some(head.headers().clone()))) + } + + #[instrument(level = "trace", skip(self, reader))] + async fn update_byte_positions( + &self, + reader: Self::Streamable, + positions_options: BytesPositionOptions<'_>, + ) -> Result> { + info!("updating bytes positions"); + + match reader { + #[cfg(feature = "crypt4gh")] + UrlStreamEither::B(reader) + if positions_options.object_type.is_crypt4gh() + && positions_options + .object_type + .send_encrypted_to_client() + .is_some_and(|send_encrypted_to_client| send_encrypted_to_client) => + { + let Crypt4GHReader { + reader, + client_additional_bytes, + } = reader; + + let keys = reader + .keys() + .first() + .ok_or_else(|| UrlParseError("missing crypt4gh keys from reader".to_string()))?; + let file_size = positions_options.file_size(); + + info!("got keys from reader"); + + let client_additional_bytes = if let Some(bytes) = client_additional_bytes { + bytes + } else { + reader + .header_size() + .ok_or_else(|| UrlParseError("crypt4gh header has not been read".to_string()))? + }; + + // Convert back to an encrypted file size for encrypted byte ranges. + let file_size = to_encrypted_file_size(file_size, client_additional_bytes); + + info!("got client additional bytes from reader"); + + let client_public_key = + Self::decode_public_key(positions_options.headers, CLIENT_PUBLIC_KEY_NAME)?; + + info!("decoded client public key: {:#?}", client_public_key); + + let client_public_key = read_public_key(client_public_key) + .await + .map_err(|err| UrlParseError(format!("failed to parse client public key: {}", err)))?; + + info!("got client public key: {:#?}", client_public_key); + + // Need to work from the context of defined start and end ranges. + let positions = positions_options + .positions + .clone() + .into_iter() + .map(|mut pos| { + if pos.start.is_none() { + pos.start = Some(0); + } + if pos.end.is_none() { + pos.end = Some(file_size); + } + + pos + }) + .collect::>(); + + let unencrypted_positions = BytesPosition::merge_all(positions.clone()); + let clamped_positions = BytesPosition::merge_all( + positions + .clone() + .into_iter() + .map(|pos| pos.convert_to_clamped_crypt4gh_ranges(file_size)) + .collect::>(), + ); + + // Calculate edit lists + let (header_info, edit_list_packet) = self.encrypt.edit_list( + &reader, + unencrypted_positions + .iter() + .map(|position| { + UnencryptedPosition::new( + position.start.unwrap_or_default(), + position.end.unwrap_or(file_size), + ) + }) + .collect(), + clamped_positions + .iter() + .map(|position| { + ClampedPosition::new( + position.start.unwrap_or_default(), + position.end.unwrap_or(file_size), + ) + }) + .collect(), + PrivateKey(keys.privkey.clone()), + client_public_key, + )?; + + info!("created edit list"); + + let encrypted_positions = BytesPosition::merge_all( + positions + .clone() + .into_iter() + .map(|pos| pos.convert_to_crypt4gh_ranges(client_additional_bytes, file_size)) + .collect::>(), + ); + + // Append header with edit lists attached. + let header_info_size = header_info.len() as u64; + let mut blocks = vec![ + DataBlock::Data(header_info, Some(Class::Header)), + DataBlock::Range( + BytesPosition::default() + .with_start(header_info_size) + .with_end(client_additional_bytes), + ), + DataBlock::Data(edit_list_packet, Some(Class::Header)), + ]; + blocks.extend(DataBlock::from_bytes_positions(encrypted_positions)); + + info!("data blocks returned: {:#?}", blocks); + + Ok(blocks) + } + _ => Ok(DataBlock::from_bytes_positions( + positions_options.merge_all().into_inner(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use std::future::Future; + use std::str::FromStr; + + use htsget_config::resolver::object::ObjectType; + use htsget_test::http::server::with_test_server; + use http::header::{AUTHORIZATION, HOST}; + use http::{HeaderName, HeaderValue}; + use tokio::io::AsyncReadExt; + #[cfg(feature = "crypt4gh")] + use { + crate::htsget::from_storage::HtsGetFromStorage, + crate::htsget::HtsGet, + crate::Response as HtsgetResponse, + async_crypt4gh::KeyPair, + htsget_config::tls::crypt4gh::Crypt4GHKeyPair, + htsget_config::types::Class::{Body, Header}, + htsget_config::types::Request as HtsgetRequest, + htsget_config::types::{Format, Query, Url}, + htsget_test::crypt4gh::get_encryption_keys, + htsget_test::http::default_dir, + htsget_test::http::test_bam_crypt4gh_byte_ranges, + htsget_test::http::test_parsable_byte_ranges, + htsget_test::http::{get_byte_ranges_from_url_storage_response, parse_as_bgzf}, + http::header::USER_AGENT, + }; + + use htsget_config::types::Headers; + + use crate::storage::local::tests::create_local_test_files; + + use super::*; + + #[test] + fn get_url_from_key() { + let storage = UrlStorage::new( + test_client(), + endpoints_test(), + Uri::from_str("https://localhost:8080").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + assert_eq!( + storage + .get_url_from_key( + "assets/key1", + &Uri::from_str("https://example.com").unwrap() + ) + .unwrap(), + Uri::from_str("https://example.com/assets/key1").unwrap() + ); + } + + #[test] + fn get_response_url_from_key() { + let storage = UrlStorage::new( + test_client(), + endpoints_test(), + Uri::from_str("https://localhost:8080").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + assert_eq!( + storage + .get_url_from_key( + "assets/key1", + &Uri::from_str("https://localhost:8080").unwrap() + ) + .unwrap(), + Uri::from_str("https://localhost:8080/assets/key1").unwrap() + ); + } + + #[test] + fn remove_blacklisted_headers() { + let storage = UrlStorage::new( + test_client(), + endpoints_test(), + Uri::from_str("https://localhost:8080").unwrap(), + true, + vec![HOST.to_string()], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + headers.insert( + HeaderName::from_str(HOST.as_str()).unwrap(), + HeaderValue::from_str("example.com").unwrap(), + ); + headers.insert( + HeaderName::from_str(AUTHORIZATION.as_str()).unwrap(), + HeaderValue::from_str("secret").unwrap(), + ); + + let headers = storage.remove_blacklisted_headers(headers.clone()); + + assert_eq!(headers.len(), 1); + } + + #[tokio::test] + async fn send_request() { + with_url_test_server(|url| async move { + let storage = UrlStorage::new( + test_client(), + endpoints_from_url(&url), + Uri::from_str(&url).unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + let headers = test_headers(&mut headers); + + let response = String::from_utf8( + storage + .send_request( + "assets/key1", + headers, + Method::GET, + &Uri::from_str(&url).unwrap(), + ) + .await + .unwrap() + .bytes() + .await + .unwrap() + .to_vec(), + ) + .unwrap(); + assert_eq!(response, "value1"); + }) + .await; + } + + #[tokio::test] + async fn get_key() { + with_url_test_server(|url| async move { + let storage = UrlStorage::new( + test_client(), + endpoints_from_url(&url), + Uri::from_str(&url).unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + let headers = test_headers(&mut headers); + + let response = String::from_utf8( + storage + .get_header("assets/key1", headers) + .await + .unwrap() + .bytes() + .await + .unwrap() + .to_vec(), + ) + .unwrap(); + assert_eq!(response, "value1"); + }) + .await; + } + + #[tokio::test] + async fn head_key() { + with_url_test_server(|url| async move { + let storage = UrlStorage::new( + test_client(), + endpoints_from_url(&url), + Uri::from_str(&url).unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + let headers = test_headers(&mut headers); + + let response: u64 = storage + .get_header("assets/key1", headers) + .await + .unwrap() + .headers() + .get(CONTENT_LENGTH) + .unwrap() + .to_str() + .unwrap() + .parse() + .unwrap(); + assert_eq!(response, 6); + }) + .await; + } + + #[tokio::test] + async fn get_storage() { + with_url_test_server(|url| async move { + let storage = UrlStorage::new( + test_client(), + endpoints_from_url(&url), + Uri::from_str(&url).unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + let headers = test_headers(&mut headers); + let object_type = Default::default(); + let options = GetOptions::new_with_default_range(headers, &object_type); + + let mut reader = storage + .get("assets/key1", options, &mut None) + .await + .unwrap(); + + let mut response = [0; 6]; + reader.read_exact(&mut response).await.unwrap(); + + assert_eq!(String::from_utf8(response.to_vec()).unwrap(), "value1"); + }) + .await; + } + + #[tokio::test] + async fn range_url_storage() { + with_url_test_server(|url| async move { + let storage = UrlStorage::new( + test_client(), + endpoints_from_url(&url), + Uri::from_str(&url).unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + let object_type = Default::default(); + let options = test_range_options(&mut headers, &object_type); + + assert_eq!( + storage.range_url("assets/key1", options).await.unwrap(), + HtsGetUrl::new(format!("{}/assets/key1", url)) + .with_headers(Headers::default().with_header(AUTHORIZATION.as_str(), "secret")) + ); + }) + .await; + } + + #[tokio::test] + async fn range_url_storage_blacklisted_headers() { + with_url_test_server(|url| async move { + let storage = UrlStorage::new( + test_client(), + endpoints_from_url(&url), + Uri::from_str(&url).unwrap(), + true, + vec![HOST.to_string()], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + headers.insert( + HeaderName::from_str(HOST.as_str()).unwrap(), + HeaderValue::from_str("example.com").unwrap(), + ); + + let object_type = Default::default(); + let options = test_range_options(&mut headers, &object_type); + + assert_eq!( + storage.range_url("assets/key1", options).await.unwrap(), + HtsGetUrl::new(format!("{}/assets/key1", url)) + .with_headers(Headers::default().with_header(AUTHORIZATION.as_str(), "secret")) + ); + }) + .await; + } + + #[tokio::test] + async fn head_storage() { + with_url_test_server(|url| async move { + let storage = UrlStorage::new( + test_client(), + endpoints_from_url(&url), + Uri::from_str(&url).unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + let headers = test_headers(&mut headers); + let object_type = Default::default(); + let options = HeadOptions::new(headers, &object_type); + + assert_eq!( + storage + .head("assets/key1", options) + .await + .unwrap() + .content_length(), + 6 + ); + }) + .await; + } + + #[tokio::test] + async fn format_url() { + let storage = UrlStorage::new( + test_client(), + endpoints_test(), + Uri::from_str("https://localhost:8080").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + let object_type = Default::default(); + let options = test_range_options(&mut headers, &object_type); + + assert_eq!( + storage + .format_url( + "assets/key1", + options, + &Uri::from_str("https://localhost:8080").unwrap() + ) + .await + .unwrap(), + HtsGetUrl::new("https://localhost:8080/assets/key1") + .with_headers(Headers::default().with_header(AUTHORIZATION.as_str(), "secret")) + ); + } + + #[tokio::test] + async fn format_url_different_response_scheme() { + let storage = UrlStorage::new( + test_client(), + endpoints_test(), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + let object_type = Default::default(); + let options = test_range_options(&mut headers, &object_type); + + assert_eq!( + storage + .format_url( + "assets/key1", + options, + &Uri::from_str("http://example.com").unwrap() + ) + .await + .unwrap(), + HtsGetUrl::new("http://example.com/assets/key1") + .with_headers(Headers::default().with_header(AUTHORIZATION.as_str(), "secret")) + ); + } + + #[tokio::test] + async fn format_url_no_headers() { + let storage = UrlStorage::new( + test_client(), + endpoints_test(), + Uri::from_str("https://localhost:8081").unwrap(), + false, + vec![], + Some("user-agent".to_string()), + &Default::default(), + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let mut headers = HeaderMap::default(); + let object_type = Default::default(); + let options = test_range_options(&mut headers, &object_type); + + assert_eq!( + storage.range_url("assets/key1", options,).await.unwrap(), + HtsGetUrl::new("https://localhost:8081/assets/key1") + ); + } + + #[cfg(feature = "crypt4gh")] + #[tokio::test] + async fn test_endpoints_with_real_file() { + with_url_test_server(|url| async move { + let mut header_map = HeaderMap::default(); + test_headers(&mut header_map); + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Bam, + request, + Default::default(), + ) + .with_reference_name("11") + .with_start(5015000) + .with_end(5050000); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + #[cfg(feature = "crypt4gh")] + default_key_gen(), + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await; + + let expected_response = Ok(expected_bam_response()); + assert_eq!(response, expected_response); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response.unwrap(), + default_dir().join("data/bam/htsnexus_test_NA12878.bam"), + ) + .await; + + parse_as_bgzf(bytes.clone()).await; + }) + .await; + } + + #[cfg(feature = "crypt4gh")] + #[tokio::test] + async fn test_endpoints_with_real_file_encrypted() { + with_url_test_server(|url| async move { + let mut key_gen = default_key_gen(); + key_gen + .expect_edit_list() + .times(1) + .returning(|_, _, _, _, _| Ok(expected_edit_list())); + + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Bam, + request, + ObjectType::GenerateKeys { + send_encrypted_to_client: true, + }, + ) + .with_reference_name("11") + .with_start(5015000) + .with_end(5050000); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + key_gen, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + assert_encrypted_endpoints(&public_key, response).await; + }) + .await; + } + + #[cfg(feature = "crypt4gh")] + #[tokio::test] + async fn test_endpoints_with_encrypted_file_unencrypted_to_client() { + with_url_test_server(|url| async move { + let key_gen = default_key_gen(); + + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Bam, + request, + ObjectType::GenerateKeys { + send_encrypted_to_client: false, + }, + ) + .with_reference_name("11") + .with_start(5015000) + .with_end(5050000); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + key_gen, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let mut expected_response = expected_bam_response(); + expected_response.urls.iter_mut().for_each(|url| { + url.headers.iter_mut().for_each(|header| { + *header = header + .clone() + .with_header(CLIENT_PUBLIC_KEY_NAME, public_key.clone()) + .with_header(USER_AGENT.to_string(), "client-user-agent") + }) + }); + + assert_eq!(response, expected_response); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/bam/htsnexus_test_NA12878.bam"), + ) + .await; + + parse_as_bgzf(bytes).await; + }) + .await; + } + + #[cfg(feature = "crypt4gh")] + #[tokio::test] + async fn test_endpoints_with_predefined_key_pair() { + with_url_test_server(|url| async move { + let mut key_gen = Encrypt::default(); + key_gen + .expect_edit_list() + .times(1) + .returning(|_, _, _, _, _| Ok(expected_edit_list())); + + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Bam, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_reference_name("11") + .with_start(5015000) + .with_end(5050000); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + key_gen, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + assert_encrypted_endpoints(&public_key, response).await; + }) + .await; + } + + #[cfg(feature = "crypt4gh")] + #[tokio::test] + async fn test_endpoints_with_full_file_encrypted() { + with_url_test_server(|url| async move { + let mut key_gen = Encrypt::default(); + key_gen + .expect_edit_list() + .times(1) + .returning(|_, _, _, _, _| { + Ok(( + vec![99, 114, 121, 112, 116, 52, 103, 104, 1, 0, 0, 0, 2, 0, 0, 0], + vec![ + 92, 0, 0, 0, 0, 0, 0, 0, 56, 44, 122, 180, 24, 116, 207, 149, 165, 49, 204, 77, 224, + 136, 232, 121, 209, 249, 23, 51, 120, 2, 187, 147, 82, 227, 232, 32, 17, 223, 7, 38, + 137, 197, 83, 68, 73, 33, 229, 38, 173, 186, 106, 216, 22, 90, 243, 19, 191, 45, 212, + 253, 97, 151, 103, 27, 151, 29, 169, 155, 208, 93, 197, 217, 40, 133, 166, 160, 125, + 43, 82, 75, 1, 20, 104, 45, 116, 193, 165, 160, 189, 186, 146, 175, + ], + )) + }); + + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Bam, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + key_gen, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let expected_response = HtsgetResponse::new( + Format::Bam, + vec![ + // header info + Url::new("data:;base64,Y3J5cHQ0Z2gBAAAAAgAAAA=="), + // original header + Url::new("http://example.com/htsnexus_test_NA12878.bam.c4gh").with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header(CLIENT_PUBLIC_KEY_NAME, public_key.clone()) + .with_header("Range", format!("bytes={}-{}", 16, 123)) + .with_header(USER_AGENT.to_string(), "client-user-agent"), + ), + // edit list packet + Url::new( + "data:;base64,XAAAAAAAAAA4LHq0GHTPlaUxzE3giOh50fkXM3gCu5NS4+ggEd8HJonFU0RJIeUmrbpq2\ + BZa8xO/LdT9YZdnG5cdqZvQXcXZKIWmoH0rUksBFGgtdMGloL26kq8=", + ), + Url::new("http://example.com/htsnexus_test_NA12878.bam.c4gh").with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header(CLIENT_PUBLIC_KEY_NAME, public_key.clone()) + .with_header("Range", format!("bytes={}-{}", 124, 2598043 - 1)) + .with_header(USER_AGENT.to_string(), "client-user-agent"), + ), + ], + ); + + assert_eq!(response, expected_response); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/htsnexus_test_NA12878.bam.c4gh"), + ) + .await; + + let (expected_bytes, _) = get_byte_ranges_from_url_storage_response( + HtsgetResponse::new( + Format::Bam, + vec![ + Url::new("http://example.com/htsnexus_test_NA12878.bam").with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header("Range", "bytes=0-2596798"), + ), + ], + ), + default_dir().join("data/bam/htsnexus_test_NA12878.bam"), + ) + .await; + + test_bam_crypt4gh_byte_ranges(bytes.clone(), expected_bytes).await; + test_parsable_byte_ranges(bytes.clone(), Format::Bam, Body).await; + }) + .await; + } + + #[cfg(feature = "crypt4gh")] + fn expected_key_pair() -> KeyPair { + KeyPair::new( + PrivateKey(vec![ + 162, 124, 25, 18, 207, 218, 241, 41, 162, 107, 29, 40, 10, 93, 30, 193, 104, 42, 176, 235, + 207, 248, 126, 230, 97, 205, 253, 224, 215, 160, 67, 239, + ]), + PublicKey::new(vec![ + 56, 44, 122, 180, 24, 116, 207, 149, 165, 49, 204, 77, 224, 136, 232, 121, 209, 249, 23, + 51, 120, 2, 187, 147, 82, 227, 232, 32, 17, 223, 7, 38, + ]), + ) + } + + #[cfg(feature = "crypt4gh")] + fn expected_edit_list() -> (Vec, Vec) { + ( + vec![99, 114, 121, 112, 116, 52, 103, 104, 1, 0, 0, 0, 2, 0, 0, 0], + vec![ + 132, 0, 0, 0, 0, 0, 0, 0, 56, 44, 122, 180, 24, 116, 207, 149, 165, 49, 204, 77, 224, 136, + 232, 121, 209, 249, 23, 51, 120, 2, 187, 147, 82, 227, 232, 32, 17, 223, 7, 38, 34, 167, + 71, 22, 226, 141, 116, 29, 102, 158, 147, 237, 135, 239, 3, 75, 15, 202, 173, 254, 237, 63, + 4, 74, 55, 123, 247, 21, 64, 80, 22, 138, 80, 64, 123, 116, 45, 229, 168, 155, 206, 72, + 114, 91, 7, 157, 53, 64, 129, 126, 191, 28, 135, 43, 222, 239, 224, 44, 33, 236, 253, 227, + 238, 111, 15, 132, 138, 99, 251, 156, 186, 26, 98, 81, 117, 63, 75, 17, 133, 22, 24, 98, + 78, 61, 153, 239, 164, 230, 224, 120, 159, 111, + ], + ) + } + + #[cfg(feature = "crypt4gh")] + fn default_key_gen() -> Encrypt { + let mut key_gen = Encrypt::default(); + key_gen + .expect_generate_key_pair() + .times(1) + .returning(|| Ok(expected_key_pair())); + key_gen + } + + #[cfg(feature = "crypt4gh")] + async fn assert_encrypted_endpoints(public_key: &String, response: HtsgetResponse) { + let expected_response = HtsgetResponse::new( + Format::Bam, + vec![ + // header info + Url::new("data:;base64,Y3J5cHQ0Z2gBAAAAAgAAAA=="), + // original header + Url::new("http://example.com/htsnexus_test_NA12878.bam.c4gh").with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header(CLIENT_PUBLIC_KEY_NAME, public_key) + .with_header("Range", format!("bytes={}-{}", 16, 123)) + .with_header(USER_AGENT.to_string(), "client-user-agent"), + ), + // edit list packet + Url::new( + "data:;base64,hAAAAAAAAAA4LHq0GHTPlaUxzE3giOh50fkXM3gCu5NS4+ggEd8HJiKnRxbijXQdZp6T7Yf\ + vA0sPyq3+7T8ESjd79xVAUBaKUEB7dC3lqJvOSHJbB501QIF+vxyHK97v4Cwh7P3j7m8PhIpj+5y6GmJRdT9LEYUW\ + GGJOPZnvpObgeJ9v", + ), + Url::new("http://example.com/htsnexus_test_NA12878.bam.c4gh").with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header(CLIENT_PUBLIC_KEY_NAME, public_key) + .with_header("Range", format!("bytes={}-{}", 124, 124 + 65564 - 1)) + .with_header(USER_AGENT.to_string(), "client-user-agent"), + ), + Url::new("http://example.com/htsnexus_test_NA12878.bam.c4gh").with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header(CLIENT_PUBLIC_KEY_NAME, public_key) + .with_header( + "Range", + format!("bytes={}-{}", 124 + 196692, 124 + 1114588 - 1), + ) + .with_header(USER_AGENT.to_string(), "client-user-agent"), + ), + Url::new("http://example.com/htsnexus_test_NA12878.bam.c4gh").with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header(CLIENT_PUBLIC_KEY_NAME, public_key) + .with_header("Range", format!("bytes={}-{}", 124 + 2556996, 2598043 - 1)) + .with_header(USER_AGENT.to_string(), "client-user-agent"), + ), + ], + ); + + assert_eq!(response, expected_response); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/htsnexus_test_NA12878.bam.c4gh"), + ) + .await; + + let (expected_bytes, _) = get_byte_ranges_from_url_storage_response( + expected_bam_response(), + default_dir().join("data/bam/htsnexus_test_NA12878.bam"), + ) + .await; + + test_bam_crypt4gh_byte_ranges(bytes.clone(), expected_bytes).await; + test_parsable_byte_ranges(bytes.clone(), Format::Bam, Body).await; + } + + #[cfg(feature = "crypt4gh")] + fn expected_bam_response() -> HtsgetResponse { + HtsgetResponse::new( + Format::Bam, + vec![ + Url::new("http://example.com/htsnexus_test_NA12878.bam") + .with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header("Range", "bytes=0-4667"), + ) + .with_class(Header), + Url::new("http://example.com/htsnexus_test_NA12878.bam") + .with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header("Range", "bytes=256721-1065951"), + ) + .with_class(Body), + Url::new("http://example.com/htsnexus_test_NA12878.bam") + .with_headers( + Headers::default() + .with_header("authorization", "secret") + .with_header("Range", "bytes=2596771-2596798"), + ) + .with_class(Body), + ], + ) + } + + fn test_client() -> Client { + ClientBuilder::new().build().unwrap() + } + + pub(crate) async fn with_url_test_server(test: F) + where + F: FnOnce(String) -> Fut, + Fut: Future, + { + let (_, base_path) = create_local_test_files().await; + with_test_server(base_path.path(), test).await; + } + + fn test_headers(headers: &mut HeaderMap) -> &HeaderMap { + headers.append( + HeaderName::from_str(AUTHORIZATION.as_str()).unwrap(), + HeaderValue::from_str("secret").unwrap(), + ); + headers + } + + fn test_range_options<'a>( + headers: &'a mut HeaderMap, + object_type: &'a ObjectType, + ) -> RangeUrlOptions<'a> { + let headers = test_headers(headers); + let options = RangeUrlOptions::new_with_default_range(headers, object_type); + + options + } + + fn endpoints_test() -> Endpoints { + Endpoints::new( + Uri::from_str("https://example.com").unwrap().into(), + Uri::from_str("https://example.com").unwrap().into(), + ) + } + + fn endpoints_from_url(url: &str) -> Endpoints { + Endpoints::new( + Uri::from_str(url).unwrap().into(), + Uri::from_str(url).unwrap().into(), + ) + } + + #[cfg(feature = "crypt4gh")] + fn endpoints_from_url_with_path(url: &str) -> Endpoints { + Endpoints::new( + Uri::from_str(&format!("{}/endpoint_index", url)) + .unwrap() + .into(), + Uri::from_str(&format!("{}/endpoint_file", url)) + .unwrap() + .into(), + ) + } +} diff --git a/htsget-search/tests/url_storage_crypt4gh.rs b/htsget-search/tests/url_storage_crypt4gh.rs new file mode 100644 index 000000000..444f13dcb --- /dev/null +++ b/htsget-search/tests/url_storage_crypt4gh.rs @@ -0,0 +1,947 @@ +#![cfg(all(feature = "crypt4gh", feature = "url-storage"))] + +use base64::engine::general_purpose; +use base64::Engine; +use htsget_config::resolver::object::ObjectType; +use htsget_config::storage::url::endpoints::Endpoints; +use htsget_config::tls::crypt4gh::Crypt4GHKeyPair; +use htsget_config::types::Class::{Body, Header}; +use htsget_config::types::Request as HtsgetRequest; +use htsget_config::types::{Format, Query}; +use htsget_search::htsget::from_storage::HtsGetFromStorage; +use htsget_search::htsget::HtsGet; +use htsget_search::storage::url::encrypt::Encrypt; +use htsget_search::storage::url::{UrlStorage, CLIENT_PUBLIC_KEY_NAME}; +use htsget_test::crypt4gh::{create_local_test_files, expected_key_pair, get_encryption_keys}; +use htsget_test::http::server::with_test_server; +use htsget_test::http::{ + default_dir, get_byte_ranges_from_url_storage_response, test_parsable_byte_ranges, +}; +use http::header::{AUTHORIZATION, USER_AGENT}; +use http::{HeaderMap, HeaderName, HeaderValue, Uri}; +use reqwest::{Client, ClientBuilder}; +use std::future::Future; +use std::str::FromStr; + +fn test_client() -> Client { + ClientBuilder::new().build().unwrap() +} + +fn test_headers(headers: &mut HeaderMap) -> &HeaderMap { + headers.append( + HeaderName::from_str(AUTHORIZATION.as_str()).unwrap(), + HeaderValue::from_str("secret").unwrap(), + ); + headers +} + +async fn with_url_test_server(test: F) +where + F: FnOnce(String) -> Fut, + Fut: Future, +{ + let (_, base_path) = create_local_test_files().await; + with_test_server(base_path.path(), test).await; +} + +fn endpoints_from_url_with_path(url: &str) -> Endpoints { + Endpoints::new( + Uri::from_str(&format!("{}/endpoint_index", url)) + .unwrap() + .into(), + Uri::from_str(&format!("{}/endpoint_file", url)) + .unwrap() + .into(), + ) +} + +#[tokio::test] +async fn test_encrypted_bam() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Bam, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/htsnexus_test_NA12878.bam.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Bam, Body).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_cram() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Cram, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/htsnexus_test_NA12878.cram.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Cram, Body).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_vcf() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = HtsgetRequest::new_with_id("spec-v4.3".to_string()).with_headers(header_map); + let query = Query::new( + "spec-v4.3", + Format::Vcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/spec-v4.3.vcf.gz.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Vcf, Body).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_bcf() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("sample1-bcbio-cancer".to_string()).with_headers(header_map); + let query = Query::new( + "sample1-bcbio-cancer", + Format::Bcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/sample1-bcbio-cancer.bcf.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Bcf, Body).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_bam_with_range() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Bam, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_reference_name("11") + .with_start(5015000) + .with_end(5050000); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/htsnexus_test_NA12878.bam.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Bam, Body).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_cram_with_range() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Cram, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_reference_name("11") + .with_start(5000000) + .with_end(5050000); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/htsnexus_test_NA12878.cram.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Cram, Body).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_vcf_with_range() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = HtsgetRequest::new_with_id("spec-v4.3".to_string()).with_headers(header_map); + let query = Query::new( + "spec-v4.3", + Format::Vcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_reference_name("20") + .with_start(150) + .with_end(153); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/spec-v4.3.vcf.gz.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Vcf, Body).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_bcf_with_range() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("sample1-bcbio-cancer".to_string()).with_headers(header_map); + let query = Query::new( + "sample1-bcbio-cancer", + Format::Bcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_reference_name("chrM") + .with_start(150) + .with_end(153); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/sample1-bcbio-cancer.bcf.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Bcf, Body).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_bam_header() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Bam, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_class(Header); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/htsnexus_test_NA12878.bam.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Bam, Header).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_cram_header() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("htsnexus_test_NA12878".to_string()).with_headers(header_map); + let query = Query::new( + "htsnexus_test_NA12878", + Format::Cram, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_class(Header); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/htsnexus_test_NA12878.cram.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Cram, Header).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_vcf_header() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = HtsgetRequest::new_with_id("spec-v4.3".to_string()).with_headers(header_map); + let query = Query::new( + "spec-v4.3", + Format::Vcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_class(Header); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/spec-v4.3.vcf.gz.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Vcf, Header).await; + }) + .await; +} + +#[tokio::test] +async fn test_encrypted_bcf_header() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = + HtsgetRequest::new_with_id("sample1-bcbio-cancer".to_string()).with_headers(header_map); + let query = Query::new( + "sample1-bcbio-cancer", + Format::Bcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_class(Header); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/sample1-bcbio-cancer.bcf.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Bcf, Header).await; + }) + .await; +} + +// The following tests assume the existence of a large Test.1000G file. They are ignored by default. +// Run with `cargo test --all-features -- --ignored --test-threads=1`. It might take a while. +#[ignore] +#[tokio::test] +async fn test_encrypted_large_vcf_chr8_with_range() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = HtsgetRequest::new_with_id("Test.1000G.phase3.joint.lifted.UMCCR".to_string()) + .with_headers(header_map); + let query = Query::new( + "Test.1000G.phase3.joint.lifted.UMCCR", + Format::Vcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_reference_name("chr8") + .with_start(1000000) + .with_end(1000100); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/Test.1000G.phase3.joint.lifted.UMCCR.vcf.gz.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Vcf, Header).await; + }) + .await; +} + +#[ignore] +#[tokio::test] +async fn test_encrypted_large_vcf_chr2_no_range() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = HtsgetRequest::new_with_id("Test.1000G.phase3.joint.lifted.UMCCR".to_string()) + .with_headers(header_map); + let query = Query::new( + "Test.1000G.phase3.joint.lifted.UMCCR", + Format::Vcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_reference_name("chr2"); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/Test.1000G.phase3.joint.lifted.UMCCR.vcf.gz.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Vcf, Header).await; + }) + .await; +} + +#[ignore] +#[tokio::test] +async fn test_encrypted_large_vcf_chr20_no_end_range() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = HtsgetRequest::new_with_id("Test.1000G.phase3.joint.lifted.UMCCR".to_string()) + .with_headers(header_map); + let query = Query::new( + "Test.1000G.phase3.joint.lifted.UMCCR", + Format::Vcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_reference_name("chr20") + .with_start(10000000); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/Test.1000G.phase3.joint.lifted.UMCCR.vcf.gz.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Vcf, Header).await; + }) + .await; +} + +#[ignore] +#[tokio::test] +async fn test_encrypted_large_vcf_chr11_no_start_range() { + with_url_test_server(|url| async move { + let (_, public_key) = get_encryption_keys().await; + let mut header_map = HeaderMap::default(); + let public_key = general_purpose::STANDARD.encode(public_key); + test_headers(&mut header_map); + header_map.append( + HeaderName::from_str(CLIENT_PUBLIC_KEY_NAME).unwrap(), + HeaderValue::from_str(&public_key).unwrap(), + ); + header_map.append( + HeaderName::from_str(USER_AGENT.as_ref()).unwrap(), + HeaderValue::from_str("client-user-agent").unwrap(), + ); + + let request = HtsgetRequest::new_with_id("Test.1000G.phase3.joint.lifted.UMCCR".to_string()) + .with_headers(header_map); + let query = Query::new( + "Test.1000G.phase3.joint.lifted.UMCCR", + Format::Vcf, + request, + ObjectType::Crypt4GH { + crypt4gh: Crypt4GHKeyPair::new(expected_key_pair()), + send_encrypted_to_client: true, + }, + ) + .with_reference_name("chr11") + .with_end(50000); + + let storage = UrlStorage::new( + test_client(), + endpoints_from_url_with_path(&url), + Uri::from_str("http://example.com").unwrap(), + true, + vec![], + Some("user-agent".to_string()), + &query, + Encrypt, + ) + .unwrap(); + + let searcher = HtsGetFromStorage::new(storage); + let response = searcher.search(query.clone()).await.unwrap(); + + let (bytes, _) = get_byte_ranges_from_url_storage_response( + response, + default_dir().join("data/crypt4gh/Test.1000G.phase3.joint.lifted.UMCCR.vcf.gz.c4gh"), + ) + .await; + + test_parsable_byte_ranges(bytes.clone(), Format::Vcf, Header).await; + }) + .await; +} diff --git a/htsget-test/Cargo.toml b/htsget-test/Cargo.toml index 22a0b3ae0..676cfc126 100644 --- a/htsget-test/Cargo.toml +++ b/htsget-test/Cargo.toml @@ -19,10 +19,8 @@ http = [ "dep:htsget-config", "dep:noodles", "dep:reqwest", - "dep:tokio", "dep:futures", "dep:mime", - "dep:base64" ] aws-mocks = [ "s3-storage", @@ -36,6 +34,13 @@ aws-mocks = [ ] s3-storage = ["htsget-config?/s3-storage"] url-storage = ["htsget-config?/url-storage"] +crypt4gh = [ + "dep:crypt4gh", + "dep:tempfile", + "dep:async-crypt4gh", + "htsget-config?/crypt4gh", + "http" +] default = [] [dependencies] @@ -45,14 +50,14 @@ htsget-config = { version = "0.9.0", path = "../htsget-config", default-features noodles = { version = "0.65", optional = true, features = ["async", "bgzf", "vcf", "cram", "bcf", "bam", "fasta"] } reqwest = { version = "0.11", default-features = false, features = ["json", "rustls-tls"], optional = true } -tokio = { version = "1", features = ["rt-multi-thread", "fs"], optional = true } +tokio = { version = "1", features = ["rt-multi-thread", "fs"] } futures = { version = "0.3", optional = true } async-trait = { version = "0.1", optional = true } http = { version = "0.2", optional = true } mime = { version = "0.3", optional = true } serde_json = { version = "1.0", features = ["preserve_order"], optional = true } serde = { version = "1", optional = true } -base64 = { version = "0.21", optional = true } +base64 = { version = "0.21" } tempfile = { version = "3.3", optional = true } aws-sdk-s3 = { version = "0.34", features = ["test-util"], optional = true } @@ -62,6 +67,15 @@ s3s = { version = "0.8", optional = true } s3s-fs = { version = "0.8", optional = true } s3s-aws = { version = "0.8", optional = true } +crypt4gh = { version = "0.4", git = "https://github.com/EGA-archive/crypt4gh-rust", optional = true } +async-crypt4gh = { version = "0.1.0", path = "../async-crypt4gh", optional = true } +axum = { version = "0.6" } +walkdir = { version = "2.5" } +tokio-util = { version = "0.7" } +tower = { version = "0.4", features = ["make"] } +tower-http = { version = "0.4", features = ["trace", "cors", "fs"] } +tokio-rustls = { version = "0.24" } + # Default dependencies rcgen = "0.12" thiserror = "1.0" diff --git a/htsget-test/README.md b/htsget-test/README.md index e2d6a2eee..ece166ab7 100644 --- a/htsget-test/README.md +++ b/htsget-test/README.md @@ -41,6 +41,7 @@ This crate has the following features: * `server-tests`: used to enable server tests. * `s3-storage`: used to enable `S3Storage` functionality. * `url-storage`: used to enable `UrlStorage` functionality. +* `crypt4gh`: used to enable Crypt4GH functionality. [dev-dependencies]: https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#development-dependencies diff --git a/htsget-test/src/aws_mocks.rs b/htsget-test/src/aws_mocks.rs index 38d5cda53..675b6784c 100644 --- a/htsget-test/src/aws_mocks.rs +++ b/htsget-test/src/aws_mocks.rs @@ -1,3 +1,6 @@ +use std::future::Future; +use std::path::{Path, PathBuf}; + use aws_config::SdkConfig; use aws_credential_types::provider::SharedCredentialsProvider; use aws_credential_types::Credentials; @@ -6,8 +9,6 @@ use aws_sdk_s3::Client; use s3s::auth::SimpleAuth; use s3s::service::S3ServiceBuilder; use s3s_fs::FileSystem; -use std::future::Future; -use std::path::{Path, PathBuf}; use tempfile::TempDir; /// Default domain to use for mock s3 server diff --git a/htsget-test/src/crypt4gh.rs b/htsget-test/src/crypt4gh.rs new file mode 100644 index 000000000..c25b32070 --- /dev/null +++ b/htsget-test/src/crypt4gh.rs @@ -0,0 +1,91 @@ +use axum::middleware::Next; +use axum::response::Response; +use crypt4gh::keys::{get_private_key, get_public_key}; +use crypt4gh::Keys; +use http::header::AUTHORIZATION; +use http::{Request, StatusCode}; +use tempfile::TempDir; +use tokio::fs::{create_dir, File}; +use tokio::io::AsyncWriteExt; +use tokio_rustls::rustls::PrivateKey; + +use async_crypt4gh::{KeyPair, PublicKey}; + +use crate::http::get_test_path; + +/// Returns the private keys of the recipient and the senders public key from the context of decryption. +pub async fn get_decryption_keys() -> (Keys, Vec) { + get_keys("crypt4gh/keys/bob.sec", "crypt4gh/keys/alice.pub").await +} + +/// Returns the private keys of the recipient and the senders public key from the context of encryption. +pub async fn get_encryption_keys() -> (Keys, Vec) { + get_keys("crypt4gh/keys/alice.sec", "crypt4gh/keys/bob.pub").await +} + +/// Get the crypt4gh keys from the paths. +pub async fn get_keys(private_key: &str, public_key: &str) -> (Keys, Vec) { + let private_key = get_private_key(get_test_path(private_key), Ok("".to_string())).unwrap(); + let public_key = get_public_key(get_test_path(public_key)).unwrap(); + + ( + Keys { + method: 0, + privkey: private_key, + recipient_pubkey: public_key.clone(), + }, + public_key, + ) +} + +pub fn expected_key_pair() -> KeyPair { + KeyPair::new( + PrivateKey(vec![ + 162, 124, 25, 18, 207, 218, 241, 41, 162, 107, 29, 40, 10, 93, 30, 193, 104, 42, 176, 235, + 207, 248, 126, 230, 97, 205, 253, 224, 215, 160, 67, 239, + ]), + PublicKey::new(vec![ + 56, 44, 122, 180, 24, 116, 207, 149, 165, 49, 204, 77, 224, 136, 232, 121, 209, 249, 23, 51, + 120, 2, 187, 147, 82, 227, 232, 32, 17, 223, 7, 38, + ]), + ) +} + +pub async fn test_auth(request: Request, next: Next) -> Result { + let auth_header = request + .headers() + .get(AUTHORIZATION) + .and_then(|header| header.to_str().ok()); + + match auth_header { + Some("secret") => Ok(next.run(request).await), + _ => Err(StatusCode::UNAUTHORIZED), + } +} + +pub async fn create_local_test_files() -> (String, TempDir) { + let base_path = TempDir::new().unwrap(); + + let folder_name = "folder"; + let key1 = "key1"; + let value1 = b"value1"; + let key2 = "key2"; + let value2 = b"value2"; + File::create(base_path.path().join(key1)) + .await + .unwrap() + .write_all(value1) + .await + .unwrap(); + create_dir(base_path.path().join(folder_name)) + .await + .unwrap(); + File::create(base_path.path().join(folder_name).join(key2)) + .await + .unwrap() + .write_all(value2) + .await + .unwrap(); + + (folder_name.to_string(), base_path) +} diff --git a/htsget-test/src/http/mod.rs b/htsget-test/src/http/mod.rs index 140abe034..b5f290ed7 100644 --- a/htsget-test/src/http/mod.rs +++ b/htsget-test/src/http/mod.rs @@ -5,15 +5,21 @@ pub mod concat; pub mod cors; pub mod server; -use std::fs; +use std::fs::File as StdFile; +use std::io::Read; use std::net::{SocketAddr, TcpListener}; use std::path::{Path, PathBuf}; use std::str::FromStr; use async_trait::async_trait; +use base64::engine::general_purpose; +use base64::Engine; use http::uri::Authority; use http::HeaderMap; +use noodles::bgzf; use serde::de; +use tokio::fs::File; +use tokio::io::AsyncReadExt; use htsget_config::config::cors::{AllowType, CorsConfig}; use htsget_config::config::{DataServerConfig, TicketServerConfig}; @@ -22,8 +28,17 @@ use htsget_config::storage::{local::LocalStorage, Storage}; use htsget_config::tls::{ load_certs, load_key, tls_server_config, CertificateKeyPair, TlsServerConfig, }; +use htsget_config::types; +#[cfg(feature = "crypt4gh")] +use htsget_config::types::{Class, Format}; use htsget_config::types::{Scheme, TaggedTypeAll}; +#[cfg(feature = "crypt4gh")] +use {async_crypt4gh::reader::builder::Builder, std::io::Cursor}; +#[cfg(feature = "crypt4gh")] +use crate::crypt4gh::get_decryption_keys; +#[cfg(feature = "crypt4gh")] +use crate::http::concat::ReadRecords; use crate::util::generate_test_certificates; use crate::Config; @@ -98,7 +113,92 @@ pub fn default_dir() -> PathBuf { .to_path_buf() } -/// Get the default directory where data is present.. +/// Get byte ranges from a url storage response. +pub async fn get_byte_ranges_from_url_storage_response( + response: types::Response, + file: PathBuf, +) -> (Vec, Vec) { + println!("{:#?}", response); + + let file_str = file.to_str().unwrap(); + let mut buf = vec![]; + StdFile::open(file_str) + .unwrap() + .read_to_end(&mut buf) + .unwrap(); + + let mut public_key = vec![]; + let output = response + .urls + .into_iter() + .map(|url| { + if let Some(data_uri) = url.url.as_str().strip_prefix("data:;base64,") { + general_purpose::STANDARD.decode(data_uri).unwrap() + } else { + let headers = url.headers.unwrap().into_inner(); + let range = headers.get("Range").unwrap(); + let mut range = range.strip_prefix("bytes=").unwrap().split('-'); + + let start = usize::from_str(range.next().unwrap()).unwrap(); + let end = usize::from_str(range.next().unwrap()).unwrap() + 1; + + if let Some(header_public_key) = headers.get("public-key") { + public_key = general_purpose::STANDARD.decode(header_public_key).unwrap(); + } + + buf[start..end].to_vec() + } + }) + .reduce(|acc, x| [acc, x].concat()) + .unwrap(); + + (output, public_key) +} + +/// Pass the bytes through a BGZF reader. +pub async fn parse_as_bgzf(bytes: Vec) { + let mut reader = bgzf::AsyncReader::new(bytes.as_slice()); + + let mut data = Vec::new(); + reader.read_to_end(&mut data).await.unwrap(); +} + +#[cfg(feature = "crypt4gh")] +pub async fn test_bam_crypt4gh_byte_ranges(output_bytes: Vec, expected_bytes: Vec) { + let (recipient_private_key, _) = get_decryption_keys().await; + + let mut reader = Builder::default() + .build_with_stream_length(Cursor::new(output_bytes), vec![recipient_private_key]) + .await + .unwrap(); + + let mut unencrypted_out = vec![]; + reader.read_to_end(&mut unencrypted_out).await.unwrap(); + + parse_as_bgzf(unencrypted_out.clone()).await; + + assert_eq!(unencrypted_out, expected_bytes); +} + +#[cfg(feature = "crypt4gh")] +pub async fn test_parsable_byte_ranges(output_bytes: Vec, format: Format, class: Class) { + let (recipient_private_key, _) = get_decryption_keys().await; + + let mut reader = Builder::default() + .build_with_stream_length(Cursor::new(output_bytes), vec![recipient_private_key]) + .await + .unwrap(); + + let mut unencrypted_out = vec![]; + reader.read_to_end(&mut unencrypted_out).await.unwrap(); + + ReadRecords::new(format, class, unencrypted_out) + .read_records() + .await + .unwrap(); +} + +/// Get the default directory where data is present. pub fn default_dir_data() -> PathBuf { default_dir().join("data") } @@ -119,6 +219,7 @@ pub fn default_test_resolver(addr: SocketAddr, scheme: Scheme) -> Vec "^1-(.*)$", "$1", Default::default(), + Default::default(), ) .unwrap(), Resolver::new( @@ -126,6 +227,7 @@ pub fn default_test_resolver(addr: SocketAddr, scheme: Scheme) -> Vec "^2-(.*)$", "$1", Default::default(), + Default::default(), ) .unwrap(), ] @@ -207,8 +309,25 @@ pub fn test_tls_server_config(key_path: PathBuf, cert_path: PathBuf) -> TlsServe TlsServerConfig::new(server_config) } -/// Get the event associated with the file. -pub fn get_test_file>(path: P) -> String { - let path = default_dir().join("data").join(path); - fs::read_to_string(path).expect("failed to read file") +/// Get a test file as a string. +pub async fn get_test_file_string>(path: P) -> String { + let mut string = String::new(); + get_test_file(path) + .await + .read_to_string(&mut string) + .await + .expect("failed to read to string"); + string +} + +/// Get a test file path. +pub fn get_test_path>(path: P) -> PathBuf { + default_dir().join("data").join(path) +} + +/// Get a test file. +pub async fn get_test_file>(path: P) -> File { + File::open(get_test_path(path)) + .await + .expect("failed to read file") } diff --git a/htsget-test/src/http/server.rs b/htsget-test/src/http/server.rs index d721f65b7..0a7a47150 100644 --- a/htsget-test/src/http/server.rs +++ b/htsget-test/src/http/server.rs @@ -1,17 +1,50 @@ +use axum::body::StreamBody; +#[cfg(feature = "crypt4gh")] +use axum::middleware; +use axum::response::IntoResponse; +use axum::routing::{get, head}; +use axum::Router; +#[cfg(feature = "crypt4gh")] +use base64::engine::general_purpose; +#[cfg(feature = "crypt4gh")] +use base64::Engine; +#[cfg(feature = "crypt4gh")] +use crypt4gh::{encrypt, Keys}; +#[cfg(feature = "crypt4gh")] +use std::collections::HashSet; use std::fmt::Debug; -use std::net::SocketAddr; +use std::future::Future; +use std::io::Cursor; +use std::net::{SocketAddr, TcpListener}; +use std::path::Path; -use http::Method; +#[cfg(feature = "crypt4gh")] +use async_crypt4gh::util::read_public_key; +#[cfg(feature = "crypt4gh")] +use async_crypt4gh::{KeyPair, PublicKey}; +use axum::extract::path::Path as AxumPath; +use htsget_config::types::Format; +#[cfg(feature = "crypt4gh")] +use http::header::RANGE; +use http::header::{CONTENT_LENGTH, USER_AGENT}; +use http::{HeaderMap, HeaderValue, Method, StatusCode}; use reqwest::ClientBuilder; use serde::Deserialize; use serde_json::{json, Value}; +use tokio::fs::File; +use tokio::io::AsyncReadExt; +#[cfg(feature = "crypt4gh")] +use tokio_rustls::rustls::PrivateKey; +use tokio_util::io::ReaderStream; +use tower_http::services::ServeDir; +use walkdir::WalkDir; use crate::http::concat::ConcatResponse; use htsget_config::types::Class; -use htsget_config::types::Format; -use crate::http::{Header, Response, TestRequest, TestServer}; -use crate::util::expected_bgzf_eof_data_url; +#[cfg(feature = "crypt4gh")] +use crate::crypt4gh::{expected_key_pair, test_auth}; +use crate::http::{default_dir, Header, Response, TestRequest, TestServer}; use crate::Config; /// Test response with with class. @@ -243,23 +276,20 @@ pub async fn test_service_info(tester: &impl TestServer) { /// An example VCF search response. pub fn expected_response(class: Class, url_path: String) -> Value { let url = format!("{url_path}/data/vcf/sample1-bcbio-cancer.vcf.gz"); - let headers = ["Range", "bytes=0-3465"]; let urls = match class { Class::Header => json!([{ "url": url, "headers": { - headers[0]: headers[1] + "Range": "bytes=0-3465" }, "class": "header" }]), Class::Body => json!([{ "url": url, "headers": { - headers[0]: headers[1] + "Range": "bytes=0-3493" }, - }, { - "url": expected_bgzf_eof_data_url() }]), }; @@ -270,3 +300,245 @@ pub fn expected_response(class: Class, url_path: String) -> Value { } }) } + +pub async fn with_test_server(server_base_path: &Path, test: F) +where + F: FnOnce(String) -> Fut, + Fut: Future, +{ + let router = Router::new() + .route( + "/endpoint_file/:id", + get( + |headers: HeaderMap, AxumPath(_id): AxumPath| async move { + assert_eq!( + headers.get(USER_AGENT), + Some(&HeaderValue::from_static("user-agent")) + ); + + #[cfg(feature = "crypt4gh")] + if headers.contains_key("client-public-key") { + let entry = WalkDir::new(default_dir().join("data")) + .min_depth(2) + .into_iter() + .filter_entry(|e| { + e.path() + .file_name() + .map(|file| file.to_string_lossy() == _id.strip_suffix(".c4gh").unwrap_or(&_id)) + .unwrap_or(false) + }) + .filter_map(|v| v.ok()) + .next() + .unwrap(); + + let range = headers.get(RANGE).unwrap().to_str().unwrap(); + + let range = range.replacen("bytes=", "", 1); + + let split: Vec<&str> = range.splitn(2, '-').collect(); + + let parse_range = |range: Option<&str>| { + let range = range.unwrap_or_default(); + if range.is_empty() { + None + } else { + Some(range.parse().unwrap()) + } + }; + + let start: Option = parse_range(split.first().copied()); + let end: Option = parse_range(split.last().copied()).map(|value| value + 1); + + let mut bytes = vec![]; + let path = entry.path(); + File::open(path) + .await + .unwrap() + .read_to_end(&mut bytes) + .await + .unwrap(); + + let encryption_keys = KeyPair::new( + PrivateKey(vec![ + 161, 61, 174, 214, 146, 101, 139, 42, 247, 73, 68, 96, 8, 198, 29, 26, 68, 113, + 200, 182, 20, 217, 151, 89, 211, 14, 110, 80, 111, 138, 255, 194, + ]), + PublicKey::new(vec![ + 249, 209, 232, 54, 131, 32, 40, 191, 15, 205, 151, 70, 90, 37, 149, 101, 55, 138, + 22, 59, 176, 0, 59, 7, 167, 10, 194, 129, 55, 147, 141, 101, + ]), + ); + + let keys = Keys { + method: 0, + privkey: encryption_keys.private_key().clone().0, + recipient_pubkey: read_public_key( + general_purpose::STANDARD + .decode(headers.get("client-public-key").unwrap()) + .unwrap(), + ) + .await + .unwrap() + .into_inner(), + }; + + assert_eq!( + keys.recipient_pubkey, + expected_key_pair().public_key().clone().into_inner() + ); + + let mut read_buf = Cursor::new(bytes); + let mut write_buf = Cursor::new(vec![]); + + encrypt( + &HashSet::from_iter(vec![keys]), + &mut read_buf, + &mut write_buf, + 0, + None, + ) + .unwrap(); + + let data = write_buf.into_inner(); + + let data = match (start, end) { + (None, None) => data, + (Some(start), None) => data[start as usize..].to_vec(), + (None, Some(end)) => data[..end as usize].to_vec(), + (Some(start), Some(end)) => data[start as usize..end as usize].to_vec(), + }; + + let stream = ReaderStream::new(Cursor::new(data)); + let body = StreamBody::new(stream); + + return (StatusCode::OK, body).into_response(); + } + + let mut bytes = vec![]; + let path = default_dir().join("data/bam/htsnexus_test_NA12878.bam"); + File::open(path) + .await + .unwrap() + .read_to_end(&mut bytes) + .await + .unwrap(); + + let bytes = bytes[..4668].to_vec(); + + let stream = ReaderStream::new(Cursor::new(bytes)); + let body = StreamBody::new(stream); + + (StatusCode::OK, body).into_response() + }, + ), + ) + .route( + "/endpoint_index/:id", + get(|AxumPath(id): AxumPath| async move { + let entry = WalkDir::new(default_dir().join("data")) + .min_depth(2) + .into_iter() + .filter_entry(|e| { + e.path() + .file_name() + .map(|file| { + file.to_string_lossy() == id.clone() && !file.to_string_lossy().ends_with(".gzi") + }) + .unwrap_or(false) + }) + .filter_map(|v| v.ok()) + .next(); + + match entry { + None => { + let bytes: Vec = vec![]; + let stream = ReaderStream::new(Cursor::new(bytes)); + let body = StreamBody::new(stream); + + (StatusCode::NOT_FOUND, body).into_response() + } + Some(entry) => { + let mut bytes = vec![]; + let path = entry.path(); + File::open(path) + .await + .unwrap() + .read_to_end(&mut bytes) + .await + .unwrap(); + + let stream = ReaderStream::new(Cursor::new(bytes)); + let body = StreamBody::new(stream); + + (StatusCode::OK, body).into_response() + } + } + }), + ) + .route( + "/endpoint_file/:id", + head( + |AxumPath(id): AxumPath, headers: HeaderMap| async move { + assert_eq!( + headers.get(USER_AGENT), + Some(&HeaderValue::from_static("user-agent")) + ); + + #[cfg(feature = "crypt4gh")] + if headers.contains_key("client-public-key") { + let public_key = read_public_key( + general_purpose::STANDARD + .decode(headers.get("client-public-key").unwrap()) + .unwrap(), + ) + .await + .unwrap() + .into_inner(); + assert_eq!( + public_key, + expected_key_pair().public_key().clone().into_inner() + ); + } + + let length = WalkDir::new(default_dir().join("data")) + .min_depth(2) + .into_iter() + .filter_entry(|e| { + e.path() + .file_name() + .map(|file| file.to_string_lossy() == id.clone()) + .unwrap_or(false) + }) + .filter_map(|v| v.ok()) + .next() + .map(|entry| entry.metadata().unwrap().len().to_string()) + .unwrap_or_else(|| "2596799".to_string()); + + axum::response::Response::builder() + .header("server-additional-bytes", 124) + .header("client-additional-bytes", 124) + .header(CONTENT_LENGTH, HeaderValue::from_str(&length).unwrap()) + .status(StatusCode::OK) + .body(StreamBody::new(ReaderStream::new(Cursor::new(vec![])))) + .unwrap() + .into_response() + }, + ), + ) + .nest_service("/assets", ServeDir::new(server_base_path.to_str().unwrap())); + + #[cfg(feature = "crypt4gh")] + let router = router.route_layer(middleware::from_fn(test_auth)); + + // TODO fix this in htsget-test to bind and return tcp listener. + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn( + axum::Server::from_tcp(listener) + .unwrap() + .serve(router.into_make_service()), + ); + + test(format!("http://{}", addr)).await; +} diff --git a/htsget-test/src/lib.rs b/htsget-test/src/lib.rs index ca18fbd9c..4b8c8042b 100644 --- a/htsget-test/src/lib.rs +++ b/htsget-test/src/lib.rs @@ -10,3 +10,6 @@ pub mod error; #[cfg(feature = "http")] pub mod http; pub mod util; + +#[cfg(feature = "crypt4gh")] +pub mod crypt4gh;