From 65f552b6f5ae02e0ff790b55050ed8677ea09963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Wed, 15 Jan 2025 10:42:27 +0100 Subject: [PATCH] Deprecate old crates Replace old crates with stubs that give compile-time errors, notifying maintainers about the rename. --- .github/workflows/publish-edgedb-derive.yaml | 14 +- .github/workflows/publish-edgedb-errors.yaml | 14 +- .../workflows/publish-edgedb-protocol.yaml | 16 +- .github/workflows/publish-edgedb-tokio.yaml | 16 +- Cargo.toml | 2 - edgedb-derive/Cargo.toml | 18 +- edgedb-derive/README.md | 17 +- edgedb-derive/src/attrib.rs | 103 - edgedb-derive/src/enums.rs | 54 - edgedb-derive/src/json.rs | 44 - edgedb-derive/src/lib.rs | 137 +- edgedb-derive/src/shape.rs | 190 -- edgedb-derive/src/variables.rs | 80 - edgedb-derive/tests/enums.rs | 24 - edgedb-derive/tests/fail.rs | 5 - edgedb-derive/tests/json.rs | 52 - edgedb-derive/tests/list_scalar_types.rs | 50 - edgedb-derive/tests/no-prelude.rs | 9 - edgedb-derive/tests/varnames.rs | 27 - edgedb-errors/Cargo.toml | 12 +- edgedb-errors/README.md | 17 +- edgedb-errors/src/bin/edgedb_gen_errors.rs | 143 - edgedb-errors/src/display.rs | 77 - edgedb-errors/src/error.rs | 226 -- edgedb-errors/src/fields.rs | 8 - edgedb-errors/src/kinds.rs | 180 -- edgedb-errors/src/lib.rs | 89 +- edgedb-errors/src/miette.rs | 28 - edgedb-errors/src/traits.rs | 98 - edgedb-protocol/Cargo.toml | 34 +- edgedb-protocol/README.md | 6 +- edgedb-protocol/src/annotations.rs | 197 -- edgedb-protocol/src/client_message.rs | 1016 ------- edgedb-protocol/src/codec.rs | 1633 ------------ edgedb-protocol/src/common.rs | 154 -- edgedb-protocol/src/descriptors.rs | 1129 -------- edgedb-protocol/src/encoding.rs | 218 -- edgedb-protocol/src/error_response.rs | 11 - edgedb-protocol/src/errors.rs | 194 -- edgedb-protocol/src/features.rs | 54 - edgedb-protocol/src/lib.rs | 79 +- edgedb-protocol/src/model.rs | 81 - edgedb-protocol/src/model/bignum.rs | 421 --- .../src/model/bignum/bigdecimal_interop.rs | 402 --- .../src/model/bignum/num_bigint_interop.rs | 216 -- edgedb-protocol/src/model/json.rs | 44 - edgedb-protocol/src/model/memory.rs | 32 - edgedb-protocol/src/model/range.rs | 75 - edgedb-protocol/src/model/time.rs | 1898 -------------- edgedb-protocol/src/model/vector.rs | 52 - edgedb-protocol/src/query_arg.rs | 555 ---- edgedb-protocol/src/query_result.rs | 67 - edgedb-protocol/src/queryable.rs | 101 - edgedb-protocol/src/serialization.rs | 4 - edgedb-protocol/src/serialization/decode.rs | 11 - .../src/serialization/decode/chrono.rs | 27 - .../src/serialization/decode/queryable.rs | 3 - .../decode/queryable/collections.rs | 73 - .../serialization/decode/queryable/scalars.rs | 309 --- .../serialization/decode/queryable/tuples.rs | 53 - .../src/serialization/decode/raw_composite.rs | 212 -- .../src/serialization/decode/raw_scalar.rs | 736 ------ .../src/serialization/test_scalars.rs | 225 -- edgedb-protocol/src/server_message.rs | 1026 -------- edgedb-protocol/src/value.rs | 272 -- edgedb-protocol/src/value_opt.rs | 126 - edgedb-protocol/tests/base.rs | 12 - edgedb-protocol/tests/client_messages.rs | 307 --- edgedb-protocol/tests/codecs.rs | 1430 ---------- edgedb-protocol/tests/datetime_chrono.rs | 400 --- edgedb-protocol/tests/datetime_system.rs | 264 -- edgedb-protocol/tests/decode.rs | 8 - edgedb-protocol/tests/error_response.bin | Bin 458 -> 0 bytes edgedb-protocol/tests/parameter_status.bin | Bin 46 -> 0 bytes edgedb-protocol/tests/server_key_data.bin | Bin 37 -> 0 bytes edgedb-protocol/tests/server_messages.rs | 370 --- edgedb-protocol/tests/type_descriptors.rs | 401 --- edgedb-tokio/Cargo.toml | 63 +- edgedb-tokio/README.md | 39 +- edgedb-tokio/examples/query_args.rs | 47 - edgedb-tokio/examples/simple.rs | 9 - edgedb-tokio/examples/transaction.rs | 17 - edgedb-tokio/examples/transaction_errors.rs | 56 - edgedb-tokio/src/builder.rs | 2335 ----------------- edgedb-tokio/src/client.rs | 613 ----- edgedb-tokio/src/credentials.rs | 212 -- edgedb-tokio/src/env.rs | 226 -- edgedb-tokio/src/errors.rs | 2 - edgedb-tokio/src/letsencrypt_staging.pem | 47 - edgedb-tokio/src/lib.rs | 178 +- edgedb-tokio/src/nebula_development.pem | 13 - edgedb-tokio/src/options.rs | 201 -- edgedb-tokio/src/query_executor.rs | 253 -- edgedb-tokio/src/raw/connection.rs | 854 ------ edgedb-tokio/src/raw/dumps.rs | 308 --- edgedb-tokio/src/raw/mod.rs | 191 -- edgedb-tokio/src/raw/options.rs | 14 - edgedb-tokio/src/raw/queries.rs | 735 ------ edgedb-tokio/src/raw/response.rs | 319 --- edgedb-tokio/src/raw/state.rs | 445 ---- edgedb-tokio/src/sealed.rs | 1 - edgedb-tokio/src/server_params.rs | 70 - edgedb-tokio/src/state.rs | 7 - edgedb-tokio/src/tls.rs | 183 -- edgedb-tokio/src/transaction.rs | 420 --- edgedb-tokio/src/tutorial.md | 473 ---- edgedb-tokio/src/tutorial.rs | 2 - edgedb-tokio/tests/credentials1.json | 6 - edgedb-tokio/tests/func/.gitignore | 1 - edgedb-tokio/tests/func/client.rs | 303 --- edgedb-tokio/tests/func/dbschema/test.esdl | 22 - edgedb-tokio/tests/func/derive.rs | 73 - edgedb-tokio/tests/func/globals.rs | 58 - edgedb-tokio/tests/func/main.rs | 17 - edgedb-tokio/tests/func/raw.rs | 51 - edgedb-tokio/tests/func/server.rs | 39 - edgedb-tokio/tests/func/transactions.rs | 208 -- examples/globals/Cargo.toml | 12 - examples/globals/dbschema/default.esdl | 3 - .../globals/dbschema/migrations/00001.edgeql | 5 - examples/globals/src/main.rs | 18 - examples/query-error/Cargo.toml | 13 - examples/query-error/src/main.rs | 33 - 123 files changed, 51 insertions(+), 25832 deletions(-) delete mode 100644 edgedb-derive/src/attrib.rs delete mode 100644 edgedb-derive/src/enums.rs delete mode 100644 edgedb-derive/src/json.rs delete mode 100644 edgedb-derive/src/shape.rs delete mode 100644 edgedb-derive/src/variables.rs delete mode 100644 edgedb-derive/tests/enums.rs delete mode 100644 edgedb-derive/tests/fail.rs delete mode 100644 edgedb-derive/tests/json.rs delete mode 100644 edgedb-derive/tests/list_scalar_types.rs delete mode 100644 edgedb-derive/tests/no-prelude.rs delete mode 100644 edgedb-derive/tests/varnames.rs delete mode 100644 edgedb-errors/src/bin/edgedb_gen_errors.rs delete mode 100644 edgedb-errors/src/display.rs delete mode 100644 edgedb-errors/src/error.rs delete mode 100644 edgedb-errors/src/fields.rs delete mode 100644 edgedb-errors/src/kinds.rs delete mode 100644 edgedb-errors/src/miette.rs delete mode 100644 edgedb-errors/src/traits.rs delete mode 100644 edgedb-protocol/src/annotations.rs delete mode 100644 edgedb-protocol/src/client_message.rs delete mode 100644 edgedb-protocol/src/codec.rs delete mode 100644 edgedb-protocol/src/common.rs delete mode 100644 edgedb-protocol/src/descriptors.rs delete mode 100644 edgedb-protocol/src/encoding.rs delete mode 100644 edgedb-protocol/src/error_response.rs delete mode 100644 edgedb-protocol/src/errors.rs delete mode 100644 edgedb-protocol/src/features.rs delete mode 100644 edgedb-protocol/src/model.rs delete mode 100644 edgedb-protocol/src/model/bignum.rs delete mode 100644 edgedb-protocol/src/model/bignum/bigdecimal_interop.rs delete mode 100644 edgedb-protocol/src/model/bignum/num_bigint_interop.rs delete mode 100644 edgedb-protocol/src/model/json.rs delete mode 100644 edgedb-protocol/src/model/memory.rs delete mode 100644 edgedb-protocol/src/model/range.rs delete mode 100644 edgedb-protocol/src/model/time.rs delete mode 100644 edgedb-protocol/src/model/vector.rs delete mode 100644 edgedb-protocol/src/query_arg.rs delete mode 100644 edgedb-protocol/src/query_result.rs delete mode 100644 edgedb-protocol/src/queryable.rs delete mode 100644 edgedb-protocol/src/serialization.rs delete mode 100644 edgedb-protocol/src/serialization/decode.rs delete mode 100644 edgedb-protocol/src/serialization/decode/chrono.rs delete mode 100644 edgedb-protocol/src/serialization/decode/queryable.rs delete mode 100644 edgedb-protocol/src/serialization/decode/queryable/collections.rs delete mode 100644 edgedb-protocol/src/serialization/decode/queryable/scalars.rs delete mode 100644 edgedb-protocol/src/serialization/decode/queryable/tuples.rs delete mode 100644 edgedb-protocol/src/serialization/decode/raw_composite.rs delete mode 100644 edgedb-protocol/src/serialization/decode/raw_scalar.rs delete mode 100644 edgedb-protocol/src/serialization/test_scalars.rs delete mode 100644 edgedb-protocol/src/server_message.rs delete mode 100644 edgedb-protocol/src/value.rs delete mode 100644 edgedb-protocol/src/value_opt.rs delete mode 100644 edgedb-protocol/tests/base.rs delete mode 100644 edgedb-protocol/tests/client_messages.rs delete mode 100644 edgedb-protocol/tests/codecs.rs delete mode 100644 edgedb-protocol/tests/datetime_chrono.rs delete mode 100644 edgedb-protocol/tests/datetime_system.rs delete mode 100644 edgedb-protocol/tests/decode.rs delete mode 100644 edgedb-protocol/tests/error_response.bin delete mode 100644 edgedb-protocol/tests/parameter_status.bin delete mode 100644 edgedb-protocol/tests/server_key_data.bin delete mode 100644 edgedb-protocol/tests/server_messages.rs delete mode 100644 edgedb-protocol/tests/type_descriptors.rs delete mode 100644 edgedb-tokio/examples/query_args.rs delete mode 100644 edgedb-tokio/examples/simple.rs delete mode 100644 edgedb-tokio/examples/transaction.rs delete mode 100644 edgedb-tokio/examples/transaction_errors.rs delete mode 100644 edgedb-tokio/src/builder.rs delete mode 100644 edgedb-tokio/src/client.rs delete mode 100644 edgedb-tokio/src/credentials.rs delete mode 100644 edgedb-tokio/src/env.rs delete mode 100644 edgedb-tokio/src/errors.rs delete mode 100644 edgedb-tokio/src/letsencrypt_staging.pem delete mode 100644 edgedb-tokio/src/nebula_development.pem delete mode 100644 edgedb-tokio/src/options.rs delete mode 100644 edgedb-tokio/src/query_executor.rs delete mode 100644 edgedb-tokio/src/raw/connection.rs delete mode 100644 edgedb-tokio/src/raw/dumps.rs delete mode 100644 edgedb-tokio/src/raw/mod.rs delete mode 100644 edgedb-tokio/src/raw/options.rs delete mode 100644 edgedb-tokio/src/raw/queries.rs delete mode 100644 edgedb-tokio/src/raw/response.rs delete mode 100644 edgedb-tokio/src/raw/state.rs delete mode 100644 edgedb-tokio/src/sealed.rs delete mode 100644 edgedb-tokio/src/server_params.rs delete mode 100644 edgedb-tokio/src/state.rs delete mode 100644 edgedb-tokio/src/tls.rs delete mode 100644 edgedb-tokio/src/transaction.rs delete mode 100644 edgedb-tokio/src/tutorial.md delete mode 100644 edgedb-tokio/src/tutorial.rs delete mode 100644 edgedb-tokio/tests/credentials1.json delete mode 100644 edgedb-tokio/tests/func/.gitignore delete mode 100644 edgedb-tokio/tests/func/client.rs delete mode 100644 edgedb-tokio/tests/func/dbschema/test.esdl delete mode 100644 edgedb-tokio/tests/func/derive.rs delete mode 100644 edgedb-tokio/tests/func/globals.rs delete mode 100644 edgedb-tokio/tests/func/main.rs delete mode 100644 edgedb-tokio/tests/func/raw.rs delete mode 100644 edgedb-tokio/tests/func/server.rs delete mode 100644 edgedb-tokio/tests/func/transactions.rs delete mode 100644 examples/globals/Cargo.toml delete mode 100644 examples/globals/dbschema/default.esdl delete mode 100644 examples/globals/dbschema/migrations/00001.edgeql delete mode 100644 examples/globals/src/main.rs delete mode 100644 examples/query-error/Cargo.toml delete mode 100644 examples/query-error/src/main.rs diff --git a/.github/workflows/publish-edgedb-derive.yaml b/.github/workflows/publish-edgedb-derive.yaml index bc7cee7c..51734592 100644 --- a/.github/workflows/publish-edgedb-derive.yaml +++ b/.github/workflows/publish-edgedb-derive.yaml @@ -14,16 +14,8 @@ jobs: contents: "read" steps: # checkout and env setup - - uses: actions/checkout@v3 - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - name: Build the nix shell - run: nix develop --command just --version - - uses: Swatinem/rust-cache@v2 - - # test - - name: Test - run: nix develop --command cargo test --all-features --package=edgedb-derive + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@stable # verify that git tag matches cargo version - run: | @@ -35,4 +27,4 @@ jobs: - working-directory: ./edgedb-derive run: | - nix develop --command cargo publish --token=${{ secrets.CARGO_REGISTRY_TOKEN }} + cargo publish --token=${{ secrets.CARGO_REGISTRY_TOKEN }} --no-verify diff --git a/.github/workflows/publish-edgedb-errors.yaml b/.github/workflows/publish-edgedb-errors.yaml index ffcbc5f8..d4b37589 100644 --- a/.github/workflows/publish-edgedb-errors.yaml +++ b/.github/workflows/publish-edgedb-errors.yaml @@ -14,16 +14,8 @@ jobs: contents: "read" steps: # checkout and env setup - - uses: actions/checkout@v3 - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - name: Build the nix shell - run: nix develop --command just --version - - uses: Swatinem/rust-cache@v2 - - # test - - name: Test - run: nix develop --command cargo test --all-features --package=edgedb-errors + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@stable # verify that git tag matches cargo version - run: | @@ -35,4 +27,4 @@ jobs: - working-directory: ./edgedb-errors run: | - nix develop --command cargo publish --token=${{ secrets.CARGO_REGISTRY_TOKEN }} + cargo publish --token=${{ secrets.CARGO_REGISTRY_TOKEN }} --no-verify diff --git a/.github/workflows/publish-edgedb-protocol.yaml b/.github/workflows/publish-edgedb-protocol.yaml index c679b0d2..574f91bc 100644 --- a/.github/workflows/publish-edgedb-protocol.yaml +++ b/.github/workflows/publish-edgedb-protocol.yaml @@ -13,17 +13,9 @@ jobs: id-token: "write" contents: "read" steps: - # checkout and env setup - - uses: actions/checkout@v3 - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - name: Build the nix shell - run: nix develop --command just --version - - uses: Swatinem/rust-cache@v2 - - # test - - name: Test - run: nix develop --command cargo test --all-features --package=edgedb-protocol + # checkout and env setup + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@stable # verify that git tag matches cargo version - run: | @@ -35,4 +27,4 @@ jobs: - working-directory: ./edgedb-protocol run: | - nix develop --command cargo publish --token=${{ secrets.CARGO_REGISTRY_TOKEN }} + cargo publish --token=${{ secrets.CARGO_REGISTRY_TOKEN }} --no-verify diff --git a/.github/workflows/publish-edgedb-tokio.yaml b/.github/workflows/publish-edgedb-tokio.yaml index 522152b1..b1399cb2 100644 --- a/.github/workflows/publish-edgedb-tokio.yaml +++ b/.github/workflows/publish-edgedb-tokio.yaml @@ -13,17 +13,9 @@ jobs: id-token: "write" contents: "read" steps: - # checkout and env setup - - uses: actions/checkout@v3 - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - name: Build the nix shell - run: nix develop --command just --version - - uses: Swatinem/rust-cache@v2 - - # test - - name: Test - run: nix develop --command cargo test --all-features --package=edgedb-tokio + # checkout and env setup + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@stable # verify that git tag matches cargo version - run: | @@ -35,4 +27,4 @@ jobs: - working-directory: ./edgedb-tokio run: | - nix develop --command cargo publish --token=${{ secrets.CARGO_REGISTRY_TOKEN }} + cargo publish --token=${{ secrets.CARGO_REGISTRY_TOKEN }} --no-verify diff --git a/Cargo.toml b/Cargo.toml index 8ff6b180..5cc0afe1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,8 +5,6 @@ members = [ "edgedb-derive", "edgedb-protocol", "edgedb-tokio", - "examples/globals", - "examples/query-error", ] [profile.release] diff --git a/edgedb-derive/Cargo.toml b/edgedb-derive/Cargo.toml index 600b38ba..14eb7cfb 100644 --- a/edgedb-derive/Cargo.toml +++ b/edgedb-derive/Cargo.toml @@ -1,29 +1,15 @@ [package] name = "edgedb-derive" license = "MIT/Apache-2.0" -version = "0.5.2" +version = "0.6.0" authors = ["MagicStack Inc. "] edition = "2018" description = """ Derive macros for EdgeDB database client. + This crate has been renamed to gel-derive. """ readme = "README.md" rust-version.workspace = true -[dependencies] -syn = {version="2.0", features=["full"]} -proc-macro2 = "1.0.19" -quote = "1.0" -trybuild = "1.0.19" - -[dev-dependencies] -bytes = "1.0.1" -edgedb-protocol = {path="../edgedb-protocol"} -serde = {version="1.0", features=["derive"]} -serde_json = "1.0" - [lib] proc-macro = true - -[lints] -workspace = true diff --git a/edgedb-derive/README.md b/edgedb-derive/README.md index e265fd52..aa766cc7 100644 --- a/edgedb-derive/README.md +++ b/edgedb-derive/README.md @@ -1,19 +1,4 @@ EdgeDB Rust Binding: Derive Crate ================================= -This crate contains derive macros for the EdgeDB client. - -* [Documentation](https://docs.rs/edgedb-derive) -* [Tokio Client](https://docs.rs/edgedb-tokio) - - -License -======= - -Licensed under either of - -* Apache License, Version 2.0, - (./LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) -* MIT license (./LICENSE-MIT or http://opensource.org/licenses/MIT) - -at your option. +> This crate has been renamed to [gel-derive](https://crates.io/crates/gel-derive) diff --git a/edgedb-derive/src/attrib.rs b/edgedb-derive/src/attrib.rs deleted file mode 100644 index bae1ceac..00000000 --- a/edgedb-derive/src/attrib.rs +++ /dev/null @@ -1,103 +0,0 @@ -use syn::parse::{Parse, ParseStream}; -use syn::punctuated::Punctuated; - -#[derive(Debug)] -enum FieldAttr { - Json, -} - -#[derive(Debug)] -enum ContainerAttr { - Json, -} - -struct FieldAttrList(pub Punctuated); -struct ContainerAttrList(pub Punctuated); - -pub struct FieldAttrs { - pub json: bool, -} - -pub struct ContainerAttrs { - pub json: bool, -} - -mod kw { - syn::custom_keyword!(json); -} - -impl Parse for FieldAttr { - fn parse(input: ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - if lookahead.peek(kw::json) { - let _ident: syn::Ident = input.parse()?; - Ok(FieldAttr::Json) - } else { - Err(lookahead.error()) - } - } -} - -impl Parse for ContainerAttr { - fn parse(input: ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - if lookahead.peek(kw::json) { - let _ident: syn::Ident = input.parse()?; - Ok(ContainerAttr::Json) - } else { - Err(lookahead.error()) - } - } -} - -impl Parse for ContainerAttrList { - fn parse(input: ParseStream) -> syn::Result { - Punctuated::parse_terminated(input).map(ContainerAttrList) - } -} - -impl Parse for FieldAttrList { - fn parse(input: ParseStream) -> syn::Result { - Punctuated::parse_terminated(input).map(FieldAttrList) - } -} - -impl FieldAttrs { - fn default() -> FieldAttrs { - FieldAttrs { json: false } - } - pub fn from_syn(attrs: &[syn::Attribute]) -> syn::Result { - let mut res = FieldAttrs::default(); - for attr in attrs { - if matches!(attr.style, syn::AttrStyle::Outer) && attr.path().is_ident("edgedb") { - let chunk: FieldAttrList = attr.parse_args()?; - for item in chunk.0 { - match item { - FieldAttr::Json => res.json = true, - } - } - } - } - Ok(res) - } -} - -impl ContainerAttrs { - fn default() -> ContainerAttrs { - ContainerAttrs { json: false } - } - pub fn from_syn(attrs: &[syn::Attribute]) -> syn::Result { - let mut res = ContainerAttrs::default(); - for attr in attrs { - if matches!(attr.style, syn::AttrStyle::Outer) && attr.path().is_ident("edgedb") { - let chunk: ContainerAttrList = attr.parse_args()?; - for item in chunk.0 { - match item { - ContainerAttr::Json => res.json = true, - } - } - } - } - Ok(res) - } -} diff --git a/edgedb-derive/src/enums.rs b/edgedb-derive/src/enums.rs deleted file mode 100644 index 541391b2..00000000 --- a/edgedb-derive/src/enums.rs +++ /dev/null @@ -1,54 +0,0 @@ -use proc_macro2::TokenStream; -use quote::quote; - -pub fn derive_enum(s: &syn::ItemEnum) -> syn::Result { - let type_name = &s.ident; - let (impl_generics, ty_generics, _) = s.generics.split_for_impl(); - let branches = s - .variants - .iter() - .map(|v| match v.fields { - syn::Fields::Unit => { - let name = &v.ident; - let name_bstr = syn::LitByteStr::new(name.to_string().as_bytes(), name.span()); - Ok(quote!(#name_bstr => Ok(#type_name::#name))) - } - _ => Err(syn::Error::new_spanned( - &v.fields, - "fields are not allowed in enum variants", - )), - }) - .collect::, _>>()?; - let expanded = quote! { - impl #impl_generics ::edgedb_protocol::queryable::Queryable - for #type_name #ty_generics { - fn decode(decoder: &::edgedb_protocol::queryable::Decoder, buf: &[u8]) - -> Result - { - match buf { - #(#branches,)* - _ => Err(::edgedb_protocol::errors::ExtraEnumValue.build()), - } - } - fn check_descriptor( - ctx: &::edgedb_protocol::queryable::DescriptorContext, - type_pos: ::edgedb_protocol::descriptors::TypePos) - -> Result<(), ::edgedb_protocol::queryable::DescriptorMismatch> - { - use ::edgedb_protocol::descriptors::Descriptor::Enumeration; - let desc = ctx.get(type_pos)?; - match desc { - // There is no need to check the members of the enumeration - // because schema updates can't be perfectly synchronized - // to app updates. And that means that extra variants - // might be added and only when they are really present in - // data we should issue an error. Removed variants are not a - // problem here. - Enumeration(_) => Ok(()), - _ => Err(ctx.wrong_type(desc, "str")), - } - } - } - }; - Ok(expanded) -} diff --git a/edgedb-derive/src/json.rs b/edgedb-derive/src/json.rs deleted file mode 100644 index 695745fa..00000000 --- a/edgedb-derive/src/json.rs +++ /dev/null @@ -1,44 +0,0 @@ -use proc_macro2::TokenStream; -use quote::quote; - -pub fn derive(item: &syn::Item) -> syn::Result { - let (name, impl_generics, ty_generics) = match item { - syn::Item::Struct(s) => { - let (impl_generics, ty_generics, _) = s.generics.split_for_impl(); - (&s.ident, impl_generics, ty_generics) - } - syn::Item::Enum(e) => { - let (impl_generics, ty_generics, _) = e.generics.split_for_impl(); - (&e.ident, impl_generics, ty_generics) - } - _ => { - return Err(syn::Error::new_spanned( - item, - "can only derive Queryable for structs and enums in JSON mode", - )); - } - }; - let expanded = quote! { - impl #impl_generics ::edgedb_protocol::queryable::Queryable - for #name #ty_generics { - fn decode(decoder: &::edgedb_protocol::queryable::Decoder, buf: &[u8]) - -> Result - { - let json: ::edgedb_protocol::model::Json = - ::edgedb_protocol::queryable::Queryable::decode(decoder, buf)?; - ::serde_json::from_str(json.as_ref()) - .map_err(::edgedb_protocol::errors::decode_error) - } - fn check_descriptor( - ctx: &::edgedb_protocol::queryable::DescriptorContext, - type_pos: ::edgedb_protocol::descriptors::TypePos) - -> Result<(), ::edgedb_protocol::queryable::DescriptorMismatch> - { - <::edgedb_protocol::model::Json as - ::edgedb_protocol::queryable::Queryable> - ::check_descriptor(ctx, type_pos) - } - } - }; - Ok(expanded) -} diff --git a/edgedb-derive/src/lib.rs b/edgedb-derive/src/lib.rs index feb65c0a..d9939e10 100644 --- a/edgedb-derive/src/lib.rs +++ b/edgedb-derive/src/lib.rs @@ -2,140 +2,7 @@ Derive macro that allows structs and enums to be populated by database queries. -This derive can be used on structures with named fields (which correspond -to "shapes" in EdgeDB). Note that field order matters, so the struct below -corresponds to an EdgeDB `User` query with `first_name` followed by `age`. -A `DescriptorMismatch` will be returned if the fields in the Rust struct -are not in the same order as those in the query shape. - -```rust -# use edgedb_derive::Queryable; -#[derive(Queryable)] -struct User { - first_name: String, - age: i32, -} -``` - -This allows a query to directly unpack into the type instead -of working with the [Value](https://docs.rs/edgedb-protocol/latest/edgedb_protocol/value/enum.Value.html) enum. - -```rust,ignore -let query = "select User { first_name, age };"; -// With Queryable: -let query_res: Vec = client.query(query, &()).await?; -// Without Queryable: -let query_res: Vec = client.query(query, &()).await?; -``` - -# Field attributes - -## JSON - -The `#[edgedb(json)]` attribute decodes a field using `serde_json` instead -of the EdgeDB binary protocol. This is useful if some data is stored in -the database as JSON, but you need to process it. The underlying type must -implement `serde::Deserialize`. - -```rust -# use std::collections::HashMap; -# use edgedb_derive::Queryable; - -#[derive(Queryable)] -struct User { - #[edgedb(json)] - user_notes: HashMap, -} -``` - -# Container attributes - -## JSON - -The `#[edgedb(json)]` attribute can be used to unpack the structure from -the returned JSON. The underlying type must implement -`serde::Deserialize`. - -```rust -# use edgedb_derive::Queryable; -#[derive(Queryable, serde::Deserialize)] -#[edgedb(json)] -struct JsonData { - field1: String, - field2: u32, -} -``` - -This allows a query to directly unpack into the type without an intermediate -step using [serde_json::from_str](https://docs.rs/serde_json/latest/serde_json/fn.from_str.html): - -```rust,ignore -let query = "select JsonData { field1, field2 };"; -let query_res: Vec = client.query(query, &()).await?; -``` - +This crate has been renamed to [gel-derive](https://crates.io/crates/gel-derive). */ -extern crate proc_macro; - -use proc_macro::TokenStream; -use syn::parse_macro_input; - -mod attrib; -mod enums; -mod json; -mod shape; -mod variables; - -#[proc_macro_derive(Queryable, attributes(edgedb))] -pub fn edgedb_queryable(input: TokenStream) -> TokenStream { - let s = parse_macro_input!(input as syn::Item); - match derive(&s) { - Ok(stream) => stream.into(), - Err(e) => e.to_compile_error().into(), - } -} - -fn derive(item: &syn::Item) -> syn::Result { - let attrs = match item { - syn::Item::Struct(s) => &s.attrs, - syn::Item::Enum(e) => &e.attrs, - _ => { - return Err(syn::Error::new_spanned( - item, - "can only derive Queryable for structs and enums", - )); - } - }; - let attrs = attrib::ContainerAttrs::from_syn(attrs)?; - if attrs.json { - json::derive(item) - } else { - match item { - syn::Item::Struct(s) => shape::derive_struct(s), - syn::Item::Enum(s) => enums::derive_enum(s), - _ => Err(syn::Error::new_spanned( - item, - "can only derive Queryable for a struct and enum \ - in non-JSON mode", - )), - } - } -} - -#[proc_macro_derive(GlobalsDelta, attributes(edgedb))] -pub fn globals_delta(input: TokenStream) -> TokenStream { - let s = parse_macro_input!(input as syn::ItemStruct); - match variables::derive_globals(&s) { - Ok(stream) => stream.into(), - Err(e) => e.to_compile_error().into(), - } -} -#[proc_macro_derive(ConfigDelta, attributes(edgedb))] -pub fn config_delta(input: TokenStream) -> TokenStream { - let s = parse_macro_input!(input as syn::ItemStruct); - match variables::derive_config(&s) { - Ok(stream) => stream.into(), - Err(e) => e.to_compile_error().into(), - } -} +compile_error!("edgedb-derive has been renamed to gel-derive"); diff --git a/edgedb-derive/src/shape.rs b/edgedb-derive/src/shape.rs deleted file mode 100644 index d40991f0..00000000 --- a/edgedb-derive/src/shape.rs +++ /dev/null @@ -1,190 +0,0 @@ -use proc_macro2::{Span, TokenStream}; -use quote::quote; - -use crate::attrib::FieldAttrs; - -struct Field { - name: syn::Ident, - str_name: syn::LitStr, - ty: syn::Type, - attrs: FieldAttrs, -} - -pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result { - let name = &s.ident; - let decoder = syn::Ident::new("decoder", Span::mixed_site()); - let buf = syn::Ident::new("buf", Span::mixed_site()); - let nfields = syn::Ident::new("nfields", Span::mixed_site()); - let elements = syn::Ident::new("elements", Span::mixed_site()); - let (impl_generics, ty_generics, _) = s.generics.split_for_impl(); - let fields = match &s.fields { - syn::Fields::Named(named) => { - let mut fields = Vec::with_capacity(named.named.len()); - for field in &named.named { - let attrs = FieldAttrs::from_syn(&field.attrs)?; - let name = field.ident.clone().unwrap(); - fields.push(Field { - str_name: syn::LitStr::new(&name.to_string(), name.span()), - name, - ty: field.ty.clone(), - attrs, - }); - } - fields - } - _ => { - return Err(syn::Error::new_spanned( - &s.fields, - "only named fields are supported", - )); - } - }; - let fieldname = fields.iter().map(|f| f.name.clone()).collect::>(); - let base_fields = fields.len(); - let type_id_block = Some(quote! { - if #decoder.has_implicit_tid { - #elements.skip_element()?; - } - }); - let type_name_block = Some(quote! { - if #decoder.has_implicit_tname { - #elements.skip_element()?; - } - }); - let id_block = Some(quote! { - if #decoder.has_implicit_id { - #elements.skip_element()?; - } - }); - let type_id_check = Some(quote! { - if ctx.has_implicit_tid { - if(!shape.elements[idx].flag_implicit) { - return ::std::result::Result::Err(ctx.expected("implicit __tid__")); - } - idx += 1; - } - }); - let type_name_check = Some(quote! { - if ctx.has_implicit_tname { - if(!shape.elements[idx].flag_implicit) { - return ::std::result::Result::Err(ctx.expected("implicit __tname__")); - } - idx += 1; - } - }); - let id_check = Some(quote! { - if ctx.has_implicit_id { - if(!shape.elements[idx].flag_implicit) { - return ::std::result::Result::Err(ctx.expected("implicit id")); - } - idx += 1; - } - }); - let field_decoders = fields - .iter() - .map(|field| { - let fieldname = &field.name; - if field.attrs.json { - quote! { - let #fieldname: ::edgedb_protocol::model::Json = - <::edgedb_protocol::model::Json as - ::edgedb_protocol::queryable::Queryable> - ::decode_optional(#decoder, #elements.read()?)?; - let #fieldname = ::serde_json::from_str(#fieldname.as_ref()) - .map_err(::edgedb_protocol::errors::decode_error)?; - } - } else { - quote! { - let #fieldname = - ::edgedb_protocol::queryable::Queryable - ::decode_optional(#decoder, #elements.read()?)?; - } - } - }) - .collect::(); - let field_checks = fields - .iter() - .map(|field| { - let name_str = &field.str_name; - let mut result = quote! { - let el = &shape.elements[idx]; - if(el.name != #name_str) { - return ::std::result::Result::Err(ctx.wrong_field(#name_str, &el.name)); - } - idx += 1; - }; - let fieldtype = &field.ty; - if field.attrs.json { - result.extend(quote! { - <::edgedb_protocol::model::Json as - ::edgedb_protocol::queryable::Queryable> - ::check_descriptor(ctx, el.type_pos)?; - }); - } else { - result.extend(quote! { - <#fieldtype as ::edgedb_protocol::queryable::Queryable> - ::check_descriptor(ctx, el.type_pos)?; - }); - } - result - }) - .collect::(); - - let field_count = fields.len(); - - let expanded = quote! { - impl #impl_generics ::edgedb_protocol::queryable::Queryable - for #name #ty_generics { - fn decode(#decoder: &::edgedb_protocol::queryable::Decoder, #buf: &[u8]) - -> ::std::result::Result - { - let #nfields = #base_fields - + if #decoder.has_implicit_id { 1 } else { 0 } - + if #decoder.has_implicit_tid { 1 } else { 0 } - + if #decoder.has_implicit_tname { 1 } else { 0 }; - let mut #elements = - ::edgedb_protocol::serialization::decode::DecodeTupleLike - ::new_object(#buf, #nfields)?; - - #type_id_block - #type_name_block - #id_block - #field_decoders - ::std::result::Result::Ok(#name { - #( - #fieldname, - )* - }) - } - fn check_descriptor( - ctx: &::edgedb_protocol::queryable::DescriptorContext, - type_pos: ::edgedb_protocol::descriptors::TypePos) - -> ::std::result::Result<(), ::edgedb_protocol::queryable::DescriptorMismatch> - { - use ::edgedb_protocol::descriptors::Descriptor::ObjectShape; - let desc = ctx.get(type_pos)?; - let shape = match desc { - ObjectShape(shape) => shape, - _ => { - return ::std::result::Result::Err(ctx.wrong_type(desc, "str")) - } - }; - - // TODO(tailhook) cache shape.id somewhere - let mut idx = 0; - - #type_id_check - #type_name_check - #id_check - if(shape.elements.len() != #field_count) { - return ::std::result::Result::Err(ctx.field_number( - #field_count, shape.elements.len()) - ); - } - #field_checks - ::std::result::Result::Ok(()) - } - } - }; - Ok(expanded) -} diff --git a/edgedb-derive/src/variables.rs b/edgedb-derive/src/variables.rs deleted file mode 100644 index 76ab4600..00000000 --- a/edgedb-derive/src/variables.rs +++ /dev/null @@ -1,80 +0,0 @@ -use proc_macro2::Span; -use quote::quote; - -pub fn derive_globals(item: &syn::ItemStruct) -> syn::Result { - // TODO(tailhook) add namespace annotations - - let man = syn::Ident::new("man", Span::mixed_site()); - - let fields = match &item.fields { - syn::Fields::Named(fields) => fields, - _ => { - return Err(syn::Error::new_spanned( - &item.fields, - "only named fields are supported", - )); - } - }; - let set_vars = fields - .named - .iter() - .map(|f| { - let ident = f.ident.as_ref().expect("a named field"); - let name = ident.to_string(); - quote! { #man.set(#name, self.#ident); } - }) - .collect::>(); - - let name = &item.ident; - let (impl_generics, ty_generics, where_c) = item.generics.split_for_impl(); - let expanded = quote! { - impl #impl_generics ::edgedb_tokio::state::GlobalsDelta - for &'_ #name #ty_generics - #where_c - { - fn apply(self, #man: &mut ::edgedb_tokio::state::GlobalsModifier) - { - #(#set_vars)* - } - } - }; - Ok(expanded) -} - -pub fn derive_config(item: &syn::ItemStruct) -> syn::Result { - let man = syn::Ident::new("man", Span::mixed_site()); - - let fields = match &item.fields { - syn::Fields::Named(fields) => fields, - _ => { - return Err(syn::Error::new_spanned( - &item.fields, - "only named fields are supported", - )); - } - }; - let set_vars = fields - .named - .iter() - .map(|f| { - let ident = f.ident.as_ref().expect("a named field"); - let name = ident.to_string(); - quote! { #man.set(#name, self.#ident); } - }) - .collect::>(); - - let name = &item.ident; - let (impl_generics, ty_generics, where_c) = item.generics.split_for_impl(); - let expanded = quote! { - impl #impl_generics ::edgedb_tokio::state::ConfigDelta - for &'_ #name #ty_generics - #where_c - { - fn apply(self, #man: &mut ::edgedb_tokio::state::ConfigModifier) - { - #(#set_vars)* - } - } - }; - Ok(expanded) -} diff --git a/edgedb-derive/tests/enums.rs b/edgedb-derive/tests/enums.rs deleted file mode 100644 index 578822cd..00000000 --- a/edgedb-derive/tests/enums.rs +++ /dev/null @@ -1,24 +0,0 @@ -use edgedb_derive::Queryable; -use edgedb_protocol::queryable::{Decoder, Queryable}; - -#[derive(Queryable, Debug, PartialEq)] -enum Status { - Open, - Closed, - Invalid, -} - -#[test] -fn enumeration() { - let dec = Decoder::default(); - assert_eq!(Status::decode(&dec, &b"Open"[..]).unwrap(), Status::Open); - assert_eq!( - Status::decode(&dec, &b"Closed"[..]).unwrap(), - Status::Closed - ); - assert_eq!( - Status::decode(&dec, &b"Invalid"[..]).unwrap(), - Status::Invalid - ); - Status::decode(&dec, &b"closed"[..]).unwrap_err(); -} diff --git a/edgedb-derive/tests/fail.rs b/edgedb-derive/tests/fail.rs deleted file mode 100644 index 00718467..00000000 --- a/edgedb-derive/tests/fail.rs +++ /dev/null @@ -1,5 +0,0 @@ -#[test] -fn fail() { - let t = trybuild::TestCases::new(); - t.compile_fail("tests/fail/*.rs"); -} diff --git a/edgedb-derive/tests/json.rs b/edgedb-derive/tests/json.rs deleted file mode 100644 index 8c16e7cf..00000000 --- a/edgedb-derive/tests/json.rs +++ /dev/null @@ -1,52 +0,0 @@ -use edgedb_derive::Queryable; -use edgedb_protocol::queryable::{Decoder, Queryable}; -use serde::Deserialize; - -#[derive(Debug, PartialEq, Deserialize)] -struct Data { - field1: u32, -} - -#[derive(Queryable, Debug, PartialEq)] -struct ShapeWithJson { - name: String, - #[edgedb(json)] - data: Data, -} - -#[derive(Queryable, Deserialize, Debug, PartialEq)] -#[edgedb(json)] -struct JsonRow { - field2: u32, -} - -fn old_decoder() -> Decoder { - let mut dec = Decoder::default(); - dec.has_implicit_id = true; - dec.has_implicit_tid = true; - dec -} - -#[test] -fn json_field() { - let data = b"\0\0\0\x04\0\0\x0b\x86\0\0\0\x10\xf2R\ - \x04I\xd7\x04\x11\xea\xaeX\xcf\xdf\xf6\xd0Q\xac\ - \0\0\x0b\x86\0\0\0\x10\xf2\xe6F9\xd7\x04\x11\xea\ - \xa0<\x83\x9f\xd9\xbd\x88\x94\0\0\0\x19\ - \0\0\0\x02id\0\0\x0e\xda\0\0\0\x10\x01{\"field1\": 123}"; - let res = ShapeWithJson::decode(&old_decoder(), data); - assert_eq!( - res.unwrap(), - ShapeWithJson { - name: "id".into(), - data: Data { field1: 123 }, - } - ); -} - -#[test] -fn json_row() { - let data = b"\x01{\"field2\": 234}"; - let res = JsonRow::decode(&old_decoder(), data); - assert_eq!(res.unwrap(), JsonRow { field2: 234 }); -} diff --git a/edgedb-derive/tests/list_scalar_types.rs b/edgedb-derive/tests/list_scalar_types.rs deleted file mode 100644 index 2de8f22b..00000000 --- a/edgedb-derive/tests/list_scalar_types.rs +++ /dev/null @@ -1,50 +0,0 @@ -use edgedb_derive::Queryable; -use edgedb_protocol::queryable::{Decoder, Queryable}; - -#[derive(Queryable, Debug, PartialEq)] -struct ScalarType { - name: String, - extending: String, - kind: String, -} - -fn old_decoder() -> Decoder { - let mut dec = Decoder::default(); - dec.has_implicit_id = true; - dec.has_implicit_tid = true; - dec -} - -#[test] -fn decode_new() { - let data = b"\0\0\0\x03\0\0\0\x19\0\0\0\x0fcal::local_date\ - \0\0\0\x19\0\0\0 std::anyscalar, std::anydiscrete\ - \0\0\0\x19\0\0\0\x06normal"; - let res = ScalarType::decode(&Decoder::default(), data); - assert_eq!( - res.unwrap(), - ScalarType { - name: "cal::local_date".into(), - extending: "std::anyscalar, std::anydiscrete".into(), - kind: "normal".into(), - } - ); -} - -#[test] -fn decode_old() { - let data = b"\0\0\0\x05\0\0\x0b\x86\ - \0\0\0\x10\xb2\xa1\x94\xfb\t\xa4\x11\xeb\x9d\x97\xf9'\ - \xee\xfc\xb6\x12\0\0\x0b\x86\0\0\0\x10\0\0\0\0\0\0\0\0\0\0\0\0\0\0\ - \x01\x0c\0\0\0\x19\0\0\0\x0fcal::local_date\ - \0\0\0\x19\0\0\0\x0estd::anyscalar\0\0\0\x19\0\0\0\x06normal"; - let res = ScalarType::decode(&old_decoder(), data); - assert_eq!( - res.unwrap(), - ScalarType { - name: "cal::local_date".into(), - extending: "std::anyscalar".into(), - kind: "normal".into(), - } - ); -} diff --git a/edgedb-derive/tests/no-prelude.rs b/edgedb-derive/tests/no-prelude.rs deleted file mode 100644 index 9165c20e..00000000 --- a/edgedb-derive/tests/no-prelude.rs +++ /dev/null @@ -1,9 +0,0 @@ -// a macro should always fully qualify names it uses, so it works even if the relevant item hasn't been imported, or a conflicting item has been imported -// this test ensures that the code generated by the derive Queryable macro does not reference names from the prelude -#![no_implicit_prelude] - -#[derive(::edgedb_derive::Queryable)] -#[allow(dead_code)] -struct Test { - field: i64, -} diff --git a/edgedb-derive/tests/varnames.rs b/edgedb-derive/tests/varnames.rs deleted file mode 100644 index f72d4f9b..00000000 --- a/edgedb-derive/tests/varnames.rs +++ /dev/null @@ -1,27 +0,0 @@ -use edgedb_derive::Queryable; -use edgedb_protocol::queryable::{Decoder, Queryable}; - -#[derive(Queryable, Debug, PartialEq)] -struct WeirdStruct { - nfields: i64, - elements: String, - decoder: String, - buf: i64, -} - -#[test] -fn decode() { - let data = b"\0\0\0\x04\0\0\0\x14\0\0\0\x08\0\0\0\0\0\0\x03\0\0\0\ - \0\x19\0\0\0\0\0\0\0\x19\0\0\0\x0bSomeDecoder\ - \0\0\0\x14\0\0\0\x08\0\0\0\0\0\0\0{"; - let res = WeirdStruct::decode(&Decoder::default(), data); - assert_eq!( - res.unwrap(), - WeirdStruct { - decoder: "SomeDecoder".into(), - buf: 123, - nfields: 768, - elements: "".into(), - } - ); -} diff --git a/edgedb-errors/Cargo.toml b/edgedb-errors/Cargo.toml index 8493aa01..82cda80c 100644 --- a/edgedb-errors/Cargo.toml +++ b/edgedb-errors/Cargo.toml @@ -1,20 +1,12 @@ [package] name = "edgedb-errors" license = "MIT/Apache-2.0" -version = "0.4.2" +version = "0.5.0" authors = ["MagicStack Inc. "] edition = "2018" description = """ Error types for EdgeDB database client. + This crate has been renamed to gel-errors. """ readme = "README.md" rust-version.workspace = true - -[dependencies] -bytes = "1.0.1" -miette = { version = "7.2.0", optional = true } - -[lib] - -[lints] -workspace = true diff --git a/edgedb-errors/README.md b/edgedb-errors/README.md index 3dca7a65..c57bce23 100644 --- a/edgedb-errors/README.md +++ b/edgedb-errors/README.md @@ -1,19 +1,4 @@ EdgeDB Rust Binding: Errors Crate ================================= -This crate contains definitions of errors returned from the database. - -* [Documentation](https://docs.rs/edgedb-errors) -* [Tokio Client](https://docs.rs/edgedb-tokio) - -License -======= - - -Licensed under either of - -* Apache License, Version 2.0, - (./LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) -* MIT license (./LICENSE-MIT or http://opensource.org/licenses/MIT) - -at your option. +> This crate has been renamed to [gel-errors](https://crates.io/crates/gel-errors). diff --git a/edgedb-errors/src/bin/edgedb_gen_errors.rs b/edgedb-errors/src/bin/edgedb_gen_errors.rs deleted file mode 100644 index 2d1db48d..00000000 --- a/edgedb-errors/src/bin/edgedb_gen_errors.rs +++ /dev/null @@ -1,143 +0,0 @@ -use std::collections::{BTreeMap, BTreeSet}; -use std::env::args; -use std::fs; - -fn find_tag<'x>(template: &'x str, tag: &str) -> (usize, usize, &'x str) { - let tag_line = format!("// <{}>\n", tag); - let pos = template - .find(&tag_line) - .unwrap_or_else(|| panic!("missing tag <{}>", tag)); - let indent = template[..pos].rfind('\n').unwrap_or(0) + 1; - (pos, pos + tag_line.len(), &template[indent..pos]) -} - -fn find_macro<'x>(template: &'x str, name: &str) -> &'x str { - let macro_line = format!("macro_rules! {} {{", name); - let pos = template - .find(¯o_line) - .map(|pos| pos + macro_line.len()) - .unwrap_or_else(|| panic!("missing macro {}", name)); - let body = template[pos..] - .find('{') - .map(|x| pos + x + 1) - .and_then(|open| { - let mut level = 0; - for (idx, c) in template[open..].char_indices() { - match c { - '}' if level == 0 => return Some((open, open + idx)), - '}' => level -= 1, - '{' => level += 1, - _ => {} - } - } - None - }) - .map(|(begin, end)| template[begin..end].trim()) - .expect("invalid macro"); - body -} - -fn main() -> Result<(), Box> { - let filename = args().nth(1).expect("single argument"); - let mut all_errors = Vec::new(); - let mut all_tags = BTreeSet::<&str>::new(); - let data = fs::read_to_string(filename)?; - for line in data.lines() { - let line = line.trim(); - if line.is_empty() || line.starts_with('#') { - continue; - } - let mut parts = line.split_whitespace(); - let code = u32::from_str_radix( - &parts - .next() - .expect("code always specified") - .strip_prefix("0x") - .expect("code contains 0x") - .replace('_', ""), - 16, - ) - .expect("code is valid hex"); - let name = parts.next().expect("name always specified"); - let tags: Vec<_> = parts - .map(|x| x.strip_prefix('#')) - .collect::>() - .expect("tags must follow name"); - all_tags.extend(&tags); - all_errors.push((code, name, tags)); - } - - let tag_masks = all_tags - .iter() - .enumerate() - .map(|(bit, tag)| (tag, 1 << bit as u32)) - .collect::>(); - - let tmp_errors = all_errors - .into_iter() - .map(|(code, name, tags)| { - let tags = tags.iter().map(|t| *tag_masks.get(t).unwrap()).sum(); - (code, (name, tags)) - }) - .collect::>(); - - let mut all_errors = BTreeMap::::new(); - // propagate tags from error superclasses - for (code, (name, mut tags)) in tmp_errors { - for (&scode, (_, stags)) in all_errors.iter().rev() { - let mask_bits = (scode.trailing_zeros() / 8) * 8; - let mask = 0xFFFFFFFF_u32 << mask_bits; - if code & mask == scode { - tags |= stags; - } - if mask_bits == 24 { - // first byte checked no more matches possible - // (errors are sorted by code) - break; - } - } - all_errors.insert(code, (name, tags)); - } - - let outfile = "./edgedb-errors/src/kinds.rs"; - let template = fs::read_to_string(outfile)?; - let mut out = String::with_capacity(template.len() + 100); - - let (_, def_start, indent) = find_tag(&template, "define_tag"); - out.push_str(&template[..def_start]); - - let define_tag = find_macro(&template, "define_tag"); - for (bit, tag) in all_tags.iter().enumerate() { - out.push_str(indent); - out.push_str( - &define_tag - .replace("$name", tag) - .replace("$bit", &bit.to_string()), - ); - out.push('\n'); - } - - let (def_end, _, _) = find_tag(&template, "/define_tag"); - let (_, err_start, indent) = find_tag(&template, "define_error"); - out.push_str(&template[def_end..err_start]); - - let define_err = find_macro(&template, "define_error"); - for (code, (name, tags)) in all_errors.iter() { - out.push_str(indent); - out.push_str( - &define_err - .replace("$name", name) - .replace("$code", &format!("0x{:08X}u32", code)) - .replace("$tag_bits", &format!("0x{:08x}", tags)), - ); - out.push('\n'); - } - - let (err_end, _, _) = find_tag(&template, "/define_error"); - out.push_str(indent); - out.push_str(&template[err_end..]); - - fs::write(outfile, out)?; - - Ok(()) -} diff --git a/edgedb-errors/src/display.rs b/edgedb-errors/src/display.rs deleted file mode 100644 index 9d2000ec..00000000 --- a/edgedb-errors/src/display.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::fmt; - -use crate::{Error, InternalServerError}; - -pub struct DisplayError<'a>(&'a Error, bool); -pub struct VerboseError<'a>(&'a Error); - -struct DisplayNum(Option); - -pub fn display_error(e: &Error, verbose: bool) -> DisplayError { - DisplayError(e, verbose) -} -pub fn display_error_verbose(e: &Error) -> VerboseError { - VerboseError(e) -} - -impl fmt::Display for DisplayError<'_> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let DisplayError(ref e, verbose) = self; - write!(f, "{:#}", e)?; - if e.is::() || *verbose { - if let Some(traceback) = e.server_traceback() { - write!(f, "\n Server traceback:")?; - for line in traceback.lines() { - write!(f, "\n {}", line)?; - } - } - } - Ok(()) - } -} - -impl fmt::Display for DisplayNum { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.0 { - Some(x) => x.fmt(f), - None => "?".fmt(f), - } - } -} - -impl fmt::Display for VerboseError<'_> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let e = self.0; - writeln!(f, "Error type: {}", e.kind_debug())?; - writeln!(f, "Message: {:#}", e)?; - let pstart = e.position_start(); - let pend = e.position_end(); - let line = e.line(); - let column = e.column(); - if [pstart, pend, line, column].iter().any(|x| x.is_some()) { - writeln!( - f, - "Span: {}-{}, line {}, column {}", - DisplayNum(pstart), - DisplayNum(pend), - DisplayNum(line), - DisplayNum(column) - )?; - } - if let Some(traceback) = e.server_traceback() { - writeln!(f, "Server traceback:")?; - for line in traceback.lines() { - writeln!(f, " {}", line)?; - } - } - - let attr = e.unknown_headers().collect::>(); - if !attr.is_empty() { - writeln!(f, "Other attributes:")?; - for (k, v) in attr { - writeln!(f, " 0x{:04x}: {:?}", k, v)?; - } - } - Ok(()) - } -} diff --git a/edgedb-errors/src/error.rs b/edgedb-errors/src/error.rs deleted file mode 100644 index 42489439..00000000 --- a/edgedb-errors/src/error.rs +++ /dev/null @@ -1,226 +0,0 @@ -use std::any::{Any, TypeId}; -use std::borrow::Cow; -use std::collections::HashMap; -use std::error::Error as StdError; -use std::fmt; -use std::str; - -use crate::kinds::UserError; -use crate::kinds::{error_name, tag_check}; -use crate::traits::{ErrorKind, Field}; - -const FIELD_HINT: u16 = 0x_00_01; -const FIELD_DETAILS: u16 = 0x_00_02; -const FIELD_SERVER_TRACEBACK: u16 = 0x_01_01; - -// TODO(tailhook) these might be deprecated? -const FIELD_POSITION_START: u16 = 0x_FF_F1; -const FIELD_POSITION_END: u16 = 0x_FF_F2; -const FIELD_LINE: u16 = 0x_FF_F3; -const FIELD_COLUMN: u16 = 0x_FF_F4; - -/// Error type returned from any EdgeDB call. -// This includes boxed error, because propagating through call chain is -// faster when error is just one pointer -#[derive(Debug)] -pub struct Error(pub(crate) Box); - -pub struct Chain<'a>(Option<&'a (dyn StdError + 'static)>); - -/// Tag that is used to group similar errors. -#[derive(Clone, Copy)] -pub struct Tag { - pub(crate) bit: u32, -} - -pub(crate) enum Source { - Box(Box), - Ref(Box + Send + Sync + 'static>), -} - -#[derive(Debug)] -pub(crate) struct Inner { - pub code: u32, - // TODO(tailhook) possibly put message into the fields too - pub messages: Vec>, - pub error: Option, - // TODO(tailhook) put headers into the fields - pub headers: HashMap, - pub fields: HashMap<(&'static str, TypeId), Box>, -} - -impl Error { - pub fn is(&self) -> bool { - T::is_superclass_of(self.0.code) - } - pub fn has_tag(&self, tag: Tag) -> bool { - tag_check(self.0.code, tag.bit) - } - pub fn chain(&self) -> Chain { - Chain(Some(self)) - } - pub fn context>>(mut self, msg: S) -> Error { - self.0.messages.push(msg.into()); - self - } - pub fn headers(&self) -> &HashMap { - &self.0.headers - } - pub fn with_headers(mut self, headers: HashMap) -> Error { - self.0.headers = headers; - self - } - pub fn kind_name(&self) -> &str { - error_name(self.0.code) - } - pub fn kind_debug(&self) -> impl fmt::Display { - format!("{} [0x{:08X}]", error_name(self.0.code), self.0.code) - } - pub fn initial_message(&self) -> Option<&str> { - self.0.messages.first().map(|m| &m[..]) - } - pub fn contexts(&self) -> impl DoubleEndedIterator { - self.0.messages.iter().skip(1).map(|m| &m[..]) - } - fn header(&self, field: u16) -> Option<&str> { - if let Some(value) = self.headers().get(&field) { - if let Ok(value) = str::from_utf8(value) { - return Some(value); - } - } - None - } - fn usize_header(&self, field: u16) -> Option { - self.header(field) - .and_then(|x| x.parse::().ok()) - .map(|x| x as usize) - } - pub fn hint(&self) -> Option<&str> { - self.header(FIELD_HINT) - } - pub fn details(&self) -> Option<&str> { - self.header(FIELD_DETAILS) - } - pub fn server_traceback(&self) -> Option<&str> { - self.header(FIELD_SERVER_TRACEBACK) - } - pub fn position_start(&self) -> Option { - self.usize_header(FIELD_POSITION_START) - } - pub fn position_end(&self) -> Option { - self.usize_header(FIELD_POSITION_END) - } - pub fn line(&self) -> Option { - self.usize_header(FIELD_LINE) - } - pub fn column(&self) -> Option { - self.usize_header(FIELD_COLUMN) - } - pub(crate) fn unknown_headers(&self) -> impl Iterator { - self.headers().iter().filter(|(key, _)| { - **key != FIELD_HINT - && **key != FIELD_DETAILS - && **key != FIELD_POSITION_START - && **key != FIELD_POSITION_END - && **key != FIELD_LINE - && **key != FIELD_COLUMN - }) - } - pub fn from_code(code: u32) -> Error { - Error(Box::new(Inner { - code, - messages: Vec::new(), - error: None, - headers: HashMap::new(), - fields: HashMap::new(), - })) - } - pub fn code(&self) -> u32 { - self.0.code - } - pub fn refine_kind(mut self) -> Error { - self.0.code = T::CODE; - self - } - pub fn set(mut self, value: impl Into) -> Error { - self.0 - .fields - .insert((T::NAME, TypeId::of::()), Box::new(value.into())); - self - } - pub fn get(&self) -> Option<&T::Value> { - self.0 - .fields - .get(&(T::NAME, TypeId::of::())) - .and_then(|bx| bx.downcast_ref::()) - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let kind = self.kind_name(); - if f.alternate() { - write!(f, "{}", kind)?; - for msg in self.0.messages.iter().rev() { - write!(f, ": {}", msg)?; - } - if let Some(mut src) = self.source() { - write!(f, ": {}", src)?; - while let Some(next) = src.source() { - write!(f, ": {}", next)?; - src = next; - } - } - } else if let Some(last) = self.0.messages.last() { - write!(f, "{}: {}", kind, last)?; - } else { - write!(f, "{}", kind)?; - } - if let Some((line, col)) = self.line().zip(self.column()) { - write!(f, " (on line {}, column {})", line, col)?; - } - if let Some(hint) = self.hint() { - write!(f, "\n Hint: {}", hint)?; - } - if let Some(detail) = self.details() { - write!(f, "\n Detail: {}", detail)?; - } - Ok(()) - } -} - -impl StdError for Error { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - self.0.error.as_ref().map(|s| match s { - Source::Box(b) => b.as_ref() as &dyn std::error::Error, - Source::Ref(b) => (**b).as_ref() as &dyn std::error::Error, - }) - } -} - -impl fmt::Debug for Source { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Source::Box(b) => fmt::Debug::fmt(b.as_ref(), f), - Source::Ref(b) => fmt::Debug::fmt((**b).as_ref(), f), - } - } -} - -impl From for Error -where - T: AsRef + Send + Sync + 'static, -{ - fn from(err: T) -> Error { - UserError::with_source_ref(err) - } -} - -impl<'a> Iterator for Chain<'a> { - type Item = &'a (dyn StdError + 'static); - fn next(&mut self) -> Option { - let result = self.0.take(); - self.0 = result.and_then(|e| e.source()); - result - } -} diff --git a/edgedb-errors/src/fields.rs b/edgedb-errors/src/fields.rs deleted file mode 100644 index fd1646c6..00000000 --- a/edgedb-errors/src/fields.rs +++ /dev/null @@ -1,8 +0,0 @@ -use crate::traits::Field; - -pub struct QueryText; - -impl Field for QueryText { - const NAME: &'static str = "source_code"; - type Value = String; -} diff --git a/edgedb-errors/src/kinds.rs b/edgedb-errors/src/kinds.rs deleted file mode 100644 index 40e66d7a..00000000 --- a/edgedb-errors/src/kinds.rs +++ /dev/null @@ -1,180 +0,0 @@ -/*! -All errors from the [EdgeDB protocol](https://www.edgedb.com/docs/reference/protocol/errors#error-codes) -*/ - -use crate::error::Tag; -use crate::traits::{ErrorKind, Sealed}; - -macro_rules! define_errors { - ($( (struct $id:ident, $code:expr, $tags: expr), )*) => { - $( - pub struct $id; - - impl Sealed for $id { - const CODE: u32 = $code; - const NAME: &'static str = stringify!($id); - const TAGS: u32 = $tags; - } - - impl ErrorKind for $id {} - )* - pub(crate) fn tag_check(code: u32, bit: u32) -> bool { - return get_tags(code) & (1 << bit) != 0; - } - pub(crate) fn get_tags(code: u32) -> u32 { - match code { - $( - $code => $tags, - )* - _ => 0, - } - } - pub(crate) fn error_name(code: u32) -> &'static str { - match code { - $( - $code => stringify!($id), - )* - _ => "EdgeDBError", - } - } - } -} - -// AUTOGENERATED BY EdgeDB WITH -// $ cargo run --bin edgedb_gen_errors -- errors.txt - -#[allow(unused_macros)] // fake macro for generator -macro_rules! define_tag { - ($name: ident, $bit: expr) => { - pub static $name: Tag = Tag { bit: $bit }; - }; -} - -// -pub static SHOULD_RECONNECT: Tag = Tag { bit: 0 }; -pub static SHOULD_RETRY: Tag = Tag { bit: 1 }; -// - -#[allow(unused_macros)] // fake macro for generator -macro_rules! define_error { - (struct $name: ident, $code: expr, $tag_bits: expr) => { - (struct $name, $code, $tag_bits), - } -} - -define_errors![ - // - (struct InternalServerError, 0x01000000u32, 0x00000000), - (struct UnsupportedFeatureError, 0x02000000u32, 0x00000000), - (struct ProtocolError, 0x03000000u32, 0x00000000), - (struct BinaryProtocolError, 0x03010000u32, 0x00000000), - (struct UnsupportedProtocolVersionError, 0x03010001u32, 0x00000000), - (struct TypeSpecNotFoundError, 0x03010002u32, 0x00000000), - (struct UnexpectedMessageError, 0x03010003u32, 0x00000000), - (struct InputDataError, 0x03020000u32, 0x00000000), - (struct ParameterTypeMismatchError, 0x03020100u32, 0x00000000), - (struct StateMismatchError, 0x03020200u32, 0x00000002), - (struct ResultCardinalityMismatchError, 0x03030000u32, 0x00000000), - (struct CapabilityError, 0x03040000u32, 0x00000000), - (struct UnsupportedCapabilityError, 0x03040100u32, 0x00000000), - (struct DisabledCapabilityError, 0x03040200u32, 0x00000000), - (struct QueryError, 0x04000000u32, 0x00000000), - (struct InvalidSyntaxError, 0x04010000u32, 0x00000000), - (struct EdgeQLSyntaxError, 0x04010100u32, 0x00000000), - (struct SchemaSyntaxError, 0x04010200u32, 0x00000000), - (struct GraphQLSyntaxError, 0x04010300u32, 0x00000000), - (struct InvalidTypeError, 0x04020000u32, 0x00000000), - (struct InvalidTargetError, 0x04020100u32, 0x00000000), - (struct InvalidLinkTargetError, 0x04020101u32, 0x00000000), - (struct InvalidPropertyTargetError, 0x04020102u32, 0x00000000), - (struct InvalidReferenceError, 0x04030000u32, 0x00000000), - (struct UnknownModuleError, 0x04030001u32, 0x00000000), - (struct UnknownLinkError, 0x04030002u32, 0x00000000), - (struct UnknownPropertyError, 0x04030003u32, 0x00000000), - (struct UnknownUserError, 0x04030004u32, 0x00000000), - (struct UnknownDatabaseError, 0x04030005u32, 0x00000000), - (struct UnknownParameterError, 0x04030006u32, 0x00000000), - (struct SchemaError, 0x04040000u32, 0x00000000), - (struct SchemaDefinitionError, 0x04050000u32, 0x00000000), - (struct InvalidDefinitionError, 0x04050100u32, 0x00000000), - (struct InvalidModuleDefinitionError, 0x04050101u32, 0x00000000), - (struct InvalidLinkDefinitionError, 0x04050102u32, 0x00000000), - (struct InvalidPropertyDefinitionError, 0x04050103u32, 0x00000000), - (struct InvalidUserDefinitionError, 0x04050104u32, 0x00000000), - (struct InvalidDatabaseDefinitionError, 0x04050105u32, 0x00000000), - (struct InvalidOperatorDefinitionError, 0x04050106u32, 0x00000000), - (struct InvalidAliasDefinitionError, 0x04050107u32, 0x00000000), - (struct InvalidFunctionDefinitionError, 0x04050108u32, 0x00000000), - (struct InvalidConstraintDefinitionError, 0x04050109u32, 0x00000000), - (struct InvalidCastDefinitionError, 0x0405010Au32, 0x00000000), - (struct DuplicateDefinitionError, 0x04050200u32, 0x00000000), - (struct DuplicateModuleDefinitionError, 0x04050201u32, 0x00000000), - (struct DuplicateLinkDefinitionError, 0x04050202u32, 0x00000000), - (struct DuplicatePropertyDefinitionError, 0x04050203u32, 0x00000000), - (struct DuplicateUserDefinitionError, 0x04050204u32, 0x00000000), - (struct DuplicateDatabaseDefinitionError, 0x04050205u32, 0x00000000), - (struct DuplicateOperatorDefinitionError, 0x04050206u32, 0x00000000), - (struct DuplicateViewDefinitionError, 0x04050207u32, 0x00000000), - (struct DuplicateFunctionDefinitionError, 0x04050208u32, 0x00000000), - (struct DuplicateConstraintDefinitionError, 0x04050209u32, 0x00000000), - (struct DuplicateCastDefinitionError, 0x0405020Au32, 0x00000000), - (struct DuplicateMigrationError, 0x0405020Bu32, 0x00000000), - (struct SessionTimeoutError, 0x04060000u32, 0x00000000), - (struct IdleSessionTimeoutError, 0x04060100u32, 0x00000002), - (struct QueryTimeoutError, 0x04060200u32, 0x00000000), - (struct TransactionTimeoutError, 0x04060A00u32, 0x00000000), - (struct IdleTransactionTimeoutError, 0x04060A01u32, 0x00000000), - (struct ExecutionError, 0x05000000u32, 0x00000000), - (struct InvalidValueError, 0x05010000u32, 0x00000000), - (struct DivisionByZeroError, 0x05010001u32, 0x00000000), - (struct NumericOutOfRangeError, 0x05010002u32, 0x00000000), - (struct AccessPolicyError, 0x05010003u32, 0x00000000), - (struct QueryAssertionError, 0x05010004u32, 0x00000000), - (struct IntegrityError, 0x05020000u32, 0x00000000), - (struct ConstraintViolationError, 0x05020001u32, 0x00000000), - (struct CardinalityViolationError, 0x05020002u32, 0x00000000), - (struct MissingRequiredError, 0x05020003u32, 0x00000000), - (struct TransactionError, 0x05030000u32, 0x00000000), - (struct TransactionConflictError, 0x05030100u32, 0x00000002), - (struct TransactionSerializationError, 0x05030101u32, 0x00000002), - (struct TransactionDeadlockError, 0x05030102u32, 0x00000002), - (struct WatchError, 0x05040000u32, 0x00000000), - (struct ConfigurationError, 0x06000000u32, 0x00000000), - (struct AccessError, 0x07000000u32, 0x00000000), - (struct AuthenticationError, 0x07010000u32, 0x00000000), - (struct AvailabilityError, 0x08000000u32, 0x00000000), - (struct BackendUnavailableError, 0x08000001u32, 0x00000002), - (struct ServerOfflineError, 0x08000002u32, 0x00000003), - (struct UnknownTenantError, 0x08000003u32, 0x00000003), - (struct ServerBlockedError, 0x08000004u32, 0x00000000), - (struct BackendError, 0x09000000u32, 0x00000000), - (struct UnsupportedBackendFeatureError, 0x09000100u32, 0x00000000), - (struct LogMessage, 0xF0000000u32, 0x00000000), - (struct WarningMessage, 0xF0010000u32, 0x00000000), - (struct ClientError, 0xFF000000u32, 0x00000000), - (struct ClientConnectionError, 0xFF010000u32, 0x00000000), - (struct ClientConnectionFailedError, 0xFF010100u32, 0x00000000), - (struct ClientConnectionFailedTemporarilyError, 0xFF010101u32, 0x00000003), - (struct ClientConnectionTimeoutError, 0xFF010200u32, 0x00000003), - (struct ClientConnectionClosedError, 0xFF010300u32, 0x00000003), - (struct InterfaceError, 0xFF020000u32, 0x00000000), - (struct QueryArgumentError, 0xFF020100u32, 0x00000000), - (struct MissingArgumentError, 0xFF020101u32, 0x00000000), - (struct UnknownArgumentError, 0xFF020102u32, 0x00000000), - (struct InvalidArgumentError, 0xFF020103u32, 0x00000000), - (struct NoDataError, 0xFF030000u32, 0x00000000), - (struct InternalClientError, 0xFF040000u32, 0x00000000), - // - (struct ProtocolTlsError, 0x03FF0000u32, 0x00000000), - (struct ProtocolOutOfOrderError, 0x03FE0000u32, 0x00000000), - (struct ProtocolEncodingError, 0x03FD0000u32, 0x00000000), - (struct PasswordRequired, 0x0701FF00u32, 0x00000000), - (struct ClientInconsistentError, 0xFFFF0000u32, 0x00000000), - (struct ClientEncodingError, 0xFFFE0000u32, 0x00000000), - (struct ClientNoCredentialsError, 0xFF0101FFu32, 0x00000000), - (struct NoCloudConfigFound, 0xFF0101FEu32, 0x00000000), - (struct ClientConnectionEosError, 0xFF01FF00u32, 0x00000000), - (struct NoResultExpected, 0xFF02FF00u32, 0x00000000), - (struct DescriptorMismatch, 0xFF02FE00u32, 0x00000000), - (struct UserError, 0xFE000000u32, 0x00000000), -]; diff --git a/edgedb-errors/src/lib.rs b/edgedb-errors/src/lib.rs index 3527317c..dbca5966 100644 --- a/edgedb-errors/src/lib.rs +++ b/edgedb-errors/src/lib.rs @@ -1,90 +1,7 @@ /*! -# Error Handling for EdgeDB +Error Handling for EdgeDB -All errors that EdgeDB Rust bindings produce are encapsulated into the -[`Error`] structure. The structure is a bit like `Box` or -[`anyhow::Error`], except it can only contain EdgeDB error types. Or -[`UserError`] can be used to encapsulate custom errors (commonly used -to return an error from a transaction). - -A full list of EdgeDB error types on a single page can be found on the [website documentation](https://www.edgedb.com/docs/reference/protocol/errors#error-codes). - -Each error kind is represented as a separate type that implements the -[`ErrorKind`] trait. But error kinds are used like marker structs; you can -use [`Error::is`] for error kinds and use them to create instances of the -error: - -```rust -# use std::io; -# use edgedb_errors::{UserError, ErrorKind}; -let err = UserError::with_source(io::Error::from(io::ErrorKind::NotFound)); -assert!(err.is::()); -``` - -Since errors are hirarchical, [`Error::is`] works with any ancestor: - -```rust -# use edgedb_errors::*; -# let err = MissingArgumentError::with_message("test error"); -assert!(err.is::()); -assert!(err.is::()); // implied by the assertion above -assert!(err.is::()); // and this one -assert!(err.is::()); // and this one -``` - -Error hierarchy doesn't have multiple inheritance (i.e. every error has only -single parent). When we match across different parents we use error tags: - -```rust -# use edgedb_errors::*; -# let err1 = ClientConnectionTimeoutError::with_message("test error"); -# let err2 = TransactionConflictError::with_message("test error"); - -assert!(err1.is::()); -assert!(err2.is::()); -// Both of these are retried -assert!(err1.has_tag(SHOULD_RETRY)); -assert!(err2.has_tag(SHOULD_RETRY)); - -// But they aren't a part of common hierarchy -assert!(err1.is::()); -assert!(!err1.is::()); -assert!(err2.is::()); -assert!(!err2.is::()); -``` - -[`anyhow::Error`]: https://docs.rs/anyhow/latest/anyhow/struct.Error.html - -# Errors in Transactions - -Special care for errors must be taken in transactions. Generally: - -1. Errors from queries should not be ignored, and should be propagagated - up to the transaction function. -2. User errors can be encapsulated into [`UserError`] via one of the - methods: - * [`ErrorKind::with_source`] (for any [`std::error::Error`]) - * [`ErrorKind::with_source_box`] already boxed error - * [`ErrorKind::with_source_ref`] for smart wrappers such as - [`anyhow::Error`] -3. Original query error must be propagated via error chain. It can be in the - `.source()` chain but must not be swallowed, otherwise retrying - transaction may work incorrectly. - -# Nice Error Reporting - -Refer to documentation in the [edgedb-tokio](https://docs.rs/edgedb-tokio) crate. +This crate has been renamed to [gel-errors](https://crates.io/crates/gel-errors). */ -mod error; -mod traits; - -pub mod display; -pub mod fields; -pub mod kinds; - -#[cfg(feature = "miette")] -pub mod miette; -pub use error::{Error, Tag}; -pub use kinds::*; -pub use traits::{ErrorKind, Field, ResultExt}; +compile_error!("edgedb-errors has been renamed to gel-errors"); diff --git a/edgedb-errors/src/miette.rs b/edgedb-errors/src/miette.rs deleted file mode 100644 index 53899040..00000000 --- a/edgedb-errors/src/miette.rs +++ /dev/null @@ -1,28 +0,0 @@ -//! Miette support for EdgeDB errors. Add "miette" feature flag to enable. -//! -//! [miette](https://docs.rs/miette/latest/miette/) allows nice error formatting via its [Diagnostic](https://docs.rs/miette/latest/miette/trait.Diagnostic.html) trait -//! -use miette::{LabeledSpan, SourceCode}; -use std::fmt::Display; - -use crate::fields::QueryText; -use crate::Error; - -impl miette::Diagnostic for Error { - fn code(&self) -> Option> { - Some(Box::new(self.kind_name())) - } - fn source_code(&self) -> Option<&dyn SourceCode> { - self.get::().map(|s| s as _) - } - fn labels(&self) -> Option + '_>> { - let (start, end) = self.position_start().zip(self.position_end())?; - let len = end - start; - Some(Box::new( - Some(LabeledSpan::new(self.hint().map(Into::into), start, len)).into_iter(), - )) - } - fn help(&self) -> Option> { - self.details().map(|v| Box::new(v) as Box) - } -} diff --git a/edgedb-errors/src/traits.rs b/edgedb-errors/src/traits.rs deleted file mode 100644 index 3c234a27..00000000 --- a/edgedb-errors/src/traits.rs +++ /dev/null @@ -1,98 +0,0 @@ -use std::borrow::Cow; -use std::collections::HashMap; - -use crate::error::{Error, Inner, Source}; - -/// Trait that marks EdgeDB errors. -/// -/// This is currently sealed because EdgeDB errors will be changed in future. -pub trait ErrorKind: Sealed { - fn with_message>>(s: S) -> Error { - Self::build().context(s) - } - fn with_source(src: E) -> Error { - Error(Box::new(Inner { - code: Self::CODE, - messages: Vec::new(), - error: Some(Source::Box(src.into())), - headers: HashMap::new(), - fields: HashMap::new(), - })) - } - fn with_source_box(src: Box) -> Error { - Error(Box::new(Inner { - code: Self::CODE, - messages: Vec::new(), - error: Some(Source::Box(src)), - headers: HashMap::new(), - fields: HashMap::new(), - })) - } - fn with_source_ref(src: T) -> Error - where - T: AsRef, - T: Send + Sync + 'static, - { - Error(Box::new(Inner { - code: Self::CODE, - messages: Vec::new(), - error: Some(Source::Ref(Box::new(src))), - headers: HashMap::new(), - fields: HashMap::new(), - })) - } - fn build() -> Error { - Error(Box::new(Inner { - code: Self::CODE, - messages: Vec::new(), - error: None, - headers: HashMap::new(), - fields: HashMap::new(), - })) - } -} - -pub trait Field { - const NAME: &'static str; - type Value: Send + Sync + 'static; -} - -pub trait ResultExt { - fn context(self, context: C) -> Result - where - C: Into>; - fn with_context(self, f: F) -> Result - where - C: Into>, - F: FnOnce() -> C; -} - -impl ResultExt for Result { - fn context(self, context: C) -> Result - where - C: Into>, - { - self.map_err(|e| e.context(context)) - } - fn with_context(self, f: F) -> Result - where - C: Into>, - F: FnOnce() -> C, - { - self.map_err(|e| e.context(f())) - } -} - -pub trait Sealed { - const CODE: u32; - const NAME: &'static str; - const TAGS: u32; - // TODO(tailhook) use uuids of errors instead - fn is_superclass_of(code: u32) -> bool { - let mask = 0xFFFFFFFF_u32 << ((Self::CODE.trailing_zeros() / 8) * 8); - code & mask == Self::CODE - } - fn has_tag(bit: u32) -> bool { - Self::TAGS & (1 << bit) != 0 - } -} diff --git a/edgedb-protocol/Cargo.toml b/edgedb-protocol/Cargo.toml index a462a4ee..efc7e36a 100644 --- a/edgedb-protocol/Cargo.toml +++ b/edgedb-protocol/Cargo.toml @@ -5,40 +5,20 @@ version = "0.6.1" authors = ["MagicStack Inc. "] edition = "2018" description = """ - Low-level protocol implemenentation for EdgeDB database client. + Low-level protocol implementation for EdgeDB database client. Use edgedb-tokio for applications instead. + This crate has been renamed to gel-protocol. """ readme = "README.md" rust-version.workspace = true -[dependencies] -bytes = "1.5.0" -snafu = {version="0.8.0", features=["backtrace"]} -uuid = "1.1.2" -num-bigint = {version="0.4.3", optional=true} -num-traits = {version="0.2.10", optional=true} -bigdecimal = {version="0.4.0", optional=true} -chrono = {version="0.4.23", optional=true, features=["std"], default-features=false} -edgedb-errors = {path = "../edgedb-errors", version = "0.4.0" } -bitflags = "2.4.0" -serde = {version="1.0.190", features = ["derive"], optional=true} -serde_json = {version="1", optional=true} - [features] default = [] -with-num-bigint = ["num-bigint", "num-traits"] -with-bigdecimal = ["bigdecimal", "num-bigint", "num-traits"] -with-chrono = ["chrono"] -all-types = ["with-num-bigint", "with-bigdecimal", "with-chrono"] -with-serde = ["serde", "serde_json"] - -[dev-dependencies] -rand = "0.8" -pretty_assertions = "1.2.1" -test-case = "3.0.0" -humantime = "2.1.0" - -[lib] +with-num-bigint = [] +with-bigdecimal = [] +with-chrono = [] +all-types = [] +with-serde = [] [lints] workspace = true diff --git a/edgedb-protocol/README.md b/edgedb-protocol/README.md index 3222e2b8..5c799883 100644 --- a/edgedb-protocol/README.md +++ b/edgedb-protocol/README.md @@ -1,13 +1,11 @@ EdgeDB Rust Binding: Protocol Crate ================================= +> This crate has been renamed to [gel-protocol](https://crates.io/crates/gel-protocol). + This crate contains data model types and internal protocol implementation for the EdgeDB client. -* [Documentation](https://docs.rs/edgedb-protocol) -* [Tokio Client](https://docs.rs/edgedb-tokio) - - License ======= diff --git a/edgedb-protocol/src/annotations.rs b/edgedb-protocol/src/annotations.rs deleted file mode 100644 index 981e1e6d..00000000 --- a/edgedb-protocol/src/annotations.rs +++ /dev/null @@ -1,197 +0,0 @@ -#[cfg(feature = "with-serde")] -use crate::encoding::Annotations; - -/// CommandDataDescription1 may contain "warnings" annotations, whose value is -/// a JSON array of this [Warning] type. -#[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Warning { - /// User-friendly explanation of the problem - pub message: String, - - /// Name of the Python exception class - pub r#type: String, - - /// Machine-friendly exception id - pub code: u64, - - /// Name of the source file that caused the warning. - #[cfg_attr(feature = "with-serde", serde(default))] - pub filename: Option, - - /// Additional user-friendly info - #[cfg_attr(feature = "with-serde", serde(default))] - pub hint: Option, - - /// Developer-friendly explanation of why this problem occured - #[cfg_attr(feature = "with-serde", serde(default))] - pub details: Option, - - /// Inclusive 0-based position within the source - #[cfg_attr( - feature = "with-serde", - serde(deserialize_with = "deserialize_usize_from_str", default) - )] - pub start: Option, - - /// Exclusive 0-based position within the source - #[cfg_attr( - feature = "with-serde", - serde(deserialize_with = "deserialize_usize_from_str", default) - )] - pub end: Option, - - /// 1-based index of the line of the start - #[cfg_attr( - feature = "with-serde", - serde(deserialize_with = "deserialize_usize_from_str", default) - )] - pub line: Option, - - /// 1-based index of the column of the start - #[cfg_attr( - feature = "with-serde", - serde(deserialize_with = "deserialize_usize_from_str", default) - )] - pub col: Option, -} - -impl std::fmt::Display for Warning { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let Warning { - filename, - line, - col, - r#type, - message, - .. - } = self; - let filename = filename - .as_ref() - .map(|f| format!("{f}:")) - .unwrap_or_default(); - let line = line.clone().unwrap_or(1); - let col = col.clone().unwrap_or(1); - - write!(f, "{type} at {filename}{line}:{col} {message}") - } -} - -#[cfg(feature = "with-serde")] -pub fn decode_warnings(annotations: &Annotations) -> Result, edgedb_errors::Error> { - use edgedb_errors::{ErrorKind, ProtocolEncodingError}; - - const ANN_NAME: &str = "warnings"; - - if let Some(warnings) = annotations.get(ANN_NAME) { - serde_json::from_str::>(warnings).map_err(|e| { - ProtocolEncodingError::with_source(e) - .context("Invalid JSON while decoding 'warnings' annotation") - }) - } else { - Ok(vec![]) - } -} - -#[cfg(feature = "with-serde")] -fn deserialize_usize_from_str<'de, D: serde::Deserializer<'de>>( - deserializer: D, -) -> Result, D::Error> { - use serde::Deserialize; - - #[derive(Deserialize)] - #[serde(untagged)] - enum StringOrInt { - String(String), - Number(usize), - } - - Option::::deserialize(deserializer)? - .map(|x| match x { - StringOrInt::String(s) => s.parse::().map_err(serde::de::Error::custom), - StringOrInt::Number(i) => Ok(i), - }) - .transpose() -} - -#[test] -#[cfg(feature = "with-serde")] -fn deserialize_warning() { - let a: Warning = - serde_json::from_str(r#"{"message": "a", "type": "WarningException", "code": 1}"#).unwrap(); - assert_eq!( - a, - Warning { - message: "a".to_string(), - r#type: "WarningException".to_string(), - code: 1, - filename: None, - hint: None, - details: None, - start: None, - end: None, - line: None, - col: None - } - ); - - let a: Warning = serde_json::from_str( - r#"{"message": "a", "type": "WarningException", "code": 1, "start": null}"#, - ) - .unwrap(); - assert_eq!( - a, - Warning { - message: "a".to_string(), - r#type: "WarningException".to_string(), - code: 1, - filename: None, - hint: None, - details: None, - start: None, - end: None, - line: None, - col: None - } - ); - - let a: Warning = serde_json::from_str( - r#"{"message": "a", "type": "WarningException", "code": 1, "start": 23}"#, - ) - .unwrap(); - assert_eq!( - a, - Warning { - message: "a".to_string(), - r#type: "WarningException".to_string(), - code: 1, - filename: None, - hint: None, - details: None, - start: Some(23), - end: None, - line: None, - col: None - } - ); - - let a: Warning = serde_json::from_str( - r#"{"message": "a", "type": "WarningException", "code": 1, "start": "23"}"#, - ) - .unwrap(); - assert_eq!( - a, - Warning { - message: "a".to_string(), - r#type: "WarningException".to_string(), - code: 1, - filename: None, - hint: None, - details: None, - start: Some(23), - end: None, - line: None, - col: None - } - ); -} diff --git a/edgedb-protocol/src/client_message.rs b/edgedb-protocol/src/client_message.rs deleted file mode 100644 index a19dc872..00000000 --- a/edgedb-protocol/src/client_message.rs +++ /dev/null @@ -1,1016 +0,0 @@ -/*! -([Website reference](https://www.edgedb.com/docs/reference/protocol/messages)) The [ClientMessage] enum and related types. - -```rust,ignore -pub enum ClientMessage { - ClientHandshake(ClientHandshake), - ExecuteScript(ExecuteScript), - Prepare(Prepare), - Parse(Parse), - DescribeStatement(DescribeStatement), - Execute0(Execute0), - Execute1(Execute1), - OptimisticExecute(OptimisticExecute), - UnknownMessage(u8, Bytes), - AuthenticationSaslInitialResponse(SaslInitialResponse), - AuthenticationSaslResponse(SaslResponse), - Dump(Dump), - Restore(Restore), - RestoreBlock(RestoreBlock), - RestoreEof, - Sync, - Flush, - Terminate, -} -``` -*/ - -use std::collections::HashMap; -use std::convert::TryFrom; -use std::sync::Arc; - -use bytes::{Buf, BufMut, Bytes}; -use snafu::{ensure, OptionExt}; -use uuid::Uuid; - -pub use crate::common::CompilationOptions; -pub use crate::common::DumpFlags; -pub use crate::common::{Capabilities, Cardinality, CompilationFlags}; -pub use crate::common::{RawTypedesc, State}; -use crate::encoding::{encode, Decode, Encode, Input, Output}; -use crate::encoding::{Annotations, KeyValues}; -use crate::errors::{self, DecodeError, EncodeError}; - -#[derive(Debug, Clone, PartialEq, Eq)] -#[non_exhaustive] -pub enum ClientMessage { - AuthenticationSaslInitialResponse(SaslInitialResponse), - AuthenticationSaslResponse(SaslResponse), - ClientHandshake(ClientHandshake), - Dump2(Dump2), - Dump3(Dump3), - Parse(Parse), // protocol > 1.0 - ExecuteScript(ExecuteScript), - Execute0(Execute0), - Execute1(Execute1), - Restore(Restore), - RestoreBlock(RestoreBlock), - RestoreEof, - Sync, - Terminate, - Prepare(Prepare), // protocol < 1.0 - DescribeStatement(DescribeStatement), - OptimisticExecute(OptimisticExecute), - UnknownMessage(u8, Bytes), - Flush, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct SaslInitialResponse { - pub method: String, - pub data: Bytes, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct SaslResponse { - pub data: Bytes, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ClientHandshake { - pub major_ver: u16, - pub minor_ver: u16, - pub params: HashMap, - pub extensions: HashMap, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ExecuteScript { - pub headers: KeyValues, - pub script_text: String, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Prepare { - pub headers: KeyValues, - pub io_format: IoFormat, - pub expected_cardinality: Cardinality, - pub statement_name: Bytes, - pub command_text: String, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Parse { - pub annotations: Option>, - pub allowed_capabilities: Capabilities, - pub compilation_flags: CompilationFlags, - pub implicit_limit: Option, - pub output_format: IoFormat, - pub expected_cardinality: Cardinality, - pub command_text: String, - pub state: State, - pub input_language: InputLanguage, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DescribeStatement { - pub headers: KeyValues, - pub aspect: DescribeAspect, - pub statement_name: Bytes, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Execute0 { - pub headers: KeyValues, - pub statement_name: Bytes, - pub arguments: Bytes, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Execute1 { - pub annotations: Option>, - pub allowed_capabilities: Capabilities, - pub compilation_flags: CompilationFlags, - pub implicit_limit: Option, - pub output_format: IoFormat, - pub expected_cardinality: Cardinality, - pub command_text: String, - pub state: State, - pub input_typedesc_id: Uuid, - pub output_typedesc_id: Uuid, - pub arguments: Bytes, - pub input_language: InputLanguage, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct OptimisticExecute { - pub headers: KeyValues, - pub io_format: IoFormat, - pub expected_cardinality: Cardinality, - pub command_text: String, - pub input_typedesc_id: Uuid, - pub output_typedesc_id: Uuid, - pub arguments: Bytes, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Dump2 { - pub headers: KeyValues, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Dump3 { - pub annotations: Option>, - pub flags: DumpFlags, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Restore { - pub headers: KeyValues, - pub jobs: u16, - pub data: Bytes, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct RestoreBlock { - pub data: Bytes, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum DescribeAspect { - DataDescription = 0x54, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum InputLanguage { - EdgeQL = 0x45, - SQL = 0x53, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum IoFormat { - Binary = 0x62, - Json = 0x6a, - JsonElements = 0x4a, - None = 0x6e, -} - -struct Empty; -impl ClientMessage { - pub fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - use ClientMessage::*; - match self { - ClientHandshake(h) => encode(buf, 0x56, h), - AuthenticationSaslInitialResponse(h) => encode(buf, 0x70, h), - AuthenticationSaslResponse(h) => encode(buf, 0x72, h), - ExecuteScript(h) => encode(buf, 0x51, h), - Prepare(h) => encode(buf, 0x50, h), - Parse(h) => encode(buf, 0x50, h), - DescribeStatement(h) => encode(buf, 0x44, h), - Execute0(h) => encode(buf, 0x45, h), - OptimisticExecute(h) => encode(buf, 0x4f, h), - Execute1(h) => encode(buf, 0x4f, h), - Dump2(h) => encode(buf, 0x3e, h), - Dump3(h) => encode(buf, 0x3e, h), - Restore(h) => encode(buf, 0x3c, h), - RestoreBlock(h) => encode(buf, 0x3d, h), - RestoreEof => encode(buf, 0x2e, &Empty), - Sync => encode(buf, 0x53, &Empty), - Flush => encode(buf, 0x48, &Empty), - Terminate => encode(buf, 0x58, &Empty), - - UnknownMessage(_, _) => errors::UnknownMessageCantBeEncoded.fail()?, - } - } - /// Decode exactly one frame from the buffer. - /// - /// This expects a full frame to already be in the buffer. It can return - /// an arbitrary error or be silent if a message is only partially present - /// in the buffer or if extra data is present. - pub fn decode(buf: &mut Input) -> Result { - use self::ClientMessage as M; - let mut data = buf.slice(5..); - let result = match buf[0] { - 0x56 => ClientHandshake::decode(&mut data).map(M::ClientHandshake)?, - 0x70 => { - SaslInitialResponse::decode(&mut data).map(M::AuthenticationSaslInitialResponse)? - } - 0x72 => SaslResponse::decode(&mut data).map(M::AuthenticationSaslResponse)?, - 0x51 => ExecuteScript::decode(&mut data).map(M::ExecuteScript)?, - 0x50 => { - if buf.proto().is_1() { - Parse::decode(&mut data).map(M::Parse)? - } else { - Prepare::decode(&mut data).map(M::Prepare)? - } - } - 0x45 => Execute0::decode(&mut data).map(M::Execute0)?, - 0x4f => { - if buf.proto().is_1() { - Execute1::decode(&mut data).map(M::Execute1)? - } else { - OptimisticExecute::decode(&mut data).map(M::OptimisticExecute)? - } - } - 0x3e => { - if buf.proto().is_3() { - Dump3::decode(&mut data).map(M::Dump3)? - } else { - Dump2::decode(&mut data).map(M::Dump2)? - } - } - 0x3c => Restore::decode(&mut data).map(M::Restore)?, - 0x3d => RestoreBlock::decode(&mut data).map(M::RestoreBlock)?, - 0x2e => M::RestoreEof, - 0x53 => M::Sync, - 0x48 => M::Flush, - 0x58 => M::Terminate, - 0x44 => DescribeStatement::decode(&mut data).map(M::DescribeStatement)?, - code => M::UnknownMessage(code, data.copy_to_bytes(data.remaining())), - }; - ensure!(data.remaining() == 0, errors::ExtraData); - Ok(result) - } -} - -impl Encode for Empty { - fn encode(&self, _buf: &mut Output) -> Result<(), EncodeError> { - Ok(()) - } -} - -impl Encode for ClientHandshake { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(8); - buf.put_u16(self.major_ver); - buf.put_u16(self.minor_ver); - buf.put_u16( - u16::try_from(self.params.len()) - .ok() - .context(errors::TooManyParams)?, - ); - for (k, v) in &self.params { - k.encode(buf)?; - v.encode(buf)?; - } - buf.reserve(2); - buf.put_u16( - u16::try_from(self.extensions.len()) - .ok() - .context(errors::TooManyExtensions)?, - ); - for (name, headers) in &self.extensions { - name.encode(buf)?; - buf.reserve(2); - buf.put_u16( - u16::try_from(headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - } - Ok(()) - } -} - -impl Decode for ClientHandshake { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 8, errors::Underflow); - let major_ver = buf.get_u16(); - let minor_ver = buf.get_u16(); - let num_params = buf.get_u16(); - let mut params = HashMap::new(); - for _ in 0..num_params { - params.insert(String::decode(buf)?, String::decode(buf)?); - } - - ensure!(buf.remaining() >= 2, errors::Underflow); - let num_ext = buf.get_u16(); - let mut extensions = HashMap::new(); - for _ in 0..num_ext { - let name = String::decode(buf)?; - ensure!(buf.remaining() >= 2, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - extensions.insert(name, headers); - } - Ok(ClientHandshake { - major_ver, - minor_ver, - params, - extensions, - }) - } -} - -impl Encode for SaslInitialResponse { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - self.method.encode(buf)?; - self.data.encode(buf)?; - Ok(()) - } -} - -impl Decode for SaslInitialResponse { - fn decode(buf: &mut Input) -> Result { - let method = String::decode(buf)?; - let data = Bytes::decode(buf)?; - Ok(SaslInitialResponse { method, data }) - } -} - -impl Encode for SaslResponse { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - self.data.encode(buf)?; - Ok(()) - } -} - -impl Decode for SaslResponse { - fn decode(buf: &mut Input) -> Result { - let data = Bytes::decode(buf)?; - Ok(SaslResponse { data }) - } -} - -impl Encode for ExecuteScript { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(6); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - self.script_text.encode(buf)?; - Ok(()) - } -} - -impl Decode for ExecuteScript { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 6, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - let script_text = String::decode(buf)?; - Ok(ExecuteScript { - script_text, - headers, - }) - } -} - -impl Encode for Prepare { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - debug_assert!(!buf.proto().is_1()); - buf.reserve(12); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - buf.reserve(10); - buf.put_u8(self.io_format as u8); - buf.put_u8(self.expected_cardinality as u8); - self.statement_name.encode(buf)?; - self.command_text.encode(buf)?; - Ok(()) - } -} - -impl Decode for Prepare { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 12, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - ensure!(buf.remaining() >= 8, errors::Underflow); - let io_format = match buf.get_u8() { - 0x62 => IoFormat::Binary, - 0x6a => IoFormat::Json, - 0x4a => IoFormat::JsonElements, - c => errors::InvalidIoFormat { io_format: c }.fail()?, - }; - let expected_cardinality = TryFrom::try_from(buf.get_u8())?; - let statement_name = Bytes::decode(buf)?; - let command_text = String::decode(buf)?; - Ok(Prepare { - headers, - io_format, - expected_cardinality, - statement_name, - command_text, - }) - } -} - -impl Encode for DescribeStatement { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(7); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - buf.reserve(5); - buf.put_u8(self.aspect as u8); - self.statement_name.encode(buf)?; - Ok(()) - } -} - -impl Decode for DescribeStatement { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 12, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - ensure!(buf.remaining() >= 8, errors::Underflow); - let aspect = match buf.get_u8() { - 0x54 => DescribeAspect::DataDescription, - c => errors::InvalidAspect { aspect: c }.fail()?, - }; - let statement_name = Bytes::decode(buf)?; - Ok(DescribeStatement { - headers, - aspect, - statement_name, - }) - } -} - -impl Encode for Execute0 { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - debug_assert!(!buf.proto().is_1()); - buf.reserve(10); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - self.statement_name.encode(buf)?; - self.arguments.encode(buf)?; - Ok(()) - } -} - -impl Decode for Execute0 { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 12, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - let statement_name = Bytes::decode(buf)?; - let arguments = Bytes::decode(buf)?; - Ok(Execute0 { - headers, - statement_name, - arguments, - }) - } -} - -impl OptimisticExecute { - pub fn new( - flags: &CompilationOptions, - query: &str, - arguments: impl Into, - input_typedesc_id: Uuid, - output_typedesc_id: Uuid, - ) -> OptimisticExecute { - let mut headers = KeyValues::new(); - if let Some(limit) = flags.implicit_limit { - headers.insert(0xFF01, Bytes::from(limit.to_string())); - } - if flags.implicit_typenames { - headers.insert(0xFF02, "true".into()); - } - if flags.implicit_typeids { - headers.insert(0xFF03, "true".into()); - } - let caps = flags.allow_capabilities.bits().to_be_bytes(); - headers.insert(0xFF04, caps[..].to_vec().into()); - if flags.explicit_objectids { - headers.insert(0xFF03, "true".into()); - } - OptimisticExecute { - headers, - io_format: flags.io_format, - expected_cardinality: flags.expected_cardinality, - command_text: query.into(), - input_typedesc_id, - output_typedesc_id, - arguments: arguments.into(), - } - } -} - -impl Encode for OptimisticExecute { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(2 + 1 + 1 + 4 + 16 + 16 + 4); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - buf.reserve(1 + 1 + 4 + 16 + 16 + 4); - buf.put_u8(self.io_format as u8); - buf.put_u8(self.expected_cardinality as u8); - self.command_text.encode(buf)?; - self.input_typedesc_id.encode(buf)?; - self.output_typedesc_id.encode(buf)?; - self.arguments.encode(buf)?; - Ok(()) - } -} - -impl Decode for OptimisticExecute { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 12, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - let io_format = match buf.get_u8() { - 0x62 => IoFormat::Binary, - 0x6a => IoFormat::Json, - 0x4a => IoFormat::JsonElements, - c => errors::InvalidIoFormat { io_format: c }.fail()?, - }; - let expected_cardinality = TryFrom::try_from(buf.get_u8())?; - let command_text = String::decode(buf)?; - let input_typedesc_id = Uuid::decode(buf)?; - let output_typedesc_id = Uuid::decode(buf)?; - let arguments = Bytes::decode(buf)?; - Ok(OptimisticExecute { - headers, - io_format, - expected_cardinality, - command_text, - input_typedesc_id, - output_typedesc_id, - arguments, - }) - } -} - -impl Encode for Execute1 { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(2 + 3 * 8 + 1 + 1 + 4 + 16 + 4 + 16 + 16 + 4); - if let Some(annotations) = self.annotations.as_deref() { - buf.put_u16( - u16::try_from(annotations.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (name, value) in annotations { - buf.reserve(4); - name.encode(buf)?; - value.encode(buf)?; - } - } else { - buf.put_u16(0); - } - buf.reserve(3 * 8 + 1 + 1 + 4 + 16 + 4 + 16 + 16 + 4); - buf.put_u64(self.allowed_capabilities.bits()); - buf.put_u64(self.compilation_flags.bits()); - buf.put_u64(self.implicit_limit.unwrap_or(0)); - if buf.proto().is_multilingual() { - buf.put_u8(self.input_language as u8); - } - buf.put_u8(self.output_format as u8); - buf.put_u8(self.expected_cardinality as u8); - self.command_text.encode(buf)?; - self.state.typedesc_id.encode(buf)?; - self.state.data.encode(buf)?; - self.input_typedesc_id.encode(buf)?; - self.output_typedesc_id.encode(buf)?; - self.arguments.encode(buf)?; - Ok(()) - } -} - -impl Decode for Execute1 { - fn decode(buf: &mut Input) -> Result { - ensure!( - buf.remaining() >= 2 + 3 * 8 + 2 + 4 + 16 + 4 + 16 + 16 + 4, - errors::Underflow - ); - let num_annotations = buf.get_u16(); - let annotations = if num_annotations == 0 { - None - } else { - let mut annotations = HashMap::new(); - for _ in 0..num_annotations { - ensure!(buf.remaining() >= 4, errors::Underflow); - annotations.insert(String::decode(buf)?, String::decode(buf)?); - } - Some(Arc::new(annotations)) - }; - ensure!( - buf.remaining() >= 3 * 8 + 2 + 4 + 16 + 4 + 16 + 16 + 4, - errors::Underflow - ); - let allowed_capabilities = decode_capabilities(buf.get_u64())?; - let compilation_flags = decode_compilation_flags(buf.get_u64())?; - let implicit_limit = match buf.get_u64() { - 0 => None, - val => Some(val), - }; - let input_language = if buf.proto().is_multilingual() { - TryFrom::try_from(buf.get_u8())? - } else { - InputLanguage::EdgeQL - }; - let output_format = match buf.get_u8() { - 0x62 => IoFormat::Binary, - 0x6a => IoFormat::Json, - 0x4a => IoFormat::JsonElements, - c => errors::InvalidIoFormat { io_format: c }.fail()?, - }; - let expected_cardinality = TryFrom::try_from(buf.get_u8())?; - let command_text = String::decode(buf)?; - let state = State { - typedesc_id: Uuid::decode(buf)?, - data: Bytes::decode(buf)?, - }; - let input_typedesc_id = Uuid::decode(buf)?; - let output_typedesc_id = Uuid::decode(buf)?; - let arguments = Bytes::decode(buf)?; - Ok(Execute1 { - annotations, - allowed_capabilities, - compilation_flags, - implicit_limit, - output_format, - expected_cardinality, - command_text, - state, - input_typedesc_id, - output_typedesc_id, - arguments, - input_language, - }) - } -} - -impl Encode for Dump2 { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(10); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - Ok(()) - } -} - -impl Decode for Dump2 { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 12, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - Ok(Dump2 { headers }) - } -} - -impl Encode for Dump3 { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(2 + 8); - if let Some(annotations) = self.annotations.as_deref() { - buf.put_u16( - u16::try_from(annotations.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (name, value) in annotations { - buf.reserve(4); - name.encode(buf)?; - value.encode(buf)?; - } - } else { - buf.put_u16(0); - } - buf.put_u64(self.flags.bits()); - Ok(()) - } -} - -impl Decode for Dump3 { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 2, errors::Underflow); - let num_headers = buf.get_u16(); - let annotations = if num_headers == 0 { - None - } else { - let mut annotations = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 8, errors::Underflow); - annotations.insert(String::decode(buf)?, String::decode(buf)?); - } - Some(Arc::new(annotations)) - }; - ensure!(buf.remaining() >= 8, errors::Underflow); - let flags = decode_dump_flags(buf.get_u64())?; - Ok(Dump3 { annotations, flags }) - } -} - -impl Encode for Restore { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(4 + self.data.len()); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - buf.put_u16(self.jobs); - buf.extend(&self.data); - Ok(()) - } -} - -impl Decode for Restore { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - - let jobs = buf.get_u16(); - - let data = buf.copy_to_bytes(buf.remaining()); - Ok(Restore { - jobs, - headers, - data, - }) - } -} - -impl Encode for RestoreBlock { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.extend(&self.data); - Ok(()) - } -} - -impl Decode for RestoreBlock { - fn decode(buf: &mut Input) -> Result { - let data = buf.copy_to_bytes(buf.remaining()); - Ok(RestoreBlock { data }) - } -} - -impl Parse { - pub fn new( - opts: &CompilationOptions, - query: &str, - state: State, - annotations: Option>, - ) -> Parse { - Parse { - annotations, - allowed_capabilities: opts.allow_capabilities, - compilation_flags: opts.flags(), - implicit_limit: opts.implicit_limit, - output_format: opts.io_format, - expected_cardinality: opts.expected_cardinality, - command_text: query.into(), - state, - input_language: opts.input_language, - } - } -} - -impl Prepare { - pub fn new(flags: &CompilationOptions, query: &str) -> Prepare { - let mut headers = KeyValues::new(); - if let Some(limit) = flags.implicit_limit { - headers.insert(0xFF01, Bytes::from(limit.to_string())); - } - if flags.implicit_typenames { - headers.insert(0xFF02, "true".into()); - } - if flags.implicit_typeids { - headers.insert(0xFF03, "true".into()); - } - let caps = flags.allow_capabilities.bits().to_be_bytes(); - headers.insert(0xFF04, caps[..].to_vec().into()); - if flags.explicit_objectids { - headers.insert(0xFF03, "true".into()); - } - Prepare { - headers, - io_format: flags.io_format, - expected_cardinality: flags.expected_cardinality, - statement_name: Bytes::from(""), - command_text: query.into(), - } - } -} - -fn decode_capabilities(val: u64) -> Result { - Capabilities::from_bits(val) - .ok_or_else(|| errors::InvalidCapabilities { capabilities: val }.build()) -} - -fn decode_compilation_flags(val: u64) -> Result { - CompilationFlags::from_bits(val).ok_or_else(|| { - errors::InvalidCompilationFlags { - compilation_flags: val, - } - .build() - }) -} - -fn decode_dump_flags(val: u64) -> Result { - DumpFlags::from_bits(val).ok_or_else(|| errors::InvalidDumpFlags { dump_flags: val }.build()) -} - -impl Decode for Parse { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 52, errors::Underflow); - let num_headers = buf.get_u16(); - let annotations = if num_headers == 0 { - None - } else { - let mut annotations = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 8, errors::Underflow); - annotations.insert(String::decode(buf)?, String::decode(buf)?); - } - Some(Arc::new(annotations)) - }; - ensure!(buf.remaining() >= 50, errors::Underflow); - let allowed_capabilities = decode_capabilities(buf.get_u64())?; - let compilation_flags = decode_compilation_flags(buf.get_u64())?; - let implicit_limit = match buf.get_u64() { - 0 => None, - val => Some(val), - }; - let input_language = if buf.proto().is_multilingual() { - TryFrom::try_from(buf.get_u8())? - } else { - InputLanguage::EdgeQL - }; - let output_format = match buf.get_u8() { - 0x62 => IoFormat::Binary, - 0x6a => IoFormat::Json, - 0x4a => IoFormat::JsonElements, - c => errors::InvalidIoFormat { io_format: c }.fail()?, - }; - let expected_cardinality = TryFrom::try_from(buf.get_u8())?; - let command_text = String::decode(buf)?; - let state = State { - typedesc_id: Uuid::decode(buf)?, - data: Bytes::decode(buf)?, - }; - Ok(Parse { - annotations, - allowed_capabilities, - compilation_flags, - implicit_limit, - output_format, - expected_cardinality, - command_text, - state, - input_language, - }) - } -} - -impl Encode for Parse { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - debug_assert!(buf.proto().is_1()); - buf.reserve(52); - if let Some(annotations) = self.annotations.as_deref() { - buf.put_u16( - u16::try_from(annotations.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (name, value) in annotations { - buf.reserve(8); - name.encode(buf)?; - value.encode(buf)?; - } - } else { - buf.put_u16(0); - } - buf.reserve(50); - buf.put_u64(self.allowed_capabilities.bits()); - buf.put_u64(self.compilation_flags.bits()); - buf.put_u64(self.implicit_limit.unwrap_or(0)); - if buf.proto().is_multilingual() { - buf.put_u8(self.input_language as u8); - } - buf.put_u8(self.output_format as u8); - buf.put_u8(self.expected_cardinality as u8); - self.command_text.encode(buf)?; - self.state.typedesc_id.encode(buf)?; - self.state.data.encode(buf)?; - Ok(()) - } -} diff --git a/edgedb-protocol/src/codec.rs b/edgedb-protocol/src/codec.rs deleted file mode 100644 index 16269882..00000000 --- a/edgedb-protocol/src/codec.rs +++ /dev/null @@ -1,1633 +0,0 @@ -/*! -Implementations of the [Codec] trait into types found in the [Value] enum. -*/ - -use std::any::type_name; -use std::collections::HashSet; -use std::convert::{TryFrom, TryInto}; -use std::fmt; -use std::ops::Deref; -use std::str; -use std::sync::Arc; - -use bytes::{Buf, BufMut, BytesMut}; -use snafu::{ensure, OptionExt, ResultExt}; -use uuid::Uuid as UuidVal; - -use crate::common::Cardinality; -use crate::descriptors::{self, Descriptor, TypePos}; -use crate::errors::{self, CodecError, DecodeError, EncodeError}; -use crate::model; -use crate::model::range; -use crate::serialization::decode::DecodeRange; -use crate::serialization::decode::{DecodeArrayLike, DecodeTupleLike, RawCodec}; -use crate::value::{SparseObject, Value}; - -pub const STD_UUID: UuidVal = UuidVal::from_u128(0x100); -pub const STD_STR: UuidVal = UuidVal::from_u128(0x101); -pub const STD_BYTES: UuidVal = UuidVal::from_u128(0x102); -pub const STD_INT16: UuidVal = UuidVal::from_u128(0x103); -pub const STD_INT32: UuidVal = UuidVal::from_u128(0x104); -pub const STD_INT64: UuidVal = UuidVal::from_u128(0x105); -pub const STD_FLOAT32: UuidVal = UuidVal::from_u128(0x106); -pub const STD_FLOAT64: UuidVal = UuidVal::from_u128(0x107); -pub const STD_DECIMAL: UuidVal = UuidVal::from_u128(0x108); -pub const STD_BOOL: UuidVal = UuidVal::from_u128(0x109); -pub const STD_DATETIME: UuidVal = UuidVal::from_u128(0x10a); -pub const CAL_LOCAL_DATETIME: UuidVal = UuidVal::from_u128(0x10b); -pub const CAL_LOCAL_DATE: UuidVal = UuidVal::from_u128(0x10c); -pub const CAL_LOCAL_TIME: UuidVal = UuidVal::from_u128(0x10d); -pub const STD_DURATION: UuidVal = UuidVal::from_u128(0x10e); -pub const CAL_RELATIVE_DURATION: UuidVal = UuidVal::from_u128(0x111); -pub const CAL_DATE_DURATION: UuidVal = UuidVal::from_u128(0x112); -pub const STD_JSON: UuidVal = UuidVal::from_u128(0x10f); -pub const STD_BIGINT: UuidVal = UuidVal::from_u128(0x110); -pub const CFG_MEMORY: UuidVal = UuidVal::from_u128(0x130); -pub const PGVECTOR_VECTOR: UuidVal = UuidVal::from_u128(0x9565dd88_04f5_11ee_a691_0b6ebe179825); -pub const STD_PG_JSON: UuidVal = UuidVal::from_u128(0x1000001); -pub const STD_PG_TIMESTAMPTZ: UuidVal = UuidVal::from_u128(0x1000002); -pub const STD_PG_TIMESTAMP: UuidVal = UuidVal::from_u128(0x1000003); -pub const STD_PG_DATE: UuidVal = UuidVal::from_u128(0x1000004); -pub const STD_PG_INTERVAL: UuidVal = UuidVal::from_u128(0x1000005); -pub const POSTGIS_GEOMETRY: UuidVal = UuidVal::from_u128(0x44c901c0_d922_4894_83c8_061bd05e4840); -pub const POSTGIS_GEOGRAPHY: UuidVal = UuidVal::from_u128(0x4d738878_3a5f_4821_ab76_9d8e7d6b32c4); -pub const POSTGIS_BOX_2D: UuidVal = UuidVal::from_u128(0x7fae5536_6311_4f60_8eb9_096a5d972f48); -pub const POSTGIS_BOX_3D: UuidVal = UuidVal::from_u128(0xc1a50ff8_fded_48b0_85c2_4905a8481433); - -pub(crate) fn uuid_to_known_name(uuid: &UuidVal) -> Option<&'static str> { - match *uuid { - STD_UUID => Some("BaseScalar(uuid)"), - STD_STR => Some("BaseScalar(str)"), - STD_BYTES => Some("BaseScalar(bytes)"), - STD_INT16 => Some("BaseScalar(int16)"), - STD_INT32 => Some("BaseScalar(int32)"), - STD_INT64 => Some("BaseScalar(int64)"), - STD_FLOAT32 => Some("BaseScalar(float32)"), - STD_FLOAT64 => Some("BaseScalar(float64)"), - STD_DECIMAL => Some("BaseScalar(decimal)"), - STD_BOOL => Some("BaseScalar(bool)"), - STD_DATETIME => Some("BaseScalar(datetime)"), - CAL_LOCAL_DATETIME => Some("BaseScalar(cal::local_datetime)"), - CAL_LOCAL_DATE => Some("BaseScalar(cal::local_date)"), - CAL_LOCAL_TIME => Some("BaseScalar(cal::local_time)"), - STD_DURATION => Some("BaseScalar(duration)"), - CAL_RELATIVE_DURATION => Some("BaseScalar(cal::relative_duration)"), - CAL_DATE_DURATION => Some("BaseScalar(cal::date_duration)"), - STD_JSON => Some("BaseScalar(std::json)"), - STD_BIGINT => Some("BaseScalar(bigint)"), - CFG_MEMORY => Some("BaseScalar(cfg::memory)"), - PGVECTOR_VECTOR => Some("BaseScalar(ext::pgvector::vector)"), - STD_PG_JSON => Some("BaseScalar(std::pg::json)"), - STD_PG_TIMESTAMPTZ => Some("BaseScalar(std::pg::timestamptz)"), - STD_PG_TIMESTAMP => Some("BaseScalar(std::pg::timestamp)"), - STD_PG_DATE => Some("BaseScalar(std::pg::date)"), - STD_PG_INTERVAL => Some("BaseScalar(std::pg::interval)"), - POSTGIS_GEOMETRY => Some("BaseScalar(ext::postgis::geometry)"), - POSTGIS_GEOGRAPHY => Some("BaseScalar(ext::postgis::geography)"), - POSTGIS_BOX_2D => Some("BaseScalar(ext::postgis::box2d)"), - POSTGIS_BOX_3D => Some("BaseScalar(ext::postgis::box3d)"), - _ => None, - } -} - -pub trait Codec: fmt::Debug + Send + Sync + 'static { - fn decode(&self, buf: &[u8]) -> Result; - fn encode(&self, buf: &mut BytesMut, value: &Value) -> Result<(), EncodeError>; -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct EnumValue(Arc); -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ObjectShape(pub(crate) Arc); -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct NamedTupleShape(Arc); -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct SQLRowShape(Arc); - -#[derive(Debug, PartialEq, Eq)] -pub struct ObjectShapeInfo { - pub elements: Vec, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct ShapeElement { - pub flag_implicit: bool, - pub flag_link_property: bool, - pub flag_link: bool, - pub cardinality: Option, - pub name: String, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct InputObjectShape(pub(crate) Arc); - -#[derive(Debug, PartialEq, Eq)] -pub struct InputObjectShapeInfo { - pub elements: Vec, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct InputShapeElement { - pub cardinality: Option, - pub name: String, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct NamedTupleShapeInfo { - pub elements: Vec, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct TupleElement { - pub name: String, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct SQLRowShapeInfo { - pub elements: Vec, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct SQLRowElement { - pub name: String, -} - -#[derive(Debug)] -pub struct Uuid; - -#[derive(Debug)] -pub struct Int16; - -#[derive(Debug)] -pub struct Int32; - -#[derive(Debug)] -pub struct Int64; - -#[derive(Debug)] -pub struct Float32; - -#[derive(Debug)] -pub struct Float64; - -#[derive(Debug)] -pub struct Str; - -#[derive(Debug)] -pub struct Bytes; - -#[derive(Debug)] -pub struct Duration; - -#[derive(Debug)] -pub struct RelativeDuration; - -#[derive(Debug)] -pub struct DateDuration; - -#[derive(Debug)] -pub struct Datetime; - -#[derive(Debug)] -pub struct LocalDatetime; - -#[derive(Debug)] -pub struct LocalDate; - -#[derive(Debug)] -pub struct LocalTime; - -#[derive(Debug)] -pub struct Decimal; - -#[derive(Debug)] -pub struct BigInt; - -#[derive(Debug)] -pub struct ConfigMemory; - -#[derive(Debug)] -pub struct Bool; - -#[derive(Debug)] -pub struct Json; - -#[derive(Debug)] -pub struct PgTextJson; - -#[derive(Debug)] -pub struct Nothing; - -#[derive(Debug)] -pub struct Object { - shape: ObjectShape, - codecs: Vec>, -} - -#[derive(Debug)] -pub struct Input { - shape: InputObjectShape, - codecs: Vec>, -} - -#[derive(Debug)] -pub struct Set { - element: Arc, -} - -#[derive(Debug)] -pub struct Scalar { - inner: Arc, -} - -#[derive(Debug)] -pub struct Tuple { - elements: Vec>, -} - -#[derive(Debug)] -pub struct NamedTuple { - shape: NamedTupleShape, - codecs: Vec>, -} - -#[derive(Debug)] -pub struct SQLRow { - shape: SQLRowShape, - codecs: Vec>, -} - -#[derive(Debug)] -pub struct Array { - element: Arc, -} - -#[derive(Debug)] -pub struct Vector {} - -#[derive(Debug)] -pub struct Range { - element: Arc, -} - -#[derive(Debug)] -pub struct MultiRange { - element: Arc, -} - -#[derive(Debug)] -pub struct ArrayAdapter(Array); - -#[derive(Debug)] -pub struct Enum { - members: HashSet>, -} - -#[derive(Debug)] -pub struct PostGisGeometry {} - -#[derive(Debug)] -pub struct PostGisGeography {} - -#[derive(Debug)] -pub struct PostGisBox2d {} - -#[derive(Debug)] -pub struct PostGisBox3d {} - -struct CodecBuilder<'a> { - descriptors: &'a [Descriptor], -} - -impl ObjectShape { - pub fn new(elements: Vec) -> ObjectShape { - ObjectShape(Arc::new(ObjectShapeInfo { elements })) - } -} - -impl Deref for ObjectShape { - type Target = ObjectShapeInfo; - fn deref(&self) -> &ObjectShapeInfo { - &self.0 - } -} - -impl InputObjectShape { - pub fn new(elements: Vec) -> Self { - InputObjectShape(Arc::new(InputObjectShapeInfo { elements })) - } -} - -impl Deref for InputObjectShape { - type Target = InputObjectShapeInfo; - fn deref(&self) -> &InputObjectShapeInfo { - &self.0 - } -} - -impl Deref for NamedTupleShape { - type Target = NamedTupleShapeInfo; - fn deref(&self) -> &NamedTupleShapeInfo { - &self.0 - } -} - -impl Deref for SQLRowShape { - type Target = SQLRowShapeInfo; - fn deref(&self) -> &SQLRowShapeInfo { - &self.0 - } -} - -impl<'a> CodecBuilder<'a> { - fn build(&self, pos: TypePos) -> Result, CodecError> { - use Descriptor as D; - if let Some(item) = self.descriptors.get(pos.0 as usize) { - match item { - D::BaseScalar(base) => scalar_codec(&base.id), - D::Set(d) => Ok(Arc::new(Set::build(d, self)?)), - D::ObjectShape(d) => Ok(Arc::new(Object::build(d, self)?)), - D::Scalar(d) => Ok(Arc::new(Scalar { - inner: match d.base_type_pos { - Some(type_pos) => self.build(type_pos)?, - None => scalar_codec(&d.id)?, - }, - })), - D::Tuple(d) => Ok(Arc::new(Tuple::build(d, self)?)), - D::NamedTuple(d) => Ok(Arc::new(NamedTuple::build(d, self)?)), - D::Array(d) => Ok(Arc::new(Array { - element: self.build(d.type_pos)?, - })), - D::Range(d) => Ok(Arc::new(Range { - element: self.build(d.type_pos)?, - })), - D::MultiRange(d) => Ok(Arc::new(MultiRange { - element: Arc::new(Range { - element: self.build(d.type_pos)?, - }), - })), - D::Enumeration(d) => Ok(Arc::new(Enum { - members: d.members.iter().map(|x| x[..].into()).collect(), - })), - D::Object(_) => Ok(Arc::new(Nothing {})), - D::Compound(_) => Ok(Arc::new(Nothing {})), - D::InputShape(d) => Ok(Arc::new(Input::build(d, self)?)), - D::SQLRow(d) => Ok(Arc::new(SQLRow::build(d, self)?)), - // type annotations are stripped from codecs array before - // building a codec - D::TypeAnnotation(..) => unreachable!(), - } - } else { - errors::UnexpectedTypePos { position: pos.0 }.fail()? - } - } -} - -pub fn build_codec( - root_pos: Option, - descriptors: &[Descriptor], -) -> Result, CodecError> { - let dec = CodecBuilder { descriptors }; - match root_pos { - Some(pos) => dec.build(pos), - None => Ok(Arc::new(Nothing {})), - } -} - -pub fn scalar_codec(uuid: &UuidVal) -> Result, CodecError> { - match *uuid { - STD_UUID => Ok(Arc::new(Uuid {})), - STD_STR => Ok(Arc::new(Str {})), - STD_BYTES => Ok(Arc::new(Bytes {})), - STD_INT16 => Ok(Arc::new(Int16 {})), - STD_INT32 => Ok(Arc::new(Int32 {})), - STD_INT64 => Ok(Arc::new(Int64 {})), - STD_FLOAT32 => Ok(Arc::new(Float32 {})), - STD_FLOAT64 => Ok(Arc::new(Float64 {})), - STD_DECIMAL => Ok(Arc::new(Decimal {})), - STD_BOOL => Ok(Arc::new(Bool {})), - STD_DATETIME => Ok(Arc::new(Datetime {})), - CAL_LOCAL_DATETIME => Ok(Arc::new(LocalDatetime {})), - CAL_LOCAL_DATE => Ok(Arc::new(LocalDate {})), - CAL_LOCAL_TIME => Ok(Arc::new(LocalTime {})), - STD_DURATION => Ok(Arc::new(Duration {})), - CAL_RELATIVE_DURATION => Ok(Arc::new(RelativeDuration {})), - CAL_DATE_DURATION => Ok(Arc::new(DateDuration {})), - STD_JSON => Ok(Arc::new(Json {})), - STD_BIGINT => Ok(Arc::new(BigInt {})), - CFG_MEMORY => Ok(Arc::new(ConfigMemory {})), - PGVECTOR_VECTOR => Ok(Arc::new(Vector {})), - STD_PG_JSON => Ok(Arc::new(PgTextJson {})), - STD_PG_TIMESTAMPTZ => Ok(Arc::new(Datetime {})), - STD_PG_TIMESTAMP => Ok(Arc::new(LocalDatetime {})), - STD_PG_DATE => Ok(Arc::new(LocalDate {})), - STD_PG_INTERVAL => Ok(Arc::new(RelativeDuration {})), - POSTGIS_GEOMETRY => Ok(Arc::new(PostGisGeometry {})), - POSTGIS_GEOGRAPHY => Ok(Arc::new(PostGisGeography {})), - POSTGIS_BOX_2D => Ok(Arc::new(PostGisBox2d {})), - POSTGIS_BOX_3D => Ok(Arc::new(PostGisBox3d {})), - _ => errors::UndefinedBaseScalar { uuid: uuid.clone() }.fail()?, - } -} - -impl Codec for Int16 { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Int16) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let &val = match val { - Value::Int16(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(2); - buf.put_i16(val); - Ok(()) - } -} - -impl Codec for Int32 { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Int32) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Int32(val) => *val, - Value::Int16(val) => *val as i32, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(4); - buf.put_i32(val); - Ok(()) - } -} - -impl Codec for Int64 { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Int64) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Int64(val) => *val, - Value::Int32(val) => *val as i64, - Value::Int16(val) => *val as i64, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(8); - buf.put_i64(val); - Ok(()) - } -} - -impl Codec for ConfigMemory { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::ConfigMemory) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let &val = match val { - Value::ConfigMemory(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(8); - buf.put_i64(val.0); - Ok(()) - } -} - -impl Codec for Float32 { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Float32) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let &val = match val { - Value::Float32(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(4); - buf.put_f32(val); - Ok(()) - } -} - -impl Codec for Float64 { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Float64) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Float64(val) => *val, - Value::Float32(val) => *val as f64, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(8); - buf.put_f64(val); - Ok(()) - } -} - -impl Codec for Str { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Str) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Str(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.extend(val.as_bytes()); - Ok(()) - } -} - -impl Codec for Bytes { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Bytes) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Bytes(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.extend(val); - Ok(()) - } -} - -impl Codec for Duration { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Duration) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Duration(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - encode_duration(buf, val) - } -} - -pub(crate) fn encode_duration( - buf: &mut BytesMut, - val: &model::Duration, -) -> Result<(), EncodeError> { - buf.reserve(16); - buf.put_i64(val.micros); - buf.put_u32(0); - buf.put_u32(0); - Ok(()) -} - -impl Codec for RelativeDuration { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::RelativeDuration) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::RelativeDuration(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - encode_relative_duration(buf, val) - } -} - -pub(crate) fn encode_relative_duration( - buf: &mut BytesMut, - val: &model::RelativeDuration, -) -> Result<(), EncodeError> { - buf.reserve(16); - buf.put_i64(val.micros); - buf.put_i32(val.days); - buf.put_i32(val.months); - Ok(()) -} - -impl Codec for DateDuration { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::DateDuration) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::DateDuration(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - encode_date_duration(buf, val) - } -} - -pub(crate) fn encode_date_duration( - buf: &mut BytesMut, - val: &model::DateDuration, -) -> Result<(), EncodeError> { - buf.reserve(16); - buf.put_i64(0); - buf.put_i32(val.days); - buf.put_i32(val.months); - Ok(()) -} - -impl Codec for Uuid { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Uuid) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let &val = match val { - Value::Uuid(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.extend(val.as_bytes()); - Ok(()) - } -} - -impl Codec for Nothing { - fn decode(&self, _buf: &[u8]) -> Result { - Ok(Value::Nothing) - } - fn encode(&self, _buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - match val { - Value::Nothing => Ok(()), - _ => Err(errors::invalid_value(type_name::(), val))?, - } - } -} - -impl Object { - fn build( - d: &descriptors::ObjectShapeDescriptor, - dec: &CodecBuilder, - ) -> Result { - Ok(Object { - shape: d.elements.as_slice().into(), - codecs: d - .elements - .iter() - .map(|e| dec.build(e.type_pos)) - .collect::>()?, - }) - } -} - -impl Input { - fn build( - d: &descriptors::InputShapeTypeDescriptor, - dec: &CodecBuilder, - ) -> Result { - Ok(Input { - shape: d.elements.as_slice().into(), - codecs: d - .elements - .iter() - .map(|e| dec.build(e.type_pos)) - .collect::>()?, - }) - } -} - -impl Tuple { - fn build( - d: &descriptors::TupleTypeDescriptor, - dec: &CodecBuilder, - ) -> Result { - Ok(Tuple { - elements: d - .element_types - .iter() - .map(|&t| dec.build(t)) - .collect::>()?, - }) - } -} - -impl NamedTuple { - fn build( - d: &descriptors::NamedTupleTypeDescriptor, - dec: &CodecBuilder, - ) -> Result { - Ok(NamedTuple { - shape: d.elements.as_slice().into(), - codecs: d - .elements - .iter() - .map(|e| dec.build(e.type_pos)) - .collect::>()?, - }) - } -} - -impl SQLRow { - fn build(d: &descriptors::SQLRowDescriptor, dec: &CodecBuilder) -> Result { - Ok(SQLRow { - shape: d.elements.as_slice().into(), - codecs: d - .elements - .iter() - .map(|e| dec.build(e.type_pos)) - .collect::>()?, - }) - } -} - -fn decode_tuple( - mut elements: DecodeTupleLike, - codecs: &[Arc], -) -> Result, DecodeError> { - codecs - .iter() - .map(|codec| { - codec.decode( - elements - .read()? - .ok_or_else(|| errors::MissingRequiredElement.build())?, - ) - }) - .collect::, DecodeError>>() -} - -fn decode_array_like( - elements: DecodeArrayLike<'_>, - codec: &dyn Codec, -) -> Result, DecodeError> { - elements - .map(|element| codec.decode(element?)) - .collect::, DecodeError>>() -} - -impl Codec for Object { - fn decode(&self, buf: &[u8]) -> Result { - let mut elements = DecodeTupleLike::new_object(buf, self.codecs.len())?; - let fields = self - .codecs - .iter() - .map(|codec| { - elements - .read()? - .map(|element| codec.decode(element)) - .transpose() - }) - .collect::>, DecodeError>>()?; - - Ok(Value::Object { - shape: self.shape.clone(), - fields, - }) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let (shape, fields) = match val { - Value::Object { shape, fields } => (shape, fields), - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - ensure!(shape == &self.shape, errors::ObjectShapeMismatch); - ensure!( - self.codecs.len() == fields.len(), - errors::ObjectShapeMismatch - ); - debug_assert_eq!(self.codecs.len(), shape.0.elements.len()); - buf.reserve(4 + 8 * self.codecs.len()); - buf.put_u32( - self.codecs - .len() - .try_into() - .ok() - .context(errors::TooManyElements)?, - ); - for (codec, field) in self.codecs.iter().zip(fields) { - buf.reserve(8); - buf.put_u32(0); - match field { - Some(v) => { - let pos = buf.len(); - buf.put_i32(0); // replaced after serializing a value - codec.encode(buf, v)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &i32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - } - None => { - buf.put_i32(-1); - } - } - } - Ok(()) - } -} - -impl Codec for Input { - fn decode(&self, mut buf: &[u8]) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - let count = buf.get_u32() as usize; - let mut fields = vec![None; self.codecs.len()]; - for _ in 0..count { - ensure!(buf.remaining() >= 8, errors::Underflow); - let index = buf.get_u32() as usize; - ensure!(index < self.codecs.len(), errors::InvalidIndex { index }); - let length = buf.get_i32(); - if length < 0 { - fields[index] = Some(None); - } else { - let length = length as usize; - let value = self.codecs[index].decode(&buf[..length])?; - buf.advance(length); - fields[index] = Some(Some(value)); - } - } - Ok(Value::SparseObject(SparseObject { - shape: self.shape.clone(), - fields, - })) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let ob = match val { - Value::SparseObject(ob) => ob, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - let mut items = Vec::with_capacity(self.codecs.len()); - let dest_els = &self.shape.0.elements; - for (fld, el) in ob.fields.iter().zip(&ob.shape.0.elements) { - if let Some(value) = fld { - if let Some(index) = dest_els.iter().position(|x| x.name == el.name) { - items.push((index, value)); - } - } - } - buf.reserve(4 + 8 * items.len()); - buf.put_u32( - items - .len() - .try_into() - .ok() - .context(errors::TooManyElements)?, - ); - for (index, value) in items { - buf.reserve(8); - buf.put_u32(index as u32); - let pos = buf.len(); - if let Some(value) = value { - buf.put_i32(0); // replaced after serializing a value - self.codecs[index].encode(buf, value)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &i32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - } else { - buf.put_i32(-1); - } - } - Ok(()) - } -} - -impl Codec for ArrayAdapter { - fn decode(&self, mut buf: &[u8]) -> Result { - ensure!(buf.remaining() >= 12, errors::Underflow); - let count = buf.get_u32() as usize; - ensure!(count == 1, errors::InvalidArrayShape); - let _reserved = buf.get_i32() as usize; - let len = buf.get_i32() as usize; - ensure!(buf.remaining() >= len, errors::Underflow); - ensure!(buf.remaining() <= len, errors::ExtraData); - self.0.decode(buf) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - buf.reserve(12); - buf.put_u32(1); - buf.put_u32(0); - let pos = buf.len(); - buf.put_i32(0); // replaced after serializing a value - self.0.encode(buf, val)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &i32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - Ok(()) - } -} - -impl<'a> From<&'a [descriptors::ShapeElement]> for ObjectShape { - fn from(shape: &'a [descriptors::ShapeElement]) -> ObjectShape { - ObjectShape(Arc::new(ObjectShapeInfo { - elements: shape.iter().map(ShapeElement::from).collect(), - })) - } -} - -impl<'a> From<&'a descriptors::ShapeElement> for ShapeElement { - fn from(e: &'a descriptors::ShapeElement) -> ShapeElement { - let descriptors::ShapeElement { - flag_implicit, - flag_link_property, - flag_link, - cardinality, - name, - type_pos: _, - source_type_pos: _, - } = e; - ShapeElement { - flag_implicit: *flag_implicit, - flag_link_property: *flag_link_property, - flag_link: *flag_link, - cardinality: *cardinality, - name: name.clone(), - } - } -} - -impl<'a> From<&'a [descriptors::InputShapeElement]> for InputObjectShape { - fn from(shape: &'a [descriptors::InputShapeElement]) -> InputObjectShape { - InputObjectShape(Arc::new(InputObjectShapeInfo { - elements: shape.iter().map(InputShapeElement::from).collect(), - })) - } -} - -impl<'a> From<&'a descriptors::InputShapeElement> for InputShapeElement { - fn from(e: &'a descriptors::InputShapeElement) -> InputShapeElement { - let descriptors::InputShapeElement { - cardinality, - name, - type_pos: _, - } = e; - InputShapeElement { - cardinality: *cardinality, - name: name.clone(), - } - } -} - -impl<'a> From<&'a [descriptors::TupleElement]> for NamedTupleShape { - fn from(shape: &'a [descriptors::TupleElement]) -> NamedTupleShape { - NamedTupleShape(Arc::new(NamedTupleShapeInfo { - elements: shape - .iter() - .map(|e| { - let descriptors::TupleElement { name, type_pos: _ } = e; - TupleElement { name: name.clone() } - }) - .collect(), - })) - } -} - -impl<'a> From<&'a [descriptors::SQLRowElement]> for SQLRowShape { - fn from(shape: &'a [descriptors::SQLRowElement]) -> SQLRowShape { - SQLRowShape(Arc::new(SQLRowShapeInfo { - elements: shape - .iter() - .map(|e| { - let descriptors::SQLRowElement { name, type_pos: _ } = e; - SQLRowElement { name: name.clone() } - }) - .collect(), - })) - } -} - -impl From<&str> for EnumValue { - fn from(s: &str) -> EnumValue { - EnumValue(s.into()) - } -} - -impl std::ops::Deref for EnumValue { - type Target = str; - fn deref(&self) -> &str { - &self.0 - } -} - -impl Set { - fn build(d: &descriptors::SetDescriptor, dec: &CodecBuilder) -> Result { - let element = match dec.descriptors.get(d.type_pos.0 as usize) { - Some(Descriptor::Array(d)) => Arc::new(ArrayAdapter(Array { - element: dec.build(d.type_pos)?, - })), - _ => dec.build(d.type_pos)?, - }; - Ok(Set { element }) - } -} - -impl Codec for Set { - fn decode(&self, buf: &[u8]) -> Result { - let elements = DecodeArrayLike::new_set(buf)?; - let items = decode_array_like(elements, &*self.element)?; - Ok(Value::Set(items)) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let items = match val { - Value::Set(items) => items, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - if items.is_empty() { - buf.reserve(12); - buf.put_u32(0); // ndims - buf.put_u32(0); // reserved0 - buf.put_u32(0); // reserved1 - return Ok(()); - } - buf.reserve(20); - buf.put_u32(1); // ndims - buf.put_u32(0); // reserved0 - buf.put_u32(0); // reserved1 - buf.put_u32(items.len().try_into().ok().context(errors::ArrayTooLong)?); - buf.put_u32(1); // lower - for item in items { - buf.reserve(4); - let pos = buf.len(); - buf.put_u32(0); // replaced after serializing a value - self.element.encode(buf, item)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &u32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - } - Ok(()) - } -} - -impl Codec for Decimal { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Decimal) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Decimal(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - encode_decimal(buf, val) - } -} - -pub(crate) fn encode_decimal(buf: &mut BytesMut, val: &model::Decimal) -> Result<(), EncodeError> { - buf.reserve(8 + val.digits.len() * 2); - buf.put_u16( - val.digits - .len() - .try_into() - .ok() - .context(errors::BigIntTooLong)?, - ); - buf.put_i16(val.weight); - buf.put_u16(if val.negative { 0x4000 } else { 0x0000 }); - buf.put_u16(val.decimal_digits); - for &dig in &val.digits { - buf.put_u16(dig); - } - Ok(()) -} - -impl Codec for BigInt { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::BigInt) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::BigInt(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - encode_big_int(buf, val) - } -} - -pub(crate) fn encode_big_int(buf: &mut BytesMut, val: &model::BigInt) -> Result<(), EncodeError> { - buf.reserve(8 + val.digits.len() * 2); - buf.put_u16( - val.digits - .len() - .try_into() - .ok() - .context(errors::BigIntTooLong)?, - ); - buf.put_i16(val.weight); - buf.put_u16(if val.negative { 0x4000 } else { 0x0000 }); - buf.put_u16(0); - for &dig in &val.digits { - buf.put_u16(dig); - } - Ok(()) -} - -impl Codec for Bool { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Bool) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Bool(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(1); - buf.put_u8(match val { - true => 1, - false => 0, - }); - Ok(()) - } -} - -impl Codec for Datetime { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::Datetime) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Datetime(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - encode_datetime(buf, val) - } -} - -pub(crate) fn encode_datetime( - buf: &mut BytesMut, - val: &model::Datetime, -) -> Result<(), EncodeError> { - buf.reserve(8); - buf.put_i64(val.micros); - Ok(()) -} - -impl Codec for LocalDatetime { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::LocalDatetime) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::LocalDatetime(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - encode_local_datetime(buf, val) - } -} - -pub(crate) fn encode_local_datetime( - buf: &mut BytesMut, - val: &model::LocalDatetime, -) -> Result<(), EncodeError> { - buf.reserve(8); - buf.put_i64(val.micros); - Ok(()) -} - -impl Codec for LocalDate { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::LocalDate) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::LocalDate(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - encode_local_date(buf, val) - } -} - -pub(crate) fn encode_local_date( - buf: &mut BytesMut, - val: &model::LocalDate, -) -> Result<(), EncodeError> { - buf.reserve(4); - buf.put_i32(val.days); - Ok(()) -} - -impl Codec for LocalTime { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::LocalTime) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::LocalTime(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - encode_local_time(buf, val) - } -} - -pub(crate) fn encode_local_time( - buf: &mut BytesMut, - val: &model::LocalTime, -) -> Result<(), EncodeError> { - buf.reserve(8); - buf.put_i64(val.micros as i64); - Ok(()) -} - -impl Codec for Json { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(|json: model::Json| Value::Json(json)) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Json(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(1 + val.len()); - buf.put_u8(1); - buf.extend(val.as_bytes()); - Ok(()) - } -} - -impl Codec for PgTextJson { - fn decode(&self, buf: &[u8]) -> Result { - let val = str::from_utf8(buf).context(errors::InvalidUtf8)?.to_owned(); - let json = model::Json::new_unchecked(val); - Ok(Value::Json(json)) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Json(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(val.len()); - buf.extend(val.as_bytes()); - Ok(()) - } -} - -impl Codec for Scalar { - fn decode(&self, buf: &[u8]) -> Result { - self.inner.decode(buf) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - self.inner.encode(buf, val) - } -} - -impl Codec for Tuple { - fn decode(&self, buf: &[u8]) -> Result { - let elements = DecodeTupleLike::new_object(buf, self.elements.len())?; - let items = decode_tuple(elements, &self.elements)?; - Ok(Value::Tuple(items)) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let items = match val { - Value::Tuple(items) => items, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - ensure!( - self.elements.len() == items.len(), - errors::TupleShapeMismatch - ); - buf.reserve(4 + 8 * self.elements.len()); - buf.put_u32( - self.elements - .len() - .try_into() - .ok() - .context(errors::TooManyElements)?, - ); - for (codec, item) in self.elements.iter().zip(items) { - buf.reserve(8); - buf.put_u32(0); - let pos = buf.len(); - buf.put_u32(0); // replaced after serializing a value - codec.encode(buf, item)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &u32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - } - Ok(()) - } -} - -impl Codec for NamedTuple { - fn decode(&self, buf: &[u8]) -> Result { - let elements = DecodeTupleLike::new_tuple(buf, self.codecs.len())?; - let fields = decode_tuple(elements, &self.codecs)?; - Ok(Value::NamedTuple { - shape: self.shape.clone(), - fields, - }) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let (shape, fields) = match val { - Value::NamedTuple { shape, fields } => (shape, fields), - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - ensure!(shape == &self.shape, errors::TupleShapeMismatch); - ensure!( - self.codecs.len() == fields.len(), - errors::ObjectShapeMismatch - ); - debug_assert_eq!(self.codecs.len(), shape.0.elements.len()); - buf.reserve(4 + 8 * self.codecs.len()); - buf.put_u32( - self.codecs - .len() - .try_into() - .ok() - .context(errors::TooManyElements)?, - ); - for (codec, field) in self.codecs.iter().zip(fields) { - buf.reserve(8); - buf.put_u32(0); - let pos = buf.len(); - buf.put_u32(0); // replaced after serializing a value - codec.encode(buf, field)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &u32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - } - Ok(()) - } -} - -impl Codec for SQLRow { - fn decode(&self, buf: &[u8]) -> Result { - let elements = DecodeTupleLike::new_tuple(buf, self.codecs.len())?; - let fields = decode_tuple(elements, &self.codecs)?; - Ok(Value::SQLRow { - shape: self.shape.clone(), - fields, - }) - } - fn encode(&self, _buf: &mut BytesMut, _val: &Value) -> Result<(), EncodeError> { - errors::UnknownMessageCantBeEncoded.fail()? - } -} - -impl Codec for Array { - fn decode(&self, buf: &[u8]) -> Result { - let elements = DecodeArrayLike::new_array(buf)?; - let items = decode_array_like(elements, &*self.element)?; - Ok(Value::Array(items)) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let items = match val { - Value::Array(items) => items, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - if items.is_empty() { - buf.reserve(12); - buf.put_u32(0); // ndims - buf.put_u32(0); // reserved0 - buf.put_u32(0); // reserved1 - return Ok(()); - } - buf.reserve(20); - buf.put_u32(1); // ndims - buf.put_u32(0); // reserved0 - buf.put_u32(0); // reserved1 - buf.put_u32(items.len().try_into().ok().context(errors::ArrayTooLong)?); - buf.put_u32(1); // lower - for item in items { - buf.reserve(4); - let pos = buf.len(); - buf.put_u32(0); // replaced after serializing a value - self.element.encode(buf, item)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &u32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - } - Ok(()) - } -} - -impl Codec for Vector { - fn decode(&self, mut buf: &[u8]) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - let length = buf.get_u16() as usize; - let _reserved = buf.get_u16(); - ensure!(buf.remaining() >= length * 4, errors::Underflow); - let vec = (0..length).map(|_| f32::from_bits(buf.get_u32())).collect(); - Ok(Value::Vector(vec)) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let items = match val { - Value::Vector(items) => items, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - if items.is_empty() { - buf.reserve(4); - buf.put_i16(0); // length - buf.put_i16(0); // reserved - return Ok(()); - } - buf.reserve(4 + items.len() * 4); - buf.put_i16(items.len().try_into().ok().context(errors::ArrayTooLong)?); - buf.put_i16(0); // reserved - for item in items { - buf.put_u32(item.to_bits()); - } - Ok(()) - } -} - -impl Codec for Range { - fn decode(&self, mut buf: &[u8]) -> Result { - ensure!(buf.remaining() >= 1, errors::Underflow); - let flags = buf.get_u8() as usize; - - let empty = (flags & range::EMPTY) != 0; - let inc_lower = (flags & range::LB_INC) != 0; - let inc_upper = (flags & range::UB_INC) != 0; - let has_lower = (flags & (range::EMPTY | range::LB_INF)) == 0; - let has_upper = (flags & (range::EMPTY | range::UB_INF)) == 0; - - let mut range = DecodeRange::new(buf)?; - - let lower = if has_lower { - Some(Box::new(self.element.decode(range.read()?)?)) - } else { - None - }; - let upper = if has_upper { - Some(Box::new(self.element.decode(range.read()?)?)) - } else { - None - }; - - Ok(Value::Range(model::Range { - lower, - upper, - inc_lower, - inc_upper, - empty, - })) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let rng = match val { - Value::Range(rng) => rng, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - - let flags = if rng.empty { - range::EMPTY - } else { - (if rng.inc_lower { range::LB_INC } else { 0 }) - | (if rng.inc_upper { range::UB_INC } else { 0 }) - | (if rng.lower.is_none() { - range::LB_INF - } else { - 0 - }) - | (if rng.upper.is_none() { - range::UB_INF - } else { - 0 - }) - }; - buf.reserve(1); - buf.put_u8(flags as u8); - - if let Some(lower) = &rng.lower { - let pos = buf.len(); - buf.reserve(4); - buf.put_u32(0); // replaced after serializing a value - self.element.encode(buf, lower)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &u32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - } - - if let Some(upper) = &rng.upper { - let pos = buf.len(); - buf.reserve(4); - buf.put_u32(0); // replaced after serializing a value - self.element.encode(buf, upper)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &u32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - } - - Ok(()) - } -} - -impl Codec for MultiRange { - fn decode(&self, buf: &[u8]) -> Result { - let elements = DecodeArrayLike::new_tuple_header(buf)?; - let items = decode_array_like(elements, &*self.element)?; - Ok(Value::Array(items)) - } - - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let items = match val { - Value::Array(items) => items, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.reserve(4); - buf.put_u32(items.len().try_into().ok().context(errors::ArrayTooLong)?); - for item in items { - buf.reserve(4); - let pos = buf.len(); - buf.put_u32(0); // replaced after serializing a value - self.element.encode(buf, item)?; - let len = buf.len() - pos - 4; - buf[pos..pos + 4].copy_from_slice( - &u32::try_from(len) - .ok() - .context(errors::ElementTooLong)? - .to_be_bytes(), - ); - } - Ok(()) - } -} - -impl Codec for Enum { - fn decode(&self, buf: &[u8]) -> Result { - let val: &str = RawCodec::decode(buf)?; - let val = self.members.get(val).context(errors::ExtraEnumValue)?; - Ok(Value::Enum(EnumValue(val.clone()))) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::Enum(val) => val.0.as_ref(), - Value::Str(val) => val.as_str(), - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - ensure!(self.members.contains(val), errors::MissingEnumValue); - buf.extend(val.as_bytes()); - Ok(()) - } -} - -impl Codec for PostGisGeometry { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::PostGisGeometry) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::PostGisGeometry(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.extend(val); - Ok(()) - } -} - -impl Codec for PostGisGeography { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::PostGisGeography) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::PostGisGeography(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.extend(val); - Ok(()) - } -} - -impl Codec for PostGisBox2d { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::PostGisBox2d) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::PostGisBox2d(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.extend(val); - Ok(()) - } -} - -impl Codec for PostGisBox3d { - fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(Value::PostGisBox3d) - } - fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> { - let val = match val { - Value::PostGisBox3d(val) => val, - _ => Err(errors::invalid_value(type_name::(), val))?, - }; - buf.extend(val); - Ok(()) - } -} diff --git a/edgedb-protocol/src/common.rs b/edgedb-protocol/src/common.rs deleted file mode 100644 index 03fbfb8b..00000000 --- a/edgedb-protocol/src/common.rs +++ /dev/null @@ -1,154 +0,0 @@ -/*! -([Website reference](https://www.edgedb.com/docs/reference/protocol/messages#parse)) Capabilities, CompilationFlags etc. from the message protocol. -*/ - -use crate::errors; -use crate::model::Uuid; -use bytes::Bytes; - -use crate::descriptors::Typedesc; -use crate::encoding::Input; -use crate::errors::DecodeError; -use crate::features::ProtocolVersion; - -pub use crate::client_message::{InputLanguage, IoFormat}; - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum Cardinality { - NoResult = 0x6e, - AtMostOne = 0x6f, - One = 0x41, - Many = 0x6d, - AtLeastOne = 0x4d, -} - -bitflags::bitflags! { - #[derive(Clone, Copy, Debug, PartialEq, Eq)] - pub struct Capabilities: u64 { - const MODIFICATIONS = 0b00000001; - const SESSION_CONFIG = 0b00000010; - const TRANSACTION = 0b00000100; - const DDL = 0b00001000; - const PERSISTENT_CONFIG = 0b00010000; - const ALL = 0b00011111; - } -} - -bitflags::bitflags! { - #[derive(Clone, Copy, Debug, PartialEq, Eq)] - pub struct CompilationFlags: u64 { - const INJECT_OUTPUT_TYPE_IDS = 0b00000001; - const INJECT_OUTPUT_TYPE_NAMES = 0b00000010; - const INJECT_OUTPUT_OBJECT_IDS = 0b00000100; - } -} - -bitflags::bitflags! { - #[derive(Clone, Copy, Debug, PartialEq, Eq)] - pub struct DumpFlags: u64 { - const DUMP_SECRETS = 0b00000001; - } -} - -#[derive(Debug, Clone)] -pub struct CompilationOptions { - pub implicit_limit: Option, - pub implicit_typenames: bool, - pub implicit_typeids: bool, - pub allow_capabilities: Capabilities, - pub explicit_objectids: bool, - pub io_format: IoFormat, - pub expected_cardinality: Cardinality, - pub input_language: InputLanguage, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct State { - pub typedesc_id: Uuid, - pub data: Bytes, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct RawTypedesc { - pub proto: ProtocolVersion, - pub id: Uuid, - pub data: Bytes, -} - -impl RawTypedesc { - pub fn uninitialized() -> RawTypedesc { - RawTypedesc { - proto: ProtocolVersion::current(), - id: Uuid::from_u128(0), - data: Bytes::new(), - } - } - pub fn decode(&self) -> Result { - let cur = &mut Input::new(self.proto.clone(), self.data.clone()); - Typedesc::decode_with_id(self.id, cur) - } -} - -impl std::convert::TryFrom for Cardinality { - type Error = errors::DecodeError; - fn try_from(cardinality: u8) -> Result { - match cardinality { - 0x6e => Ok(Cardinality::NoResult), - 0x6f => Ok(Cardinality::AtMostOne), - 0x41 => Ok(Cardinality::One), - 0x6d => Ok(Cardinality::Many), - 0x4d => Ok(Cardinality::AtLeastOne), - _ => Err(errors::InvalidCardinality { cardinality }.build()), - } - } -} - -impl Cardinality { - pub fn is_optional(&self) -> bool { - use Cardinality::*; - match self { - NoResult => true, - AtMostOne => true, - One => false, - Many => true, - AtLeastOne => false, - } - } -} - -impl std::convert::TryFrom for InputLanguage { - type Error = errors::DecodeError; - fn try_from(input_language: u8) -> Result { - match input_language { - 0x45 => Ok(InputLanguage::EdgeQL), - 0x53 => Ok(InputLanguage::SQL), - _ => Err(errors::InvalidInputLanguage { input_language }.build()), - } - } -} - -impl State { - pub fn empty() -> State { - State { - typedesc_id: Uuid::from_u128(0), - data: Bytes::new(), - } - } - pub fn descriptor_id(&self) -> Uuid { - self.typedesc_id - } -} - -impl CompilationOptions { - pub fn flags(&self) -> CompilationFlags { - let mut cflags = CompilationFlags::empty(); - if self.implicit_typenames { - cflags |= CompilationFlags::INJECT_OUTPUT_TYPE_NAMES; - } - if self.implicit_typeids { - cflags |= CompilationFlags::INJECT_OUTPUT_TYPE_IDS; - } - // TODO(tailhook) object ids - cflags - } -} diff --git a/edgedb-protocol/src/descriptors.rs b/edgedb-protocol/src/descriptors.rs deleted file mode 100644 index ad512ec1..00000000 --- a/edgedb-protocol/src/descriptors.rs +++ /dev/null @@ -1,1129 +0,0 @@ -/*! -([Website reference](https://www.edgedb.com/docs/reference/protocol/typedesc)) Types for the [Descriptor] enum. - -```rust,ignore -pub enum Descriptor { - Set(SetDescriptor), - ObjectShape(ObjectShapeDescriptor), - BaseScalar(BaseScalarTypeDescriptor), - Scalar(ScalarTypeDescriptor), - Tuple(TupleTypeDescriptor), - NamedTuple(NamedTupleTypeDescriptor), - Array(ArrayTypeDescriptor), - Enumeration(EnumerationTypeDescriptor), - InputShape(InputShapeTypeDescriptor), - Range(RangeTypeDescriptor), - MultiRange(MultiRangeTypeDescriptor), - TypeAnnotation(TypeAnnotationDescriptor), -} -``` - -From the website: - ->The type descriptor is essentially a list of type information blocks: ->* each block encodes one type; ->* blocks can reference other blocks. - ->While parsing the _blocks_, a database driver can assemble an _encoder_ or a _decoder_ of the EdgeDB binary data. - ->An _encoder_ is used to encode objects, native to the driver’s runtime, to binary data that EdegDB can decode and work with. - ->A _decoder_ is used to decode data from EdgeDB native format to data types native to the driver. -*/ - -use std::collections::{BTreeMap, BTreeSet}; -use std::convert::{TryFrom, TryInto}; -use std::fmt::{Debug, Formatter}; -use std::ops::Deref; -use std::sync::Arc; - -use bytes::{Buf, BufMut, BytesMut}; -use edgedb_errors::{ClientEncodingError, DescriptorMismatch, Error, ErrorKind}; -use snafu::{ensure, OptionExt}; -use uuid::Uuid; - -use crate::codec::{build_codec, uuid_to_known_name, Codec}; -use crate::common::{Cardinality, State}; -use crate::encoding::{Decode, Input}; -use crate::errors::{self, CodecError, DecodeError}; -use crate::errors::{InvalidTypeDescriptor, UnexpectedTypePos}; -use crate::features::ProtocolVersion; -use crate::query_arg::{self, Encoder, QueryArg}; -use crate::queryable; -use crate::value::Value; - -pub use crate::common::RawTypedesc; - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct TypePos(pub u16); - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Descriptor { - Set(SetDescriptor), - ObjectShape(ObjectShapeDescriptor), - BaseScalar(BaseScalarTypeDescriptor), - Scalar(ScalarTypeDescriptor), - Tuple(TupleTypeDescriptor), - NamedTuple(NamedTupleTypeDescriptor), - Array(ArrayTypeDescriptor), - Enumeration(EnumerationTypeDescriptor), - InputShape(InputShapeTypeDescriptor), - Range(RangeTypeDescriptor), - MultiRange(MultiRangeTypeDescriptor), - Object(ObjectTypeDescriptor), - Compound(CompoundTypeDescriptor), - SQLRow(SQLRowDescriptor), - TypeAnnotation(TypeAnnotationDescriptor), -} - -#[derive(Clone, PartialEq, Eq)] -pub struct DescriptorUuid(Uuid); - -impl Debug for DescriptorUuid { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - match uuid_to_known_name(&self.0) { - Some(known_name) => write!(f, "{known_name}"), - None => write!(f, "{}", &self.0), - } - } -} - -impl Deref for DescriptorUuid { - type Target = Uuid; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl From for DescriptorUuid { - fn from(uuid: Uuid) -> Self { - Self(uuid) - } -} - -impl PartialEq for DescriptorUuid { - fn eq(&self, other: &Uuid) -> bool { - self.0 == *other - } -} - -#[derive(Debug)] -pub struct Typedesc { - pub(crate) proto: ProtocolVersion, - pub(crate) array: Vec, - pub(crate) root_id: Uuid, - pub(crate) root_pos: Option, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct SetDescriptor { - pub id: DescriptorUuid, - pub type_pos: TypePos, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ObjectShapeDescriptor { - pub id: DescriptorUuid, - pub ephemeral_free_shape: bool, - pub type_pos: Option, - pub elements: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct InputShapeTypeDescriptor { - pub id: DescriptorUuid, - pub elements: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ShapeElement { - pub flag_implicit: bool, - pub flag_link_property: bool, - pub flag_link: bool, - pub cardinality: Option, - pub name: String, - pub type_pos: TypePos, - pub source_type_pos: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct InputShapeElement { - pub cardinality: Option, - pub name: String, - pub type_pos: TypePos, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct BaseScalarTypeDescriptor { - pub id: DescriptorUuid, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ScalarTypeDescriptor { - pub id: DescriptorUuid, - pub base_type_pos: Option, - pub name: Option, - pub schema_defined: Option, - pub ancestors: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct TupleTypeDescriptor { - pub id: DescriptorUuid, - pub element_types: Vec, - pub name: Option, - pub schema_defined: Option, - pub ancestors: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct NamedTupleTypeDescriptor { - pub id: DescriptorUuid, - pub elements: Vec, - pub name: Option, - pub schema_defined: Option, - pub ancestors: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ObjectTypeDescriptor { - pub id: DescriptorUuid, - pub name: Option, - pub schema_defined: Option, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct SQLRowDescriptor { - pub id: DescriptorUuid, - pub elements: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct SQLRowElement { - pub name: String, - pub type_pos: TypePos, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -#[repr(u8)] -pub enum TypeOperation { - UNION = 1, - INTERSECTION = 2, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct CompoundTypeDescriptor { - pub id: DescriptorUuid, - pub name: Option, - pub schema_defined: Option, - pub op: TypeOperation, - pub components: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct TupleElement { - pub name: String, - pub type_pos: TypePos, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ArrayTypeDescriptor { - pub id: DescriptorUuid, - pub type_pos: TypePos, - pub dimensions: Vec>, - pub name: Option, - pub schema_defined: Option, - pub ancestors: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct RangeTypeDescriptor { - pub id: DescriptorUuid, - pub type_pos: TypePos, - pub name: Option, - pub schema_defined: Option, - pub ancestors: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct MultiRangeTypeDescriptor { - pub id: DescriptorUuid, - pub type_pos: TypePos, - pub name: Option, - pub schema_defined: Option, - pub ancestors: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct EnumerationTypeDescriptor { - pub id: DescriptorUuid, - pub members: Vec, - pub name: Option, - pub schema_defined: Option, - pub ancestors: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct TypeAnnotationDescriptor { - pub annotated_type: u8, - pub id: DescriptorUuid, - pub annotation: String, -} - -pub struct StateBorrow<'a> { - pub module: &'a Option, - pub aliases: &'a BTreeMap, - pub config: &'a BTreeMap, - pub globals: &'a BTreeMap, -} - -impl Typedesc { - pub fn id(&self) -> &Uuid { - &self.root_id - } - pub fn descriptors(&self) -> &[Descriptor] { - &self.array - } - pub fn root_pos(&self) -> Option { - self.root_pos - } - pub fn build_codec(&self) -> Result, CodecError> { - build_codec(self.root_pos(), self.descriptors()) - } - pub fn get(&self, type_pos: TypePos) -> Result<&Descriptor, CodecError> { - self.array - .get(type_pos.0 as usize) - .context(UnexpectedTypePos { - position: type_pos.0, - }) - } - pub fn nothing(protocol: &ProtocolVersion) -> Typedesc { - Typedesc { - proto: protocol.clone(), - array: Vec::new(), - root_id: Uuid::from_u128(0), - root_pos: None, - } - } - pub fn is_empty_tuple(&self) -> bool { - match self.root() { - Some(Descriptor::Tuple(t)) => { - *t.id == Uuid::from_u128(0xFF) && t.element_types.is_empty() - } - _ => false, - } - } - pub fn root(&self) -> Option<&Descriptor> { - self.root_pos.and_then(|pos| self.array.get(pos.0 as usize)) - } - pub(crate) fn decode_with_id(root_id: Uuid, buf: &mut Input) -> Result { - let mut descriptors = Vec::new(); - while buf.remaining() > 0 { - match Descriptor::decode(buf)? { - Descriptor::TypeAnnotation(_) => {} - item => descriptors.push(item), - } - } - let root_pos = if root_id == Uuid::from_u128(0) { - None - } else { - let idx = descriptors - .iter() - .position(|x| *x.id() == root_id) - .context(errors::UuidNotFound { uuid: root_id })?; - let pos = idx - .try_into() - .ok() - .context(errors::TooManyDescriptors { index: idx })?; - Some(TypePos(pos)) - }; - Ok(Typedesc { - proto: buf.proto().clone(), - array: descriptors, - root_id, - root_pos, - }) - } - pub fn as_query_arg_context(&self) -> query_arg::DescriptorContext { - query_arg::DescriptorContext { - proto: &self.proto, - descriptors: self.descriptors(), - root_pos: self.root_pos, - } - } - pub fn as_queryable_context(&self) -> queryable::DescriptorContext { - let mut ctx = queryable::DescriptorContext::new(self.descriptors()); - ctx.has_implicit_id = self.proto.has_implicit_id(); - ctx.has_implicit_tid = self.proto.has_implicit_tid(); - ctx - } - pub fn serialize_state(&self, state: &StateBorrow) -> Result { - #[derive(Debug)] - struct Indices { - module: (u32, TypePos), - aliases: (u32, TypePos), - config: (u32, TypePos), - globals: (u32, TypePos), - } - let mut buf = BytesMut::with_capacity(128); - let ctx = self.as_query_arg_context(); - let mut enc = Encoder::new(&ctx, &mut buf); - - let root = enc - .ctx - .root_pos - .ok_or_else(|| DescriptorMismatch::with_message("invalid state descriptor")) - .and_then(|p| enc.ctx.get(p))?; - let indices = match root { - Descriptor::InputShape(desc) => { - let mut module = None; - let mut aliases = None; - let mut config = None; - let mut globals = None; - for (i, elem) in desc.elements.iter().enumerate() { - let i = i as u32; - match &elem.name[..] { - "module" => module = Some((i, elem.type_pos)), - "aliases" => aliases = Some((i, elem.type_pos)), - "config" => config = Some((i, elem.type_pos)), - "globals" => globals = Some((i, elem.type_pos)), - _ => {} - } - } - Indices { - module: module.ok_or_else(|| { - DescriptorMismatch::with_message("no `module` field in state") - })?, - aliases: aliases.ok_or_else(|| { - DescriptorMismatch::with_message("no `aliases` field in state") - })?, - config: config.ok_or_else(|| { - DescriptorMismatch::with_message("no `config` field in state") - })?, - globals: globals.ok_or_else(|| { - DescriptorMismatch::with_message("no `globals` field in state") - })?, - } - } - _ => return Err(DescriptorMismatch::with_message("invalid state descriptor")), - }; - - enc.buf.reserve(4 + 8 * 4); - enc.buf.put_u32(4); - - let module = state.module.as_deref().unwrap_or("default"); - module.check_descriptor(enc.ctx, indices.module.1)?; - - enc.buf.reserve(8); - enc.buf.put_u32(indices.module.0); - module.encode_slot(&mut enc)?; - - match enc.ctx.get(indices.aliases.1)? { - Descriptor::Array(arr) => match enc.ctx.get(arr.type_pos)? { - Descriptor::Tuple(tup) => { - if tup.element_types.len() != 2 { - return Err(DescriptorMismatch::with_message( - "invalid type descriptor for aliases", - )); - } - "".check_descriptor(enc.ctx, tup.element_types[0])?; - "".check_descriptor(enc.ctx, tup.element_types[1])?; - } - _ => { - return Err(DescriptorMismatch::with_message( - "invalid type descriptor for aliases", - )); - } - }, - _ => { - return Err(DescriptorMismatch::with_message( - "invalid type descriptor for aliases", - )); - } - } - - enc.buf - .reserve(4 + 16 + state.aliases.len() * (4 + (8 + 4) * 2)); - enc.buf.put_u32(indices.aliases.0); - enc.length_prefixed(|enc| { - enc.buf.put_u32( - state - .aliases - .len() - .try_into() - .map_err(|_| ClientEncodingError::with_message("too many aliases"))?, - ); - for (key, value) in state.aliases { - enc.length_prefixed(|enc| { - enc.buf.reserve(4 + (8 + 4) * 2); - enc.buf.put_u32(2); - enc.buf.put_u32(0); // reserved - - key.encode_slot(enc)?; - value.encode_slot(enc)?; - Ok(()) - })?; - } - Ok(()) - })?; - - enc.buf.reserve(4); - enc.buf.put_u32(indices.config.0); - enc.length_prefixed(|enc| { - serialize_variables(enc, state.config, indices.config.1, "config") - })?; - enc.buf.reserve(4); - enc.buf.put_u32(indices.globals.0); - enc.length_prefixed(|enc| { - serialize_variables(enc, state.globals, indices.globals.1, "globals") - })?; - let data = buf.freeze(); - Ok(State { - typedesc_id: self.root_id, - data, - }) - } - pub fn proto(&self) -> &ProtocolVersion { - &self.proto - } -} - -fn serialize_variables( - enc: &mut Encoder, - variables: &BTreeMap, - - type_pos: TypePos, - tag: &str, -) -> Result<(), Error> { - enc.buf.reserve(4 + variables.len() * (4 + 4)); - enc.buf.put_u32( - variables - .len() - .try_into() - .map_err(|_| ClientEncodingError::with_message(format!("too many items in {}", tag)))?, - ); - - let desc = match enc.ctx.get(type_pos)? { - Descriptor::InputShape(desc) => desc, - _ => { - return Err(DescriptorMismatch::with_message(format!( - "invalid type descriptor for {}", - tag - ))); - } - }; - - let mut serialized = 0; - for (idx, el) in desc.elements.iter().enumerate() { - if let Some(value) = variables.get(&el.name) { - value.check_descriptor(enc.ctx, el.type_pos)?; - serialized += 1; - enc.buf.reserve(8); - enc.buf.put_u32(idx as u32); - value.encode_slot(enc)?; - } - } - - if serialized != variables.len() { - let mut extra_vars = variables.keys().collect::>(); - for el in &desc.elements { - extra_vars.remove(&el.name); - } - return Err(ClientEncodingError::with_message(format!( - "non-existing entries {} of {}", - extra_vars - .into_iter() - .map(|x| &x[..]) - .collect::>() - .join(", "), - tag - ))); - } - - Ok(()) -} - -impl Descriptor { - pub fn id(&self) -> &Uuid { - use Descriptor::*; - match self { - Set(i) => &i.id, - ObjectShape(i) => &i.id, - BaseScalar(i) => &i.id, - Scalar(i) => &i.id, - Tuple(i) => &i.id, - NamedTuple(i) => &i.id, - Array(i) => &i.id, - Range(i) => &i.id, - MultiRange(i) => &i.id, - Enumeration(i) => &i.id, - InputShape(i) => &i.id, - Object(i) => &i.id, - Compound(i) => &i.id, - SQLRow(i) => &i.id, - TypeAnnotation(i) => &i.id, - } - } - pub fn decode(buf: &mut Input) -> Result { - ::decode(buf) - } - pub fn normalize_to_base( - &self, - ctx: &query_arg::DescriptorContext, - ) -> Result { - let norm = match self { - Descriptor::Scalar(d) if d.base_type_pos.is_some() => { - match ctx.get(d.base_type_pos.unwrap())? { - Descriptor::Scalar(d) => { - Descriptor::BaseScalar(BaseScalarTypeDescriptor { id: d.id.clone() }) - } - desc => desc.clone(), - } - } - Descriptor::Scalar(d) => { - if ctx.proto.is_2() { - Descriptor::BaseScalar(BaseScalarTypeDescriptor { id: d.id.clone() }) - } else { - unreachable!("scalar dereference to a non-base type") - } - } - desc => desc.clone(), - }; - - Ok(norm) - } -} - -impl Decode for Vec { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 2, errors::Underflow); - let element_count = buf.get_u16(); - let mut elements = Vec::with_capacity(element_count as usize); - for _ in 0..element_count { - elements.push(T::decode(buf)?); - } - Ok(elements) - } -} - -impl Decode for Option { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - - let val = match buf.get_i32() { - -1 => None, - n if n > 0 => Some(n as u32), - _ => errors::InvalidOptionU32.fail()?, - }; - - Ok(val) - } -} - -impl Decode for TypePos { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 2, errors::Underflow); - Ok(Self(buf.get_u16())) - } -} - -impl Decode for Descriptor { - fn decode(buf: &mut Input) -> Result { - use Descriptor as D; - if buf.proto().is_2() { - ensure!(buf.remaining() >= 4, errors::Underflow); - let desc_len = buf.get_u32() as u64; - ensure!((buf.remaining() as u64) >= desc_len, errors::Underflow); - } - ensure!(buf.remaining() >= 1, errors::Underflow); - match buf.chunk()[0] { - 0x00 => SetDescriptor::decode(buf).map(D::Set), - 0x01 => ObjectShapeDescriptor::decode(buf).map(D::ObjectShape), - 0x02 => BaseScalarTypeDescriptor::decode(buf).map(D::BaseScalar), - 0x03 => ScalarTypeDescriptor::decode(buf).map(D::Scalar), - 0x04 => TupleTypeDescriptor::decode(buf).map(D::Tuple), - 0x05 => NamedTupleTypeDescriptor::decode(buf).map(D::NamedTuple), - 0x06 => ArrayTypeDescriptor::decode(buf).map(D::Array), - 0x07 => EnumerationTypeDescriptor::decode(buf).map(D::Enumeration), - 0x08 => InputShapeTypeDescriptor::decode(buf).map(D::InputShape), - 0x09 => RangeTypeDescriptor::decode(buf).map(D::Range), - 0x0A => ObjectTypeDescriptor::decode(buf).map(D::Object), - 0x0B => CompoundTypeDescriptor::decode(buf).map(D::Compound), - 0x0C => MultiRangeTypeDescriptor::decode(buf).map(D::MultiRange), - 0x0D => SQLRowDescriptor::decode(buf).map(D::SQLRow), - 0x7F..=0xFF => TypeAnnotationDescriptor::decode(buf).map(D::TypeAnnotation), - descriptor => InvalidTypeDescriptor { descriptor }.fail()?, - } - } -} - -impl Decode for SetDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 0); - let id = Uuid::decode(buf)?.into(); - let type_pos = TypePos(buf.get_u16()); - Ok(SetDescriptor { id, type_pos }) - } -} - -impl Decode for ObjectShapeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 1); - let id = Uuid::decode(buf)?.into(); - let type_desc = if buf.proto().is_2() { - let ephemeral_free_shape = bool::decode(buf)?; - let type_pos = Some(TypePos::decode(buf)?); - let elements = Vec::::decode(buf)?; - ObjectShapeDescriptor { - id, - elements, - ephemeral_free_shape, - type_pos, - } - } else { - let elements = Vec::::decode(buf)?; - ObjectShapeDescriptor { - id, - elements, - ephemeral_free_shape: false, - type_pos: None, - } - }; - Ok(type_desc) - } -} - -impl Decode for ShapeElement { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 7, errors::Underflow); - let (flags, cardinality) = if buf.proto().is_at_least(0, 11) { - let flags = buf.get_u32(); - let cardinality = TryFrom::try_from(buf.get_u8())?; - (flags, Some(cardinality)) - } else { - (buf.get_u8() as u32, None) - }; - let name = String::decode(buf)?; - let type_pos = TypePos::decode(buf)?; - let source_type_pos = if buf.proto().is_2() { - Some(TypePos::decode(buf)?) - } else { - None - }; - Ok(ShapeElement { - flag_implicit: flags & 0b001 != 0, - flag_link_property: flags & 0b010 != 0, - flag_link: flags & 0b100 != 0, - cardinality, - name, - type_pos, - source_type_pos, - }) - } -} - -impl Decode for InputShapeTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 8); - let id = Uuid::decode(buf)?.into(); - let elements = Vec::::decode(buf)?; - Ok(InputShapeTypeDescriptor { id, elements }) - } -} - -impl Decode for InputShapeElement { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 7, errors::Underflow); - let _flags = buf.get_u32(); - let cardinality = Some(TryFrom::try_from(buf.get_u8())?); - let name = String::decode(buf)?; - let type_pos = TypePos::decode(buf)?; - Ok(InputShapeElement { - cardinality, - name, - type_pos, - }) - } -} - -impl Decode for BaseScalarTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - let desc_byte = buf.get_u8(); - assert!(desc_byte == 2); - ensure!( - !buf.proto().is_2(), - InvalidTypeDescriptor { - descriptor: desc_byte - } - ); - let id = Uuid::decode(buf)?.into(); - Ok(BaseScalarTypeDescriptor { id }) - } -} - -impl Decode for SQLRowDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 0x0D); - let id = Uuid::decode(buf)?.into(); - let elements = Vec::::decode(buf)?; - Ok(SQLRowDescriptor { id, elements }) - } -} - -impl Decode for SQLRowElement { - fn decode(buf: &mut Input) -> Result { - let name = String::decode(buf)?; - let type_pos = TypePos::decode(buf)?; - Ok(SQLRowElement { name, type_pos }) - } -} - -impl Decode for ObjectTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 0x0A); - let id = Uuid::decode(buf)?.into(); - let name = Some(String::decode(buf)?); - let schema_defined = Some(bool::decode(buf)?); - let type_desc = ObjectTypeDescriptor { - id, - name, - schema_defined, - }; - Ok(type_desc) - } -} - -impl Decode for TypeOperation { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 1, errors::Underflow); - let val = match buf.get_u8() { - 0x00 => TypeOperation::UNION, - 0x01 => TypeOperation::INTERSECTION, - _ => errors::InvalidTypeOperation.fail()?, - }; - Ok(val) - } -} - -impl Decode for CompoundTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 0x0B); - let id = Uuid::decode(buf)?.into(); - let name = Some(String::decode(buf)?); - let schema_defined = Some(bool::decode(buf)?); - ensure!(buf.remaining() >= 1, errors::Underflow); - let op = TypeOperation::decode(buf)?; - let components = Vec::::decode(buf)?; - let type_desc = CompoundTypeDescriptor { - id, - name, - schema_defined, - op, - components, - }; - Ok(type_desc) - } -} - -impl Decode for ScalarTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 3); - let id = Uuid::decode(buf)?.into(); - let type_desc = if buf.proto().is_2() { - let name = Some(String::decode(buf)?); - let schema_defined = Some(bool::decode(buf)?); - let ancestors = Vec::::decode(buf)?; - let base_type_pos = ancestors.last().copied(); - ScalarTypeDescriptor { - id, - base_type_pos, - name, - schema_defined, - ancestors, - } - } else { - let base_type_pos = Some(TypePos(buf.get_u16())); - ScalarTypeDescriptor { - id, - base_type_pos, - name: None, - schema_defined: None, - ancestors: vec![], - } - }; - Ok(type_desc) - } -} - -impl Decode for TupleTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 4); - let id = Uuid::decode(buf)?.into(); - - let type_desc = if buf.proto().is_2() { - let name = Some(String::decode(buf)?); - let schema_defined = Some(bool::decode(buf)?); - let ancestors = Vec::::decode(buf)?; - let element_types = Vec::::decode(buf)?; - TupleTypeDescriptor { - id, - element_types, - name, - schema_defined, - ancestors, - } - } else { - let element_types = Vec::::decode(buf)?; - TupleTypeDescriptor { - id, - element_types, - name: None, - schema_defined: None, - ancestors: vec![], - } - }; - Ok(type_desc) - } -} - -impl Decode for NamedTupleTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 5); - let id = Uuid::decode(buf)?.into(); - - let type_desc = if buf.proto().is_2() { - let name = Some(String::decode(buf)?); - let schema_defined = Some(bool::decode(buf)?); - let ancestors = Vec::::decode(buf)?; - let elements = Vec::::decode(buf)?; - NamedTupleTypeDescriptor { - id, - elements, - name, - schema_defined, - ancestors, - } - } else { - let elements = Vec::::decode(buf)?; - NamedTupleTypeDescriptor { - id, - elements, - name: None, - schema_defined: None, - ancestors: vec![], - } - }; - - Ok(type_desc) - } -} - -impl Decode for TupleElement { - fn decode(buf: &mut Input) -> Result { - let name = String::decode(buf)?; - let type_pos = TypePos::decode(buf)?; - Ok(TupleElement { name, type_pos }) - } -} - -impl Decode for ArrayTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 21, errors::Underflow); - assert!(buf.get_u8() == 6); - let id = Uuid::decode(buf)?.into(); - let type_desc = if buf.proto().is_2() { - let name = Some(String::decode(buf)?); - let schema_defined = Some(bool::decode(buf)?); - let ancestors = Vec::::decode(buf)?; - let type_pos = TypePos::decode(buf)?; - let dimensions = Vec::>::decode(buf)?; - ArrayTypeDescriptor { - id, - type_pos, - dimensions, - name, - schema_defined, - ancestors, - } - } else { - let type_pos = TypePos::decode(buf)?; - let dimensions = Vec::>::decode(buf)?; - ArrayTypeDescriptor { - id, - type_pos, - dimensions, - name: None, - schema_defined: None, - ancestors: vec![], - } - }; - - Ok(type_desc) - } -} - -impl Decode for RangeTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 9); - let id = Uuid::decode(buf)?.into(); - let type_desc = if buf.proto().is_2() { - let name = Some(String::decode(buf)?); - let schema_defined = Some(bool::decode(buf)?); - let ancestors = Vec::::decode(buf)?; - let type_pos = TypePos::decode(buf)?; - RangeTypeDescriptor { - id, - type_pos, - name, - schema_defined, - ancestors, - } - } else { - let type_pos = TypePos::decode(buf)?; - RangeTypeDescriptor { - id, - type_pos, - name: None, - schema_defined: None, - ancestors: vec![], - } - }; - - Ok(type_desc) - } -} - -impl Decode for MultiRangeTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 0x0C); - let id = Uuid::decode(buf)?.into(); - let type_desc = if buf.proto().is_2() { - let name = Some(String::decode(buf)?); - let schema_defined = Some(bool::decode(buf)?); - let ancestors = Vec::::decode(buf)?; - let type_pos = TypePos::decode(buf)?; - MultiRangeTypeDescriptor { - id, - type_pos, - name, - schema_defined, - ancestors, - } - } else { - let type_pos = TypePos::decode(buf)?; - MultiRangeTypeDescriptor { - id, - type_pos, - name: None, - schema_defined: None, - ancestors: vec![], - } - }; - - Ok(type_desc) - } -} - -impl Decode for EnumerationTypeDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 19, errors::Underflow); - assert!(buf.get_u8() == 7); - let id = Uuid::decode(buf)?.into(); - let type_desc = if buf.proto().is_2() { - let name = Some(String::decode(buf)?); - let schema_defined = Some(bool::decode(buf)?); - let ancestors = Vec::::decode(buf)?; - let members = Vec::::decode(buf)?; - EnumerationTypeDescriptor { - id, - members, - name, - schema_defined, - ancestors, - } - } else { - let members = Vec::::decode(buf)?; - EnumerationTypeDescriptor { - id, - members, - name: None, - schema_defined: None, - ancestors: vec![], - } - }; - - Ok(type_desc) - } -} - -impl Decode for TypeAnnotationDescriptor { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 21, errors::Underflow); - let annotated_type = buf.get_u8(); - assert!(annotated_type >= 0x7F); - let id = Uuid::decode(buf)?.into(); - let annotation = String::decode(buf)?; - Ok(TypeAnnotationDescriptor { - annotated_type, - id, - annotation, - }) - } -} - -#[cfg(test)] -mod tests { - use crate::descriptors::{ - BaseScalarTypeDescriptor, Descriptor, DescriptorUuid, SetDescriptor, TypePos, - }; - use uuid::Uuid; - - #[test] - fn descriptor_uuid_debug_outputs() { - let float_32: Uuid = "00000000-0000-0000-0000-000000000106".parse().unwrap(); - let descriptor_id = DescriptorUuid::from(float_32); - assert_eq!(format!("{descriptor_id:?}"), "BaseScalar(float32)"); - - let random_uuid: Uuid = "7cc7e050-ef76-4ae9-b8a6-053ca9baa3d5".parse().unwrap(); - let descriptor_id = DescriptorUuid::from(random_uuid); - assert_eq!( - format!("{descriptor_id:?}"), - "7cc7e050-ef76-4ae9-b8a6-053ca9baa3d5" - ); - - let base_scalar = Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000106" - .parse::() - .unwrap() - .into(), - }); - assert_eq!( - format!("{base_scalar:?}"), - "BaseScalar(BaseScalarTypeDescriptor { id: BaseScalar(float32) })" - ); - - let set_descriptor_with_float32 = Descriptor::Set(SetDescriptor { - id: "00000000-0000-0000-0000-000000000106" - .parse::() - .unwrap() - .into(), - type_pos: TypePos(0), - }); - assert_eq!( - format!("{set_descriptor_with_float32:?}"), - "Set(SetDescriptor { id: BaseScalar(float32), type_pos: TypePos(0) })" - ); - } -} diff --git a/edgedb-protocol/src/encoding.rs b/edgedb-protocol/src/encoding.rs deleted file mode 100644 index 7c59d415..00000000 --- a/edgedb-protocol/src/encoding.rs +++ /dev/null @@ -1,218 +0,0 @@ -use std::collections::HashMap; -use std::convert::TryFrom; -use std::ops::{Deref, DerefMut, RangeBounds}; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use snafu::{ensure, OptionExt, ResultExt}; -use uuid::Uuid; - -use crate::errors::{self, DecodeError, EncodeError}; -use crate::features::ProtocolVersion; - -pub type KeyValues = HashMap; -pub type Annotations = HashMap; - -pub struct Input { - #[allow(dead_code)] - proto: ProtocolVersion, - bytes: Bytes, -} - -pub struct Output<'a> { - #[allow(dead_code)] - proto: &'a ProtocolVersion, - bytes: &'a mut BytesMut, -} - -pub(crate) trait Encode { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError>; -} - -pub(crate) trait Decode: Sized { - fn decode(buf: &mut Input) -> Result; -} - -impl Input { - pub fn new(proto: ProtocolVersion, bytes: Bytes) -> Input { - Input { proto, bytes } - } - pub fn proto(&self) -> &ProtocolVersion { - &self.proto - } - pub fn slice(&self, range: impl RangeBounds) -> Input { - Input { - proto: self.proto.clone(), - bytes: self.bytes.slice(range), - } - } -} - -impl Buf for Input { - fn remaining(&self) -> usize { - self.bytes.remaining() - } - - fn chunk(&self) -> &[u8] { - self.bytes.chunk() - } - - fn advance(&mut self, cnt: usize) { - self.bytes.advance(cnt) - } - - fn copy_to_bytes(&mut self, len: usize) -> Bytes { - self.bytes.copy_to_bytes(len) - } -} - -impl Deref for Input { - type Target = [u8]; - fn deref(&self) -> &[u8] { - &self.bytes[..] - } -} - -impl Deref for Output<'_> { - type Target = [u8]; - fn deref(&self) -> &[u8] { - &self.bytes[..] - } -} - -impl DerefMut for Output<'_> { - fn deref_mut(&mut self) -> &mut [u8] { - &mut self.bytes[..] - } -} - -impl Output<'_> { - pub fn new<'x>(proto: &'x ProtocolVersion, bytes: &'x mut BytesMut) -> Output<'x> { - Output { proto, bytes } - } - pub fn proto(&self) -> &ProtocolVersion { - self.proto - } - pub fn reserve(&mut self, size: usize) { - self.bytes.reserve(size) - } - pub fn extend(&mut self, slice: &[u8]) { - self.bytes.extend(slice) - } -} - -unsafe impl BufMut for Output<'_> { - fn remaining_mut(&self) -> usize { - self.bytes.remaining_mut() - } - unsafe fn advance_mut(&mut self, cnt: usize) { - self.bytes.advance_mut(cnt) - } - fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { - self.bytes.chunk_mut() - } -} - -pub(crate) fn encode(buf: &mut Output, code: u8, msg: &T) -> Result<(), EncodeError> { - buf.reserve(5); - buf.put_u8(code); - let base = buf.len(); - buf.put_slice(&[0; 4]); - - msg.encode(buf)?; - - let size = u32::try_from(buf.len() - base) - .ok() - .context(errors::MessageTooLong)?; - buf[base..base + 4].copy_from_slice(&size.to_be_bytes()[..]); - Ok(()) -} - -impl Encode for String { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(2 + self.len()); - buf.put_u32( - u32::try_from(self.len()) - .ok() - .context(errors::StringTooLong)?, - ); - buf.extend(self.as_bytes()); - Ok(()) - } -} - -impl Encode for Bytes { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(2 + self.len()); - buf.put_u32( - u32::try_from(self.len()) - .ok() - .context(errors::StringTooLong)?, - ); - buf.extend(&self[..]); - Ok(()) - } -} - -impl Decode for String { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - let len = buf.get_u32() as usize; - // TODO(tailhook) ensure size < i32::MAX - ensure!(buf.remaining() >= len, errors::Underflow); - let mut data = vec![0u8; len]; - buf.copy_to_slice(&mut data[..]); - - String::from_utf8(data) - .map_err(|e| e.utf8_error()) - .context(errors::InvalidUtf8) - } -} - -impl Decode for Bytes { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - let len = buf.get_u32() as usize; - // TODO(tailhook) ensure size < i32::MAX - ensure!(buf.remaining() >= len, errors::Underflow); - Ok(buf.copy_to_bytes(len)) - } -} - -impl Decode for Uuid { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 16, errors::Underflow); - let mut bytes = [0u8; 16]; - buf.copy_to_slice(&mut bytes[..]); - let result = Uuid::from_slice(&bytes).context(errors::InvalidUuid)?; - Ok(result) - } -} - -impl Encode for Uuid { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.extend(self.as_bytes()); - Ok(()) - } -} - -impl Decode for bool { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 1, errors::Underflow); - let res = match buf.get_u8() { - 0x00 => false, - 0x01 => true, - v => errors::InvalidBool { val: v }.fail()?, - }; - Ok(res) - } -} - -impl Encode for bool { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.extend(match self { - true => &[0x01], - false => &[0x00], - }); - Ok(()) - } -} diff --git a/edgedb-protocol/src/error_response.rs b/edgedb-protocol/src/error_response.rs deleted file mode 100644 index 0ff0f626..00000000 --- a/edgedb-protocol/src/error_response.rs +++ /dev/null @@ -1,11 +0,0 @@ -use edgedb_errors::Error; - -use crate::server_message::ErrorResponse; - -impl From for Error { - fn from(val: ErrorResponse) -> Self { - Error::from_code(val.code) - .context(val.message) - .with_headers(val.attributes) - } -} diff --git a/edgedb-protocol/src/errors.rs b/edgedb-protocol/src/errors.rs deleted file mode 100644 index 2a8053ad..00000000 --- a/edgedb-protocol/src/errors.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::error::Error; -use std::str; - -use snafu::{Backtrace, IntoError, Snafu}; - -use crate::value::Value; - -#[derive(Snafu, Debug)] -#[snafu(visibility(pub), context(suffix(false)))] -#[non_exhaustive] -pub enum DecodeError { - #[snafu(display("unexpected end of frame"))] - Underflow { backtrace: Backtrace }, - #[snafu(display("frame contains extra data after decoding"))] - ExtraData { backtrace: Backtrace }, - #[snafu(display("invalid utf8 when decoding string: {}", source))] - InvalidUtf8 { - backtrace: Backtrace, - source: str::Utf8Error, - }, - #[snafu(display("invalid auth status: {:x}", auth_status))] - AuthStatusInvalid { - backtrace: Backtrace, - auth_status: u32, - }, - #[snafu(display("unsupported transaction state: {:x}", transaction_state))] - InvalidTransactionState { - backtrace: Backtrace, - transaction_state: u8, - }, - #[snafu(display("unsupported io format: {:x}", io_format))] - InvalidIoFormat { backtrace: Backtrace, io_format: u8 }, - #[snafu(display("unsupported cardinality: {:x}", cardinality))] - InvalidCardinality { - backtrace: Backtrace, - cardinality: u8, - }, - #[snafu(display("unsupported input language: {:x}", input_language))] - InvalidInputLanguage { - backtrace: Backtrace, - input_language: u8, - }, - #[snafu(display("unsupported capability: {:b}", capabilities))] - InvalidCapabilities { - backtrace: Backtrace, - capabilities: u64, - }, - #[snafu(display("unsupported compilation flags: {:b}", compilation_flags))] - InvalidCompilationFlags { - backtrace: Backtrace, - compilation_flags: u64, - }, - #[snafu(display("unsupported dump flags: {:b}", dump_flags))] - InvalidDumpFlags { - backtrace: Backtrace, - dump_flags: u64, - }, - #[snafu(display("unsupported describe aspect: {:x}", aspect))] - InvalidAspect { backtrace: Backtrace, aspect: u8 }, - #[snafu(display("unsupported type descriptor: {:x}", descriptor))] - InvalidTypeDescriptor { - backtrace: Backtrace, - descriptor: u8, - }, - #[snafu(display("invalid uuid: {}", source))] - InvalidUuid { - backtrace: Backtrace, - source: uuid::Error, - }, - #[snafu(display("non-zero reserved bytes received in data"))] - NonZeroReservedBytes { backtrace: Backtrace }, - #[snafu(display("object data size does not match its shape"))] - ObjectSizeMismatch { backtrace: Backtrace }, - #[snafu(display("tuple size does not match its shape"))] - TupleSizeMismatch { backtrace: Backtrace }, - #[snafu(display("unknown negative length marker"))] - InvalidMarker { backtrace: Backtrace }, - #[snafu(display("array shape for the Set codec is invalid"))] - InvalidSetShape { backtrace: Backtrace }, - #[snafu(display("array shape is invalid"))] - InvalidArrayShape { backtrace: Backtrace }, - #[snafu(display("array or set shape is invalid"))] - InvalidArrayOrSetShape { backtrace: Backtrace }, - #[snafu(display("decimal or bigint sign bytes have invalid value"))] - BadSign { backtrace: Backtrace }, - #[snafu(display("invalid boolean value: {val:?}"))] - InvalidBool { backtrace: Backtrace, val: u8 }, - #[snafu(display("invalid optional u32 value"))] - InvalidOptionU32 { backtrace: Backtrace }, - #[snafu(display("datetime is out of range"))] - InvalidDate { backtrace: Backtrace }, - #[snafu(display("json format is invalid"))] - InvalidJsonFormat { backtrace: Backtrace }, - #[snafu(display("enum value returned is not in type descriptor"))] - ExtraEnumValue { backtrace: Backtrace }, - #[snafu(display("too may descriptors ({})", index))] - TooManyDescriptors { backtrace: Backtrace, index: usize }, - #[snafu(display("invalid index in input shape ({})", index))] - InvalidIndex { backtrace: Backtrace, index: usize }, - #[snafu(display("uuid {} not found", uuid))] - UuidNotFound { - backtrace: Backtrace, - uuid: uuid::Uuid, - }, - #[snafu(display("error decoding value"))] - DecodeValue { - backtrace: Backtrace, - source: Box, - }, - #[snafu(display("missing required link or property"))] - MissingRequiredElement { backtrace: Backtrace }, - #[snafu(display("invalid format of {annotation} annotation"))] - InvalidAnnotationFormat { - backtrace: Backtrace, - annotation: &'static str, - }, - #[snafu(display("invalid type operation value"))] - InvalidTypeOperation { backtrace: Backtrace }, -} - -#[derive(Snafu, Debug)] -#[snafu(visibility(pub(crate)), context(suffix(false)))] -#[non_exhaustive] -pub enum EncodeError { - #[snafu(display("message doesn't fit 4GiB"))] - MessageTooLong { backtrace: Backtrace }, - #[snafu(display("string is larger than 64KiB"))] - StringTooLong { backtrace: Backtrace }, - #[snafu(display("more than 64Ki extensions"))] - TooManyExtensions { backtrace: Backtrace }, - #[snafu(display("more than 64Ki headers"))] - TooManyHeaders { backtrace: Backtrace }, - #[snafu(display("more than 64Ki params"))] - TooManyParams { backtrace: Backtrace }, - #[snafu(display("more than 64Ki attributes"))] - TooManyAttributes { backtrace: Backtrace }, - #[snafu(display("more than 64Ki authentication methods"))] - TooManyMethods { backtrace: Backtrace }, - #[snafu(display("more than 4Gi elements in the object"))] - TooManyElements { backtrace: Backtrace }, - #[snafu(display("single element larger than 4Gi"))] - ElementTooLong { backtrace: Backtrace }, - #[snafu(display("array or set has more than 4Gi elements"))] - ArrayTooLong { backtrace: Backtrace }, - #[snafu(display("bigint has more than 256Ki digits"))] - BigIntTooLong { backtrace: Backtrace }, - #[snafu(display("decimal has more than 256Ki digits"))] - DecimalTooLong { backtrace: Backtrace }, - #[snafu(display("unknown message types cannot be encoded"))] - UnknownMessageCantBeEncoded { backtrace: Backtrace }, - #[snafu(display( - "trying to encode invalid value type {} with codec {}", - value_type, - codec - ))] - InvalidValue { - backtrace: Backtrace, - value_type: &'static str, - codec: &'static str, - }, - #[snafu(display("shape of data does not match shape of encoder"))] - ObjectShapeMismatch { backtrace: Backtrace }, - #[snafu(display("datetime value is out of range"))] - DatetimeRange { backtrace: Backtrace }, - #[snafu(display("tuple size doesn't match encoder"))] - TupleShapeMismatch { backtrace: Backtrace }, - #[snafu(display("enum value is not in type descriptor"))] - MissingEnumValue { backtrace: Backtrace }, -} - -#[derive(Snafu, Debug)] -#[snafu(visibility(pub(crate)), context(suffix(false)))] -#[non_exhaustive] -pub enum CodecError { - #[snafu(display("type position {} is absent", position))] - UnexpectedTypePos { backtrace: Backtrace, position: u16 }, - #[snafu(display("base scalar with uuid {} not found", uuid))] - UndefinedBaseScalar { - backtrace: Backtrace, - uuid: uuid::Uuid, - }, -} - -pub fn invalid_value(codec: &'static str, value: &Value) -> EncodeError { - InvalidValue { - codec, - value_type: value.kind(), - } - .build() -} - -pub fn decode_error(e: E) -> DecodeError { - DecodeValue.into_error(Box::new(e)) -} diff --git a/edgedb-protocol/src/features.rs b/edgedb-protocol/src/features.rs deleted file mode 100644 index a71af0a3..00000000 --- a/edgedb-protocol/src/features.rs +++ /dev/null @@ -1,54 +0,0 @@ -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ProtocolVersion { - pub(crate) major_ver: u16, - pub(crate) minor_ver: u16, -} - -impl ProtocolVersion { - pub fn current() -> ProtocolVersion { - ProtocolVersion { - major_ver: 3, - minor_ver: 0, - } - } - pub fn new(major_ver: u16, minor_ver: u16) -> ProtocolVersion { - ProtocolVersion { - major_ver, - minor_ver, - } - } - pub fn version_tuple(&self) -> (u16, u16) { - (self.major_ver, self.minor_ver) - } - pub fn is_1(&self) -> bool { - self.major_ver >= 1 - } - pub fn is_2(&self) -> bool { - self.major_ver >= 2 - } - pub fn is_3(&self) -> bool { - self.major_ver >= 3 - } - pub fn supports_inline_typenames(&self) -> bool { - self.version_tuple() >= (0, 9) - } - pub fn has_implicit_tid(&self) -> bool { - self.version_tuple() <= (0, 8) - } - pub fn has_implicit_id(&self) -> bool { - // Some of pre 1.0 protocols required implicit id. - // Later it was opt out. In 1.0 it's opt-in. - // We never opt-in or opt-out so whether it's present only on pre 1.0 - // portocols. - !self.is_1() - } - pub fn is_multilingual(&self) -> bool { - self.is_at_least(3, 0) - } - pub fn is_at_least(&self, major_ver: u16, minor_ver: u16) -> bool { - self.major_ver > major_ver || self.major_ver == major_ver && self.minor_ver >= minor_ver - } - pub fn is_at_most(&self, major_ver: u16, minor_ver: u16) -> bool { - self.major_ver < major_ver || self.major_ver == major_ver && self.minor_ver <= minor_ver - } -} diff --git a/edgedb-protocol/src/lib.rs b/edgedb-protocol/src/lib.rs index 8062a608..1a216f20 100644 --- a/edgedb-protocol/src/lib.rs +++ b/edgedb-protocol/src/lib.rs @@ -1,80 +1,7 @@ /*! -([Website reference](https://www.edgedb.com/docs/reference/protocol/index)) The EdgeDB protocol for Edgedb-Rust. +The EdgeDB protocol for EdgeDB-Rust -EdgeDB types used for data modeling can be seen on the [model] crate, in which the [Value](crate::value::Value) -enum provides the quickest overview of all the possible types encountered using the client. Many of the variants hold Rust -standard library types while others contain types defined in this protocol. Some types such as [Duration](crate::model::Duration) -appear to be standard library types but are unique to the EdgeDB protocol. - -Other parts of this crate pertain to the rest of the EdgeDB protocol (e.g. client + server message formats), plus various traits -for working with the client such as: - -* [QueryArg](crate::query_arg::QueryArg): a single argument for a query -* [QueryArgs](crate::query_arg::QueryArgs): a tuple of query arguments -* [Queryable](crate::queryable::Queryable): for the Queryable derive macro -* [QueryResult]: single result from a query (scalars and tuples) - -The Value enum: - -```rust,ignore -pub enum Value { - Nothing, - Uuid(Uuid), - Str(String), - Bytes(Bytes), - Int16(i16), - Int32(i32), - Int64(i64), - Float32(f32), - Float64(f64), - BigInt(BigInt), - ConfigMemory(ConfigMemory), - Decimal(Decimal), - Bool(bool), - Datetime(Datetime), - LocalDatetime(LocalDatetime), - LocalDate(LocalDate), - LocalTime(LocalTime), - Duration(Duration), - RelativeDuration(RelativeDuration), - DateDuration(DateDuration), - Json(Json), - Set(Vec), - Object { - shape: ObjectShape, - fields: Vec>, - }, - SparseObject(SparseObject), - Tuple(Vec), - NamedTuple { - shape: NamedTupleShape, - fields: Vec, - }, - Array(Vec), - Enum(EnumValue), - Range(Range>), -} -``` +This crate has been renamed to [gel-protocol](https://crates.io/crates/gel-protocol). */ -mod query_result; // sealed trait should remain non-public - -pub mod client_message; -pub mod codec; -pub mod common; -pub mod descriptors; -pub mod encoding; -pub mod error_response; -pub mod errors; -pub mod features; -pub mod queryable; -pub mod serialization; -pub mod server_message; -pub mod value; -#[macro_use] -pub mod value_opt; -pub mod annotations; -pub mod model; -pub mod query_arg; - -pub use query_result::QueryResult; +compile_error!("edgedb-protocol has been renamed to gel-protocol"); \ No newline at end of file diff --git a/edgedb-protocol/src/model.rs b/edgedb-protocol/src/model.rs deleted file mode 100644 index 6930b45e..00000000 --- a/edgedb-protocol/src/model.rs +++ /dev/null @@ -1,81 +0,0 @@ -//! # EdgeDB Types Used for Data Modelling - -mod bignum; -mod json; -mod memory; -mod time; -mod vector; - -pub(crate) mod range; - -pub use self::bignum::{BigInt, Decimal}; -pub use self::json::Json; -pub use self::time::{DateDuration, RelativeDuration}; -pub use self::time::{Datetime, Duration, LocalDate, LocalDatetime, LocalTime}; -pub use memory::ConfigMemory; -pub use range::Range; -pub use uuid::Uuid; -pub use vector::Vector; - -use std::fmt; -use std::num::ParseIntError; - -/// Error converting an out of range value to/from EdgeDB type. -#[derive(Debug, PartialEq)] -pub struct OutOfRangeError; - -impl std::error::Error for OutOfRangeError {} -impl fmt::Display for OutOfRangeError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - "value is out of range".fmt(f) - } -} - -impl From for OutOfRangeError { - fn from(_: std::num::TryFromIntError) -> OutOfRangeError { - OutOfRangeError - } -} - -/// Error parsing string into EdgeDB Duration type. -#[derive(Debug, PartialEq)] -pub struct ParseDurationError { - pub(crate) message: String, - pub(crate) pos: usize, - pub(crate) is_final: bool, -} - -impl std::error::Error for ParseDurationError {} -impl fmt::Display for ParseDurationError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - format!( - "Error parsing input at position {}: {}", - self.pos, self.message, - ) - .fmt(f) - } -} - -impl From for ParseDurationError { - fn from(e: ParseIntError) -> Self { - Self::new(format!("{}", e)) - } -} - -impl ParseDurationError { - pub(crate) fn new(message: impl Into) -> Self { - Self { - pos: 0, - message: message.into(), - is_final: true, - } - } - pub(crate) fn not_final(mut self) -> Self { - self.is_final = false; - self - } - pub(crate) fn pos(mut self, value: usize) -> Self { - self.pos = value; - self - } -} diff --git a/edgedb-protocol/src/model/bignum.rs b/edgedb-protocol/src/model/bignum.rs deleted file mode 100644 index 7aab644a..00000000 --- a/edgedb-protocol/src/model/bignum.rs +++ /dev/null @@ -1,421 +0,0 @@ -use std::fmt::Write; - -#[cfg(feature = "num-bigint")] -mod num_bigint_interop; - -#[cfg(feature = "bigdecimal")] -mod bigdecimal_interop; - -/// Virtually unlimited precision integer. -/// -/// See EdgeDB [protocol documentation](https://docs.edgedb.com/database/reference/protocol/dataformats#std-bigint). -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct BigInt { - pub(crate) negative: bool, - pub(crate) weight: i16, - pub(crate) digits: Vec, -} - -/// High-precision decimal number. -/// -/// See EdgeDB [protocol documentation](https://docs.edgedb.com/database/reference/protocol/dataformats#std-decimal). -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Decimal { - pub(crate) negative: bool, - pub(crate) weight: i16, - pub(crate) decimal_digits: u16, - pub(crate) digits: Vec, -} - -impl BigInt { - pub fn negative(&self) -> bool { - self.negative - } - - pub fn weight(&self) -> i16 { - self.weight - } - - pub fn digits(&self) -> &[u16] { - &self.digits - } - - fn normalize(mut self) -> BigInt { - while let Some(0) = self.digits.last() { - self.digits.pop(); - } - while let Some(0) = self.digits.first() { - self.digits.remove(0); - self.weight -= 1; - } - self - } - - fn trailing_zero_groups(&self) -> i16 { - self.weight - self.digits.len() as i16 + 1 - } -} - -impl std::fmt::Display for BigInt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.negative { - write!(f, "-")?; - } - if let Some(digit) = self.digits.first() { - write!(f, "{}", digit)?; - for digit in &mut self.digits.iter().skip(1) { - write!(f, "{:04}", digit)?; - } - let trailing_zero_groups = self.trailing_zero_groups(); - debug_assert!(trailing_zero_groups >= 0); - for _ in 0..trailing_zero_groups { - write!(f, "0000")?; - } - } else { - write!(f, "0")?; - } - Ok(()) - } -} - -impl From for BigInt { - fn from(v: u64) -> BigInt { - BigInt { - negative: false, - weight: 4, - digits: vec![ - (v / 10_000_000_000_000_000 % 10000) as u16, - (v / 1_000_000_000_000 % 10000) as u16, - (v / 100_000_000 % 10000) as u16, - (v / 10000 % 10000) as u16, - (v % 10000) as u16, - ], - } - .normalize() - } -} - -impl From for BigInt { - fn from(v: i64) -> BigInt { - let (abs, negative) = if v < 0 { - (u64::MAX - v as u64 + 1, true) - } else { - (v as u64, false) - }; - BigInt { - negative, - weight: 4, - digits: vec![ - (abs / 10_000_000_000_000_000 % 10000) as u16, - (abs / 1_000_000_000_000 % 10000) as u16, - (abs / 100_000_000 % 10000) as u16, - (abs / 10000 % 10000) as u16, - (abs % 10000) as u16, - ], - } - .normalize() - } -} - -impl From for BigInt { - fn from(v: u32) -> BigInt { - BigInt { - negative: false, - weight: 2, - digits: vec![ - (v / 100_000_000) as u16, - (v / 10000 % 10000) as u16, - (v % 10000) as u16, - ], - } - .normalize() - } -} - -impl From for BigInt { - fn from(v: i32) -> BigInt { - let (abs, negative) = if v < 0 { - (u32::MAX - v as u32 + 1, true) - } else { - (v as u32, false) - }; - BigInt { - negative, - weight: 2, - digits: vec![ - (abs / 100_000_000) as u16, - (abs / 10000 % 10000) as u16, - (abs % 10000) as u16, - ], - } - .normalize() - } -} - -impl Decimal { - pub fn negative(&self) -> bool { - self.negative - } - - pub fn weight(&self) -> i16 { - self.weight - } - - pub fn decimal_digits(&self) -> u16 { - self.decimal_digits - } - - pub fn digits(&self) -> &[u16] { - &self.digits - } - - #[allow(dead_code)] // isn't used when BigDecimal is disabled - fn normalize(mut self) -> Decimal { - while let Some(0) = self.digits.last() { - self.digits.pop(); - } - while let Some(0) = self.digits.first() { - self.digits.remove(0); - self.weight -= 1; - } - self - } -} - -impl std::fmt::Display for Decimal { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.negative { - write!(f, "-")?; - } - - let mut index = 0; - - // integer part - while self.weight - index >= 0 { - if let Some(digit) = self.digits.get(index as usize) { - if index == 0 { - write!(f, "{}", digit)?; - } else { - write!(f, "{:04}", digit)?; - } - index += 1; - } else { - break; - } - } - - // trailing zeros of the integer part - for _ in 0..(self.weight - index + 1) { - f.write_str("0000")?; - } - - if index == 0 { - write!(f, "0")?; - } - - // dot - write!(f, ".")?; - - // leading zeros of the decimal part - let mut decimals = u16::max(self.decimal_digits, 1); - if index == 0 && self.weight < 0 { - for _ in 0..(-1 - self.weight) { - f.write_str("0000")?; - decimals -= 4; - } - } - - while decimals > 0 { - if let Some(digit) = self.digits.get(index as usize) { - let digit = format!("{digit:04}"); - let consumed = u16::min(4, decimals); - f.write_str(&digit[0..consumed as usize])?; - decimals -= consumed; - index += 1; - } else { - break; - } - } - // trailing zeros - for _ in 0..decimals { - f.write_char('0')?; - } - Ok(()) - } -} - -#[cfg(test)] -#[allow(dead_code)] // used by optional tests -mod test_helpers { - use rand::Rng; - - pub fn gen_u64(rng: &mut T) -> u64 { - // change distribution to generate different length more frequently - let max = 10_u64.pow(rng.gen_range(0..20)); - rng.gen_range(0..max) - } - - pub fn gen_i64(rng: &mut T) -> i64 { - // change distribution to generate different length more frequently - let max = 10_i64.pow(rng.gen_range(0..19)); - rng.gen_range(-max..max) - } - - pub fn gen_f64(rng: &mut T) -> f64 { - rng.gen::() - } -} - -#[cfg(test)] -#[allow(unused_imports)] // because of optional tests -mod test { - use super::{BigInt, Decimal}; - use std::convert::TryFrom; - use std::str::FromStr; - - #[test] - fn big_int_conversion() { - assert_eq!(BigInt::from(125u32).weight, 0); - assert_eq!(&BigInt::from(125u32).digits, &[125]); - assert_eq!(BigInt::from(30000u32).weight, 1); - assert_eq!(&BigInt::from(30000u32).digits, &[3]); - assert_eq!(BigInt::from(30001u32).weight, 1); - assert_eq!(&BigInt::from(30001u32).digits, &[3, 1]); - assert_eq!(BigInt::from(u32::MAX).weight, 2); - assert_eq!(BigInt::from(u32::MAX).digits, &[42, 9496, 7295]); - - assert_eq!(BigInt::from(125i32).weight, 0); - assert_eq!(&BigInt::from(125i32).digits, &[125]); - assert_eq!(BigInt::from(30000i32).weight, 1); - assert_eq!(&BigInt::from(30000i32).digits, &[3]); - assert_eq!(BigInt::from(30001i32).weight, 1); - assert_eq!(&BigInt::from(30001i32).digits, &[3, 1]); - assert_eq!(BigInt::from(i32::MAX).weight, 2); - assert_eq!(BigInt::from(i32::MAX).digits, &[21, 4748, 3647]); - - assert_eq!(BigInt::from(-125i32).weight, 0); - assert_eq!(&BigInt::from(-125i32).digits, &[125]); - assert_eq!(BigInt::from(-30000i32).weight, 1); - assert_eq!(&BigInt::from(-30000i32).digits, &[3]); - assert_eq!(BigInt::from(-30001i32).weight, 1); - assert_eq!(&BigInt::from(-30001i32).digits, &[3, 1]); - assert_eq!(BigInt::from(i32::MIN).weight, 2); - assert_eq!(BigInt::from(i32::MIN).digits, &[21, 4748, 3648]); - - assert_eq!(BigInt::from(125u64).weight, 0); - assert_eq!(&BigInt::from(125u64).digits, &[125]); - assert_eq!(BigInt::from(30000u64).weight, 1); - assert_eq!(&BigInt::from(30000u64).digits, &[3]); - assert_eq!(BigInt::from(30001u64).weight, 1); - assert_eq!(&BigInt::from(30001u64).digits, &[3, 1]); - assert_eq!(BigInt::from(u64::MAX).weight, 4); - assert_eq!(BigInt::from(u64::MAX).digits, &[1844, 6744, 737, 955, 1615]); - - assert_eq!(BigInt::from(125i64).weight, 0); - assert_eq!(&BigInt::from(125i64).digits, &[125]); - assert_eq!(BigInt::from(30000i64).weight, 1); - assert_eq!(&BigInt::from(30000i64).digits, &[3]); - assert_eq!(BigInt::from(30001i64).weight, 1); - assert_eq!(&BigInt::from(30001i64).digits, &[3, 1]); - assert_eq!(BigInt::from(i64::MAX).weight, 4); - assert_eq!(BigInt::from(i64::MAX).digits, &[922, 3372, 368, 5477, 5807]); - - assert_eq!(BigInt::from(-125i64).weight, 0); - assert_eq!(&BigInt::from(-125i64).digits, &[125]); - assert_eq!(BigInt::from(-30000i64).weight, 1); - assert_eq!(&BigInt::from(-30000i64).digits, &[3]); - assert_eq!(BigInt::from(-30001i64).weight, 1); - assert_eq!(&BigInt::from(-30001i64).digits, &[3, 1]); - assert_eq!(BigInt::from(i64::MIN).weight, 4); - assert_eq!(BigInt::from(i64::MIN).digits, &[922, 3372, 368, 5477, 5808]); - } - - #[test] - fn bigint_display() { - let cases = [0, 1, -1, 1_0000, -1_0000, 1_2345_6789, i64::MAX, i64::MIN]; - for i in cases.iter() { - assert_eq!(BigInt::from(*i).to_string(), i.to_string()); - } - } - - #[test] - fn bigint_display_rand() { - use rand::{rngs::StdRng, Rng, SeedableRng}; - let mut rng = StdRng::seed_from_u64(4); - for _ in 0..1000 { - let i = super::test_helpers::gen_i64(&mut rng); - assert_eq!(BigInt::from(i).to_string(), i.to_string()); - } - } - - #[test] - fn decimal_display() { - assert_eq!( - Decimal { - negative: false, - weight: 0, - decimal_digits: 0, - digits: vec![42], - } - .to_string(), - "42.0" - ); - - assert_eq!( - Decimal { - negative: false, - weight: 0, - decimal_digits: 10, - digits: vec![42], - } - .to_string(), - "42.0000000000" - ); - - assert_eq!( - Decimal { - negative: true, - weight: 0, - decimal_digits: 1, - digits: vec![42], - } - .to_string(), - "-42.0" - ); - - assert_eq!( - Decimal { - negative: false, - weight: 1, - decimal_digits: 10, - digits: vec![42], - } - .to_string(), - "420000.0000000000" - ); - - assert_eq!( - Decimal { - negative: false, - weight: -2, - decimal_digits: 10, - digits: vec![42], - } - .to_string(), - "0.0000004200" - ); - - assert_eq!( - Decimal { - negative: false, - weight: -6, - decimal_digits: 21, - digits: vec![1000], - } - .to_string(), - "0.000000000000000000001" - ); - } -} diff --git a/edgedb-protocol/src/model/bignum/bigdecimal_interop.rs b/edgedb-protocol/src/model/bignum/bigdecimal_interop.rs deleted file mode 100644 index bc5106ea..00000000 --- a/edgedb-protocol/src/model/bignum/bigdecimal_interop.rs +++ /dev/null @@ -1,402 +0,0 @@ -use super::Decimal; -use crate::model::OutOfRangeError; - -impl std::convert::TryFrom for Decimal { - type Error = OutOfRangeError; - fn try_from(dec: bigdecimal::BigDecimal) -> Result { - use num_traits::{ToPrimitive, Zero}; - use std::cmp::max; - use std::convert::TryInto; - - let mut digits = Vec::new(); - let (v, scale) = dec.into_bigint_and_exponent(); - let (negative, mut val) = match v.sign() { - num_bigint::Sign::Minus => (true, -v), - num_bigint::Sign::NoSign => (false, v), - num_bigint::Sign::Plus => (false, v), - }; - let scale_4digits = if scale < 0 { scale / 4 } else { scale / 4 + 1 }; - let pad = scale_4digits * 4 - scale; - - if pad > 0 { - val *= 10u16.pow(pad as u32); - } - while !val.is_zero() { - digits.push((&val % 10000u16).to_u16().unwrap()); - val /= 10000; - } - digits.reverse(); - - // These return "out of range integral type conversion attempted" - // which should be good enough for this error - let decimal_digits = max(0, scale).try_into()?; - let weight = i16::try_from(digits.len() as i64 - scale_4digits - 1)?; - - // TODO(tailhook) normalization can be optimized here - Ok(Decimal { - negative, - weight, - decimal_digits, - digits, - } - .normalize()) - } -} - -impl From for bigdecimal::BigDecimal { - fn from(v: Decimal) -> bigdecimal::BigDecimal { - (&v).into() - } -} - -impl From<&Decimal> for bigdecimal::BigDecimal { - fn from(val: &Decimal) -> bigdecimal::BigDecimal { - use bigdecimal::BigDecimal; - use num_bigint::BigInt; - use num_traits::pow; - use std::cmp::max; - if val.digits.is_empty() { - return BigDecimal::from(0); - } - - let mut r = BigInt::from(0); - // TODO(tailhook) this is quite slow, use preallocated vector - for &digit in &val.digits { - r *= 10000; - r += digit; - } - let decimal_stored = 4 * max(0, val.digits.len() as i64 - val.weight as i64 - 1) as usize; - let pad = if decimal_stored > 0 { - let pad = decimal_stored as i64 - val.decimal_digits as i64; - match pad { - 1.. => { - r /= pow(10, pad as usize); - } - 0 => {} - ..=-1 => { - r *= pow(10, (-pad) as usize); - } - } - - pad - } else { - 0 - }; - - let scale = if val.decimal_digits == 0 { - -(val.weight as i64 + 1 - val.digits.len() as i64) * 4 - pad as i64 - } else { - if decimal_stored == 0 { - let power = - (val.weight as usize + 1 - val.digits.len()) * 4 + val.decimal_digits as usize; - if power > 0 { - r *= pow(BigInt::from(10), power); - } - } - val.decimal_digits as i64 - }; - if val.negative { - r = -r; - } - BigDecimal::new(r, scale) - } -} - -#[cfg(test)] -mod test { - use super::super::test_helpers::{gen_i64, gen_u64}; - use super::Decimal; - use bigdecimal::BigDecimal; - use rand::{rngs::StdRng, Rng, SeedableRng}; - use std::convert::TryFrom; - use std::str::FromStr; - - #[test] - fn decimal_conversion() -> Result<(), Box> { - let x = Decimal::try_from(BigDecimal::from_str("42.00")?)?; - assert_eq!(x.weight, 0); - assert_eq!(x.decimal_digits, 2); - assert_eq!(x.digits, &[42]); - let x = Decimal::try_from(BigDecimal::from_str("42.07")?)?; - assert_eq!(x.weight, 0); - assert_eq!(x.decimal_digits, 2); - assert_eq!(x.digits, &[42, 700]); - let x = Decimal::try_from(BigDecimal::from_str("0.07")?)?; - assert_eq!(x.weight, -1); - assert_eq!(x.decimal_digits, 2); - assert_eq!(x.digits, &[700]); - let x = Decimal::try_from(BigDecimal::from_str("420000.00")?)?; - assert_eq!(x.weight, 1); - assert_eq!(x.decimal_digits, 2); - assert_eq!(x.digits, &[42]); - - let x = Decimal::try_from(BigDecimal::from_str("-42.00")?)?; - assert_eq!(x.weight, 0); - assert_eq!(x.decimal_digits, 2); - assert_eq!(x.digits, &[42]); - let x = Decimal::try_from(BigDecimal::from_str("-42.07")?)?; - assert_eq!(x.weight, 0); - assert_eq!(x.decimal_digits, 2); - assert_eq!(x.digits, &[42, 700]); - let x = Decimal::try_from(BigDecimal::from_str("-0.07")?)?; - assert_eq!(x.weight, -1); - assert_eq!(x.decimal_digits, 2); - assert_eq!(x.digits, &[700]); - let x = Decimal::try_from(BigDecimal::from_str( - "10000000000000000000000000000000000000.00000", - )?)?; - assert_eq!(x.digits, &[10]); - assert_eq!(x.weight, 9); - assert_eq!(x.decimal_digits, 5); - let x = Decimal::try_from(BigDecimal::from_str("1e100")?)?; - assert_eq!(x.weight, 25); - assert_eq!(x.decimal_digits, 0); - assert_eq!(x.digits, &[1]); - let x = Decimal::try_from(BigDecimal::from_str( - "-703367234220692490200000000000000000000000000", - )?)?; - assert_eq!(x.weight, 11); - assert_eq!(x.decimal_digits, 0); - assert_eq!(x.digits, &[7, 336, 7234, 2206, 9249, 200]); - let x = Decimal::try_from(BigDecimal::from_str("-7033672342206924902e26")?)?; - assert_eq!(x.weight, 11); - assert_eq!(x.decimal_digits, 0); - assert_eq!(x.digits, &[7, 336, 7234, 2206, 9249, 200]); - - let x = Decimal::try_from(BigDecimal::from_str( - "6545218855030988517.14400196897187081925e47", - )?)?; - assert_eq!(x.weight, 16); - assert_eq!(x.decimal_digits, 0); - assert_eq!( - x.digits, - &[65, 4521, 8855, 309, 8851, 7144, 19, 6897, 1870, 8192, 5000] - ); - let x = Decimal::try_from(BigDecimal::from_str( - "-260399300000000000000000000000000000000000000.\ - 000000000007745502260", - )?)?; - assert_eq!(x.weight, 11); - assert_eq!(x.decimal_digits, 21); - assert_eq!( - x.digits, - &[ - 2, 6039, 9300, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // decimal digits start here - 0, 7, 7455, 226, - ] - ); - - Ok(()) - } - - #[test] - fn convert_special() { - let orig = Decimal { - negative: false, - weight: 0, - decimal_digits: 1, - digits: Vec::new(), - }; - let big: BigDecimal = orig.into(); - assert_eq!(big.to_string(), "0"); - let orig = Decimal { - negative: false, - weight: 0, - decimal_digits: 0, - digits: Vec::new(), - }; - let big: BigDecimal = orig.into(); - assert_eq!(big.to_string(), "0"); - } - - fn dec_roundtrip(s: &str) -> BigDecimal { - let rust = BigDecimal::from_str(s).expect("can parse big decimal"); - let edgedb = Decimal::try_from(rust).expect("can convert for edgedb"); - BigDecimal::from(edgedb) - } - - #[test] - fn decimal_roundtrip() -> Result<(), Box> { - use bigdecimal::BigDecimal as B; - - assert_eq!(dec_roundtrip("1"), B::from_str("1")?); - assert_eq!(dec_roundtrip("1000"), B::from_str("1000")?); - assert_eq!(dec_roundtrip("1e100"), B::from_str("1e100")?); - assert_eq!(dec_roundtrip("0"), B::from_str("0")?); - assert_eq!(dec_roundtrip("-1000"), B::from_str("-1000")?); - assert_eq!(dec_roundtrip("1.01"), B::from_str("1.01")?); - assert_eq!(dec_roundtrip("1000.0070"), B::from_str("1000.0070")?); - assert_eq!(dec_roundtrip("0.00008"), B::from_str("0.00008")?); - assert_eq!(dec_roundtrip("-1000.1"), B::from_str("-1000.1")?); - assert_eq!( - dec_roundtrip("10000000000000000000000000000000000000.00001"), - B::from_str("10000000000000000000000000000000000000.00001")? - ); - assert_eq!( - dec_roundtrip("12345678901234567890012345678901234567890123"), - B::from_str("12345678901234567890012345678901234567890123")? - ); - assert_eq!( - dec_roundtrip("1234567890123456789.012345678901234567890123"), - B::from_str("1234567890123456789.012345678901234567890123")? - ); - assert_eq!( - dec_roundtrip("0.000000000000000000000000000000000000017238"), - B::from_str("0.000000000000000000000000000000000000017238")? - ); - assert_eq!(dec_roundtrip("1234.00000"), B::from_str("1234.00000")?); - assert_eq!( - dec_roundtrip("10000000000000000000000000000000000000.00000"), - B::from_str("10000000000000000000000000000000000000.00000")? - ); - assert_eq!( - dec_roundtrip("100010001000000000000000000000000000"), - B::from_str("100010001000000000000000000000000000")? - ); - - Ok(()) - } - - #[test] - fn decimal_rand_i64() -> Result<(), Box> { - use bigdecimal::BigDecimal as B; - - let mut rng = StdRng::seed_from_u64(1); - for _ in 0..10000 { - let head = gen_u64(&mut rng); - let txt = format!("{}", head); - assert_eq!(dec_roundtrip(&txt), B::from_str(&txt)?, "parsing: {}", txt); - } - Ok(()) - } - - #[test] - fn decimal_rand_nulls() -> Result<(), Box> { - use bigdecimal::BigDecimal as B; - - let mut rng = StdRng::seed_from_u64(2); - for iter in 0..10000 { - let head = gen_u64(&mut rng); - let nulls = rng.gen_range(0..100); - let txt = format!("{0}{1:0<2$}", head, "", nulls); - assert_eq!( - dec_roundtrip(&txt), - B::from_str(&txt)?, - "parsing {}: {}", - iter, - txt - ); - } - Ok(()) - } - - #[test] - fn decimal_rand_eplus() -> Result<(), Box> { - use bigdecimal::BigDecimal as B; - - let mut rng = StdRng::seed_from_u64(3); - for iter in 0..10000 { - let head = gen_u64(&mut rng); - let nulls = rng.gen_range(-100..100); - let txt = format!("{}e{}", head, nulls); - assert_eq!( - dec_roundtrip(&txt), - B::from_str(&txt)?, - "parsing {}: {}", - iter, - txt - ); - } - Ok(()) - } - - #[test] - fn decimal_rand_fract_eplus() -> Result<(), Box> { - use bigdecimal::BigDecimal as B; - - let mut rng = StdRng::seed_from_u64(4); - for iter in 0..10000 { - let head = gen_i64(&mut rng); - let fract = gen_u64(&mut rng); - let nulls = rng.gen_range(-100..100); - let txt = format!("{}.{}e{}", head, fract, nulls); - let rt = dec_roundtrip(&txt); - let dec = if head == 0 && fract == 0 { - // Zeros are normalized - B::from(0) - } else { - B::from_str(&txt)? - }; - assert_eq!(rt, dec, "parsing {}: {}", iter, txt); - if dec.as_bigint_and_exponent().1 > 0 { - // check precision - // (if scale is negative it's integer, we don't have precision) - assert_eq!( - rt.as_bigint_and_exponent().1, - dec.as_bigint_and_exponent().1, - "precision: {}", - txt - ); - } - } - Ok(()) - } - - #[test] - fn decimal_rand_nulls_eplus() -> Result<(), Box> { - use bigdecimal::BigDecimal as B; - - let mut rng = StdRng::seed_from_u64(5); - for iter in 0..10000 { - let head = gen_i64(&mut rng); - let nulls1 = rng.gen_range(0..100); - let nulls2 = rng.gen_range(0..100); - let txt = format!("{0}{1:0<2$}e{3}", head, "", nulls1, nulls2); - let rt = dec_roundtrip(&txt); - let dec = B::from_str(&txt)?; - assert_eq!(rt, dec, "parsing {}: {}", iter, txt); - if dec.as_bigint_and_exponent().1 > 0 { - // check precision - // (if scale is negative it's integer, we don't have precision) - assert_eq!( - rt.as_bigint_and_exponent().1, - dec.as_bigint_and_exponent().1, - "precision: {}", - txt - ); - } - } - Ok(()) - } - - #[test] - fn decimal_rand_decim() -> Result<(), Box> { - use bigdecimal::BigDecimal as B; - - let mut rng = StdRng::seed_from_u64(6); - for iter in 0..10000 { - let head = gen_i64(&mut rng); - let nulls1 = rng.gen_range(0..100); - let nulls2 = rng.gen_range(0..100); - let decimals = gen_u64(&mut rng); - let txt = format!( - "{0}{1:0<2$}.{1:0<3$}{4}", - head, "", nulls1, nulls2, decimals - ); - let dec = if head == 0 && decimals == 0 { - // Zeros are normalized - B::from(0) - } else { - B::from_str(&txt)? - }; - assert_eq!(dec_roundtrip(&txt), dec, "parsing {}: {}", iter, txt); - assert_eq!( - dec_roundtrip(&txt).as_bigint_and_exponent().1, - dec.as_bigint_and_exponent().1, - "precision: {}", - txt - ); - } - Ok(()) - } -} diff --git a/edgedb-protocol/src/model/bignum/num_bigint_interop.rs b/edgedb-protocol/src/model/bignum/num_bigint_interop.rs deleted file mode 100644 index 432376a9..00000000 --- a/edgedb-protocol/src/model/bignum/num_bigint_interop.rs +++ /dev/null @@ -1,216 +0,0 @@ -use super::BigInt; -use crate::model::OutOfRangeError; - -impl std::convert::TryFrom for BigInt { - type Error = OutOfRangeError; - fn try_from(v: num_bigint::BigInt) -> Result { - use num_traits::{ToPrimitive, Zero}; - use std::convert::TryInto; - - if v.is_zero() { - return Ok(BigInt { - negative: false, - weight: 0, - digits: Vec::new(), - }); - } - - let mut digits = Vec::new(); - let (negative, mut val) = match v.sign() { - num_bigint::Sign::Minus => (true, -v), - num_bigint::Sign::NoSign => (false, v), - num_bigint::Sign::Plus => (false, v), - }; - while !val.is_zero() { - digits.push((&val % 10000u16).to_u16().unwrap()); - val /= 10000; - } - digits.reverse(); - - // This returns "out of range integral type conversion attempted" - // which should be good enough for this error - let weight = (digits.len() - 1).try_into()?; - - // TODO(tailhook) normalization can be optimized here - Ok(BigInt { - negative, - weight, - digits, - } - .normalize()) - } -} - -impl From for num_bigint::BigInt { - fn from(v: BigInt) -> num_bigint::BigInt { - (&v).into() - } -} - -impl From<&BigInt> for num_bigint::BigInt { - fn from(v: &BigInt) -> num_bigint::BigInt { - use num_bigint::BigInt; - use num_traits::pow; - - let mut r = BigInt::from(0); - for &digit in &v.digits { - r *= 10000; - r += digit; - } - if (v.weight + 1) as usize > v.digits.len() { - r *= pow( - BigInt::from(10000), - (v.weight + 1) as usize - v.digits.len(), - ); - } - if v.negative { - -r - } else { - r - } - } -} - -#[cfg(test)] -mod test { - use super::BigInt; - use std::convert::TryFrom; - use std::str::FromStr; - - #[test] - fn big_big_int_conversion() -> Result<(), Box> { - let x = BigInt::try_from(num_bigint::BigInt::from_str( - "10000000000000000000000000000000000000", - )?)?; - assert_eq!(x.weight, 9); - assert_eq!(&x.digits, &[10]); - Ok(()) - } -} - -// conceptually these tests work on BigInt, but depend on the bigdecimal feature -#[cfg(all(test, feature = "bigdecimal"))] -mod test_with_decimal { - use super::super::test_helpers::gen_i64; - use super::BigInt; - use bigdecimal::BigDecimal; - use num_bigint::ToBigInt; - use rand::{rngs::StdRng, Rng, SeedableRng}; - use std::convert::TryFrom; - use std::str::FromStr; - - #[test] - fn bigint_conversion() -> Result<(), Box> { - let x = BigInt::try_from(BigDecimal::from_str("1e20")?.to_bigint().unwrap())?; - assert_eq!(x.weight, 5); - assert_eq!(x.digits, &[1]); - Ok(()) - } - - fn int_roundtrip(s: &str) -> num_bigint::BigInt { - let decimal = BigDecimal::from_str(s).expect("can parse decimal"); - let rust = decimal.to_bigint().expect("can convert to big int"); - let edgedb = BigInt::try_from(rust).expect("can convert for edgedb"); - num_bigint::BigInt::from(edgedb) - } - - #[test] - fn big_int_roundtrip() -> Result<(), Box> { - use num_bigint::BigInt as N; - - assert_eq!(int_roundtrip("1"), N::from_str("1")?); - assert_eq!(int_roundtrip("1000"), N::from_str("1000")?); - assert_eq!(int_roundtrip("1e20"), N::from_str("100000000000000000000")?); - assert_eq!(int_roundtrip("0"), N::from_str("0")?); - assert_eq!(int_roundtrip("-1000"), N::from_str("-1000")?); - assert_eq!( - int_roundtrip("10000000000000000000000000000000000000000000"), - N::from_str("10000000000000000000000000000000000000000000")? - ); - assert_eq!( - int_roundtrip("12345678901234567890012345678901234567890123"), - N::from_str("12345678901234567890012345678901234567890123")? - ); - assert_eq!( - int_roundtrip("10000000000000000000000000000000000000"), - N::from_str("10000000000000000000000000000000000000")? - ); - Ok(()) - } - - #[test] - fn int_rand_i64() -> Result<(), Box> { - use num_bigint::BigInt as B; - - let mut rng = StdRng::seed_from_u64(7); - for _ in 0..10000 { - let head = gen_i64(&mut rng); - let txt = format!("{}", head); - assert_eq!(int_roundtrip(&txt), B::from_str(&txt)?, "parsing: {}", txt); - } - Ok(()) - } - - #[test] - fn int_rand_nulls() -> Result<(), Box> { - use num_bigint::BigInt as B; - - let mut rng = StdRng::seed_from_u64(8); - for iter in 0..10000 { - let head = gen_i64(&mut rng); - let nulls = rng.gen_range(0..100); - let txt = format!("{0}{1:0<2$}", head, "", nulls); - assert_eq!( - int_roundtrip(&txt), - B::from_str(&txt)?, - "parsing {}: {}", - iter, - txt - ); - } - Ok(()) - } - - #[test] - fn int_rand_eplus() -> Result<(), Box> { - use num_bigint::BigInt as B; - - let mut rng = StdRng::seed_from_u64(9); - for iter in 0..10000 { - let head = gen_i64(&mut rng); - let nulls = rng.gen_range(0..100); - let edb = format!("{}e{}", head, nulls); - let bigint = format!("{}{1:0<2$}", head, "", nulls); - assert_eq!( - int_roundtrip(&edb), - B::from_str(&bigint)?, - "parsing {}: {}", - iter, - edb - ); - } - Ok(()) - } - - #[test] - fn int_rand_nulls_eplus() -> Result<(), Box> { - use num_bigint::BigInt as B; - - let mut rng = StdRng::seed_from_u64(10); - for iter in 0..10000 { - let head = gen_i64(&mut rng); - let nulls1 = rng.gen_range(0..100); - let nulls2 = rng.gen_range(0..100); - let edb = format!("{0}{1:0<2$}e{3}", head, "", nulls1, nulls2); - let bigint = format!("{}{1:0<2$}", head, "", nulls1 + nulls2); - assert_eq!( - int_roundtrip(&edb), - B::from_str(&bigint)?, - "parsing {}: {}", - iter, - edb - ); - } - Ok(()) - } -} diff --git a/edgedb-protocol/src/model/json.rs b/edgedb-protocol/src/model/json.rs deleted file mode 100644 index d0dbeeaa..00000000 --- a/edgedb-protocol/src/model/json.rs +++ /dev/null @@ -1,44 +0,0 @@ -/// A newtype for JSON received from the database -#[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Json(String); - -impl Json { - /// Create a JSON value without checking the contents. - /// - /// Two examples of use: - /// - /// 1) To construct values with the data received from the - /// database, because we trust database to produce valid JSON. - /// - /// 2) By client users who are using data that is guaranteed - /// to be valid JSON. If unsure, using a method such as serde_json's - /// [to_string](https://docs.rs/serde_json/latest/serde_json/ser/fn.to_string.html) - /// to construct a String is highly recommended. - /// - /// When used in a client query method, EdgeDB itself will recognize if the - /// String inside `Json` is invalid JSON by returning `InvalidValueError: - /// invalid input syntax for type json`. - pub fn new_unchecked(value: String) -> Json { - Json(value) - } -} - -impl AsRef for Json { - fn as_ref(&self) -> &str { - &self.0 - } -} - -impl std::ops::Deref for Json { - type Target = str; - fn deref(&self) -> &str { - &self.0 - } -} - -impl From for String { - fn from(val: Json) -> Self { - val.0 - } -} diff --git a/edgedb-protocol/src/model/memory.rs b/edgedb-protocol/src/model/memory.rs deleted file mode 100644 index fcd89e52..00000000 --- a/edgedb-protocol/src/model/memory.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::fmt::{Debug, Display}; - -/// A type for cfg::memory received from the database -#[derive(Copy, Debug, Clone, PartialEq)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct ConfigMemory(pub i64); - -impl ConfigMemory {} - -static KIB: i64 = 1024; -static MIB: i64 = 1024 * KIB; -static GIB: i64 = 1024 * MIB; -static TIB: i64 = 1024 * GIB; - -impl Display for ConfigMemory { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // Use the same rendering logic we have in EdgeDB server - // to cast cfg::memory to std::str. - let v = self.0; - if v >= TIB && v % TIB == 0 { - write!(f, "{}TiB", v / TIB) - } else if v >= GIB && v % GIB == 0 { - write!(f, "{}GiB", v / GIB) - } else if v >= MIB && v % MIB == 0 { - write!(f, "{}MiB", v / MIB) - } else if v >= KIB && v % KIB == 0 { - write!(f, "{}KiB", v / KIB) - } else { - write!(f, "{}B", v) - } - } -} diff --git a/edgedb-protocol/src/model/range.rs b/edgedb-protocol/src/model/range.rs deleted file mode 100644 index 35a9c9f1..00000000 --- a/edgedb-protocol/src/model/range.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::value::Value; - -pub(crate) const EMPTY: usize = 0x01; -pub(crate) const LB_INC: usize = 0x02; -pub(crate) const UB_INC: usize = 0x04; -pub(crate) const LB_INF: usize = 0x08; -pub(crate) const UB_INF: usize = 0x10; - -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Range { - pub(crate) lower: Option, - pub(crate) upper: Option, - pub(crate) inc_lower: bool, - pub(crate) inc_upper: bool, - pub(crate) empty: bool, -} - -impl From> for Range { - fn from(src: std::ops::Range) -> Range { - Range { - lower: Some(src.start), - upper: Some(src.end), - inc_lower: true, - inc_upper: false, - empty: false, - } - } -} - -impl> From> for Value { - fn from(src: std::ops::Range) -> Value { - Range::from(src).into_value() - } -} - -impl Range { - /// Constructor of the empty range - pub fn empty() -> Range { - Range { - lower: None, - upper: None, - inc_lower: true, - inc_upper: false, - empty: true, - } - } - pub fn lower(&self) -> Option<&T> { - self.lower.as_ref() - } - pub fn upper(&self) -> Option<&T> { - self.upper.as_ref() - } - pub fn inc_lower(&self) -> bool { - self.inc_lower - } - pub fn inc_upper(&self) -> bool { - self.inc_upper - } - pub fn is_empty(&self) -> bool { - self.empty - } -} - -impl> Range { - pub fn into_value(self) -> Value { - Value::Range(Range { - lower: self.lower.map(|v| Box::new(v.into())), - upper: self.upper.map(|v| Box::new(v.into())), - inc_lower: self.inc_lower, - inc_upper: self.inc_upper, - empty: self.empty, - }) - } -} diff --git a/edgedb-protocol/src/model/time.rs b/edgedb-protocol/src/model/time.rs deleted file mode 100644 index 67129e69..00000000 --- a/edgedb-protocol/src/model/time.rs +++ /dev/null @@ -1,1898 +0,0 @@ -use crate::model::{OutOfRangeError, ParseDurationError}; -use std::convert::{TryFrom, TryInto}; -use std::fmt::{self, Debug, Display}; -use std::str::FromStr; -use std::time::{SystemTime, UNIX_EPOCH}; - -/// A span of time. -/// -/// Precision: microseconds. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Duration { - pub(crate) micros: i64, -} - -/// A combination [`LocalDate`] and [`LocalTime`]. -#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct LocalDatetime { - pub(crate) micros: i64, -} - -/// Naive date without a timezone. -#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct LocalDate { - pub(crate) days: i32, -} - -/// Naive time without a timezone. -/// -/// Can't be more than 24 hours. -/// -/// Precision: microseconds. -#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct LocalTime { - pub(crate) micros: u64, -} - -/// A UTC date and time. -#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Datetime { - pub(crate) micros: i64, -} - -/// A type that can represent a human-friendly duration like 1 month or two days. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct RelativeDuration { - pub(crate) micros: i64, - pub(crate) days: i32, - pub(crate) months: i32, -} - -/// A type that can represent a human-friendly date duration like 1 month or two days. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct DateDuration { - pub(crate) days: i32, - pub(crate) months: i32, -} - -const SECS_PER_DAY: u64 = 86_400; -const MICROS_PER_DAY: u64 = SECS_PER_DAY * 1_000_000; - -// leap years repeat every 400 years -const DAYS_IN_400_YEARS: u32 = 400 * 365 + 97; - -const MIN_YEAR: i32 = 1; -const MAX_YEAR: i32 = 9999; - -// year -4800 is a multiple of 400 smaller than the minimum supported year -const BASE_YEAR: i32 = -4800; - -#[allow(dead_code)] // only used by specific features -const DAYS_IN_2000_YEARS: i32 = 5 * DAYS_IN_400_YEARS as i32; - -const DAY_TO_MONTH_365: [u32; 13] = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365]; -const DAY_TO_MONTH_366: [u32; 13] = [0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366]; - -const MICROS_PER_MS: i64 = 1_000; -const MICROS_PER_SECOND: i64 = MICROS_PER_MS * 1_000; -const MICROS_PER_MINUTE: i64 = MICROS_PER_SECOND * 60; -const MICROS_PER_HOUR: i64 = MICROS_PER_MINUTE * 60; - -impl Duration { - pub const MIN: Duration = Duration { micros: i64::MIN }; - pub const MAX: Duration = Duration { micros: i64::MAX }; - - pub fn from_micros(micros: i64) -> Duration { - Duration { micros } - } - - pub fn to_micros(self) -> i64 { - self.micros - } - - // Returns true if self is positive and false if the duration - // is zero or negative. - pub fn is_positive(&self) -> bool { - self.micros.is_positive() - } - // Returns true if self is negative and false if the duration - // is zero or positive. - pub fn is_negative(&self) -> bool { - self.micros.is_negative() - } - // Returns absolute values as stdlib's duration - // - // Note: `std::time::Duration` can't be negative - pub fn abs_duration(&self) -> std::time::Duration { - if self.micros.is_negative() { - std::time::Duration::from_micros(u64::MAX - self.micros as u64 + 1) - } else { - std::time::Duration::from_micros(self.micros as u64) - } - } - - fn try_from_pg_simple_format(input: &str) -> Result { - let mut split = input.trim_end().splitn(3, ':'); - let mut value: i64 = 0; - let negative; - let mut pos: usize = 0; - - { - let hour_str = split.next().filter(|s| !s.is_empty()).ok_or_else(|| { - ParseDurationError::new("EOF met, expecting `+`, `-` or int") - .not_final() - .pos(input.len()) - })?; - pos += hour_str.len() - 1; - let hour_str = hour_str.trim_start(); - let hour = hour_str - .strip_prefix('-') - .unwrap_or(hour_str) - .parse::() - .map_err(|e| ParseDurationError::from(e).not_final().pos(pos))?; - negative = hour_str.starts_with('-'); - value += (hour.abs() as i64) * MICROS_PER_HOUR; - } - - { - pos += 1; - let minute_str = split.next().ok_or_else(|| { - ParseDurationError::new("EOF met, expecting `:`") - .not_final() - .pos(pos) - })?; - if !minute_str.is_empty() { - pos += minute_str.len(); - let minute = minute_str - .parse::() - .map_err(|e| ParseDurationError::from(e).pos(pos)) - .and_then(|m| { - if m <= 59 { - Ok(m) - } else { - Err(ParseDurationError::new("minutes value out of range").pos(pos)) - } - })?; - value += (minute as i64) * MICROS_PER_MINUTE; - } - } - - if let Some(remaining) = split.last() { - pos += 1; - let mut sec_split = remaining.splitn(2, '.'); - - { - let second_str = sec_split.next().unwrap(); - pos += second_str.len(); - let second = second_str - .parse::() - .map_err(|e| ParseDurationError::from(e).pos(pos)) - .and_then(|s| { - if s <= 59 { - Ok(s) - } else { - Err(ParseDurationError::new("seconds value out of range").pos(pos)) - } - })?; - value += (second as i64) * MICROS_PER_SECOND; - } - - if let Some(sub_sec_str) = sec_split.last() { - pos += 1; - for (i, c) in sub_sec_str.char_indices() { - let d = c - .to_digit(10) - .ok_or_else(|| ParseDurationError::new("not a digit").pos(pos + i + 1))?; - if i < 6 { - value += (d * 10_u32.pow((5 - i) as u32)) as i64; - } else { - if d >= 5 { - value += 1; - } - break; - } - } - } - } - - if negative { - value = -value; - } - Ok(Self { micros: value }) - } - - fn try_from_iso_format(input: &str) -> Result { - if let Some(input) = input.strip_prefix("PT") { - let mut pos = 2; - let mut result = 0; - let mut parts = input.split_inclusive(|c: char| c.is_alphabetic()); - let mut current = parts.next(); - - if let Some(part) = current { - if let Some(hour_str) = part.strip_suffix('H') { - let hour = hour_str - .parse::() - .map_err(|e| ParseDurationError::from(e).pos(pos))?; - result += (hour as i64) * MICROS_PER_HOUR; - pos += part.len(); - current = parts.next(); - } - } - - if let Some(part) = current { - if let Some(minute_str) = part.strip_suffix('M') { - let minute = minute_str - .parse::() - .map_err(|e| ParseDurationError::from(e).pos(pos))?; - result += (minute as i64) * MICROS_PER_MINUTE; - pos += part.len(); - current = parts.next(); - } - } - - if let Some(part) = current { - if let Some(second_str) = part.strip_suffix('S') { - let (second_str, subsec_str) = second_str - .split_once('.') - .map(|(sec, sub)| (sec, sub.get(..6).or(Some(sub)))) - .unwrap_or_else(|| (second_str, None)); - - let second = second_str - .parse::() - .map_err(|e| ParseDurationError::from(e).pos(pos))?; - result += (second as i64) * MICROS_PER_SECOND; - pos += second_str.len() + 1; - - if let Some(subsec_str) = subsec_str { - let subsec = subsec_str - .parse::() - .map_err(|e| ParseDurationError::from(e).pos(pos))?; - result += (subsec as i64) - * 10_i64.pow((6 - subsec_str.len()) as u32) - * if second < 0 { -1 } else { 1 }; - pos += subsec_str.len() - } - current = parts.next(); - } - } - - if current.is_some() { - Err(ParseDurationError::new("expecting EOF").pos(pos)) - } else { - Ok(Self { micros: result }) - } - } else { - Err(ParseDurationError::new("not ISO format").not_final()) - } - } - - fn get_pg_format_value( - input: &str, - start: usize, - end: usize, - ) -> Result { - if let Some(val) = input.get(start..end) { - match val.parse::() { - Ok(v) => Ok(v as i64), - Err(e) => Err(ParseDurationError::from(e).pos(end.saturating_sub(1))), - } - } else { - Err(ParseDurationError::new("expecting value").pos(end)) - } - } - - fn try_from_pg_format(input: &str) -> Result { - enum Expect { - Numeric { begin: usize }, - Alphabetic { begin: usize, numeric: i64 }, - Whitespace { numeric: Option }, - } - let mut seen = Vec::new(); - let mut get_unit = |start: usize, end: usize, default: Option<&str>| { - input - .get(start..end) - .or(default) - .and_then(|u| match u.to_lowercase().as_str() { - "h" | "hr" | "hrs" | "hour" | "hours" => Some(MICROS_PER_HOUR), - "m" | "min" | "mins" | "minute" | "minutes" => Some(MICROS_PER_MINUTE), - "ms" | "millisecon" | "millisecons" | "millisecond" | "milliseconds" => { - Some(MICROS_PER_MS) - } - "us" | "microsecond" | "microseconds" => Some(1), - "s" | "sec" | "secs" | "second" | "seconds" => Some(MICROS_PER_SECOND), - _ => None, - }) - .ok_or_else(|| ParseDurationError::new("unknown unit").pos(start)) - .and_then(|u| { - if seen.contains(&u) { - Err(ParseDurationError::new("specified more than once").pos(start)) - } else { - seen.push(u.clone()); - Ok(u) - } - }) - }; - let mut state = Expect::Whitespace { numeric: None }; - let mut result = 0; - for (pos, c) in input.char_indices() { - let is_whitespace = c.is_whitespace(); - let is_numeric = c.is_numeric() || c == '+' || c == '-'; - let is_alphabetic = c.is_alphabetic(); - if !(is_whitespace || is_numeric || is_alphabetic) { - return Err(ParseDurationError::new("unexpected character").pos(pos)); - } - match state { - Expect::Numeric { begin } if !is_numeric => { - let numeric = Self::get_pg_format_value(input, begin, pos)?; - if is_alphabetic { - state = Expect::Alphabetic { - begin: pos, - numeric, - }; - } else { - state = Expect::Whitespace { - numeric: Some(numeric), - }; - } - } - Expect::Alphabetic { begin, numeric } if !is_alphabetic => { - result += numeric * get_unit(begin, pos, None)?; - if is_numeric { - state = Expect::Numeric { begin: pos }; - } else { - state = Expect::Whitespace { numeric: None }; - } - } - Expect::Whitespace { numeric: None } if !is_whitespace => { - if is_numeric { - state = Expect::Numeric { begin: pos }; - } else { - return Err( - ParseDurationError::new("expecting whitespace or numeric").pos(pos) - ); - } - } - Expect::Whitespace { - numeric: Some(numeric), - } if !is_whitespace => { - if is_alphabetic { - state = Expect::Alphabetic { - begin: pos, - numeric, - }; - } else { - return Err( - ParseDurationError::new("expecting whitespace or alphabetic").pos(pos), - ); - } - } - _ => {} - } - } - match state { - Expect::Numeric { begin } => { - result += Self::get_pg_format_value(input, begin, input.len())? * MICROS_PER_SECOND; - } - Expect::Alphabetic { begin, numeric } => { - result += numeric * get_unit(begin, input.len(), Some("s"))?; - } - Expect::Whitespace { - numeric: Some(numeric), - } => { - result += numeric * MICROS_PER_SECOND; - } - _ => {} - } - Ok(Self { micros: result }) - } -} - -impl FromStr for Duration { - type Err = ParseDurationError; - - fn from_str(input: &str) -> Result { - if let Ok(seconds) = input.trim().parse::() { - seconds - .checked_mul(MICROS_PER_SECOND) - .map(Self::from_micros) - .ok_or_else(|| Self::Err::new("seconds value out of range").pos(input.len() - 1)) - } else { - Self::try_from_pg_simple_format(input) - .or_else(|e| { - if e.is_final { - Err(e) - } else { - Self::try_from_iso_format(input) - } - }) - .or_else(|e| { - if e.is_final { - Err(e) - } else { - Self::try_from_pg_format(input) - } - }) - } - } -} - -impl LocalDatetime { - // 0001-01-01T00:00:00 - pub const MIN: LocalDatetime = LocalDatetime { - micros: -63082281600000000, - }; - // 9999-12-31T23:59:59.999999 - pub const MAX: LocalDatetime = LocalDatetime { - micros: 252455615999999999, - }; - - pub(crate) fn from_postgres_micros(micros: i64) -> Result { - if !(Self::MIN.micros..=Self::MAX.micros).contains(µs) { - return Err(OutOfRangeError); - } - Ok(LocalDatetime { micros }) - } - - #[deprecated( - since = "0.5.0", - note = "use Datetime::try_from_unix_micros(v).into() instead" - )] - pub fn from_micros(micros: i64) -> LocalDatetime { - Self::from_postgres_micros(micros).unwrap_or_else(|_| { - panic!( - "LocalDatetime::from_micros({}) is outside the valid datetime range", - micros - ) - }) - } - - #[deprecated(since = "0.5.0", note = "use .to_utc().to_unix_micros() instead")] - pub fn to_micros(self) -> i64 { - self.micros - } - - pub fn new(date: LocalDate, time: LocalTime) -> LocalDatetime { - let micros = date.to_days() as i64 * MICROS_PER_DAY as i64 + time.to_micros() as i64; - LocalDatetime { micros } - } - - pub fn date(self) -> LocalDate { - LocalDate::from_days(self.micros.wrapping_div_euclid(MICROS_PER_DAY as i64) as i32) - } - - pub fn time(self) -> LocalTime { - LocalTime::from_micros(self.micros.wrapping_rem_euclid(MICROS_PER_DAY as i64) as u64) - } - - pub fn to_utc(self) -> Datetime { - Datetime { - micros: self.micros, - } - } -} - -impl From for LocalDatetime { - fn from(d: Datetime) -> LocalDatetime { - LocalDatetime { micros: d.micros } - } -} - -impl Display for LocalDatetime { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} {}", self.date(), self.time()) - } -} - -impl Debug for LocalDatetime { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}T{}", self.date(), self.time()) - } -} - -impl LocalTime { - pub const MIN: LocalTime = LocalTime { micros: 0 }; - pub const MIDNIGHT: LocalTime = LocalTime { micros: 0 }; - pub const MAX: LocalTime = LocalTime { - micros: MICROS_PER_DAY - 1, - }; - - pub(crate) fn try_from_micros(micros: u64) -> Result { - if micros < MICROS_PER_DAY { - Ok(LocalTime { micros }) - } else { - Err(OutOfRangeError) - } - } - - pub fn from_micros(micros: u64) -> LocalTime { - Self::try_from_micros(micros).expect("LocalTime is out of range") - } - - pub fn to_micros(self) -> u64 { - self.micros - } - - fn to_hmsu(self) -> (u8, u8, u8, u32) { - let micros = self.micros; - - let microsecond = (micros % 1_000_000) as u32; - let micros = micros / 1_000_000; - - let second = (micros % 60) as u8; - let micros = micros / 60; - - let minute = (micros % 60) as u8; - let micros = micros / 60; - - let hour = (micros % 24) as u8; - let micros = micros / 24; - debug_assert_eq!(0, micros); - - (hour, minute, second, microsecond) - } - - #[cfg(test)] // currently only used by tests, will be used by parsing later - fn from_hmsu(hour: u8, minute: u8, second: u8, microsecond: u32) -> LocalTime { - assert!(microsecond < 1_000_000); - assert!(second < 60); - assert!(minute < 60); - assert!(hour < 24); - - let micros = microsecond as u64 - + 1_000_000 * (second as u64 + 60 * (minute as u64 + 60 * (hour as u64))); - LocalTime::from_micros(micros) - } -} - -impl Display for LocalTime { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Debug::fmt(self, f) - } -} - -impl Debug for LocalTime { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let (hour, minute, second, microsecond) = self.to_hmsu(); - write!(f, "{:02}:{:02}:{:02}", hour, minute, second)?; - // like chrono::NaiveTime it outputs either 0, 3 or 6 decimal digits - if microsecond != 0 { - if microsecond % 1000 == 0 { - write!(f, ".{:03}", microsecond / 1000)?; - } else { - write!(f, ".{:06}", microsecond)?; - } - }; - Ok(()) - } -} - -impl LocalDate { - pub const MIN: LocalDate = LocalDate { days: -730119 }; // 0001-01-01 - pub const MAX: LocalDate = LocalDate { days: 2921939 }; // 9999-12-31 - pub const UNIX_EPOCH: LocalDate = LocalDate { - days: -(30 * 365 + 7), - }; // 1970-01-01 - - fn try_from_days(days: i32) -> Result { - if !(Self::MIN.days..=Self::MAX.days).contains(&days) { - return Err(OutOfRangeError); - } - Ok(LocalDate { days }) - } - - pub fn from_days(days: i32) -> LocalDate { - Self::try_from_days(days).unwrap_or_else(|_| { - panic!( - "LocalDate::from_days({}) is outside the valid date range", - days - ) - }) - } - - pub fn to_days(self) -> i32 { - self.days - } - - pub fn from_ymd(year: i32, month: u8, day: u8) -> LocalDate { - Self::try_from_ymd(year, month, day) - .unwrap_or_else(|_| panic!("invalid date {:04}-{:02}-{:02}", year, month, day)) - } - - fn try_from_ymd(year: i32, month: u8, day: u8) -> Result { - if !(1..=31).contains(&day) { - return Err(OutOfRangeError); - } - if !(1..=12).contains(&month) { - return Err(OutOfRangeError); - } - if !(MIN_YEAR..=MAX_YEAR).contains(&year) { - return Err(OutOfRangeError); - } - - let passed_years = (year - BASE_YEAR - 1) as u32; - let days_from_year = - 365 * passed_years + passed_years / 4 - passed_years / 100 + passed_years / 400 + 366; - - let is_leap_year = (year % 400 == 0) || (year % 4 == 0 && year % 100 != 0); - let day_to_month = if is_leap_year { - DAY_TO_MONTH_366 - } else { - DAY_TO_MONTH_365 - }; - - let day_in_year = (day - 1) as u32 + day_to_month[month as usize - 1]; - if day_in_year >= day_to_month[month as usize] { - return Err(OutOfRangeError); - } - - LocalDate::try_from_days( - (days_from_year + day_in_year) as i32 - - DAYS_IN_400_YEARS as i32 * ((2000 - BASE_YEAR) / 400), - ) - } - - fn to_ymd(self) -> (i32, u8, u8) { - const DAYS_IN_100_YEARS: u32 = 100 * 365 + 24; - const DAYS_IN_4_YEARS: u32 = 4 * 365 + 1; - const DAYS_IN_1_YEAR: u32 = 365; - const DAY_TO_MONTH_MARCH: [u32; 12] = - [0, 31, 61, 92, 122, 153, 184, 214, 245, 275, 306, 337]; - const MARCH_1: u32 = 31 + 29; - const MARCH_1_MINUS_BASE_YEAR_TO_POSTGRES_EPOCH: u32 = - (2000 - BASE_YEAR) as u32 / 400 * DAYS_IN_400_YEARS - MARCH_1; - - let days = (self.days as u32).wrapping_add(MARCH_1_MINUS_BASE_YEAR_TO_POSTGRES_EPOCH); - - let years400 = days / DAYS_IN_400_YEARS; - let days = days % DAYS_IN_400_YEARS; - - let mut years100 = days / DAYS_IN_100_YEARS; - if years100 == 4 { - years100 = 3 - }; // prevent 400 year leap day from overflowing - let days = days - DAYS_IN_100_YEARS * years100; - - let years4 = days / DAYS_IN_4_YEARS; - let days = days % DAYS_IN_4_YEARS; - - let mut years1 = days / DAYS_IN_1_YEAR; - if years1 == 4 { - years1 = 3 - }; // prevent 4 year leap day from overflowing - let days = days - DAYS_IN_1_YEAR * years1; - - let years = years1 + years4 * 4 + years100 * 100 + years400 * 400; - let month_entry = DAY_TO_MONTH_MARCH - .iter() - .filter(|d| days >= **d) - .enumerate() - .last() - .unwrap(); - let months = years * 12 + 2 + month_entry.0 as u32; - let year = (months / 12) as i32 + BASE_YEAR; - let month = (months % 12 + 1) as u8; - let day = (days - month_entry.1 + 1) as u8; - - (year, month, day) - } -} - -impl Display for LocalDate { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Debug::fmt(self, f) - } -} - -impl Debug for LocalDate { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let (year, month, day) = self.to_ymd(); - if year >= 10_000 { - // ISO format requires a + on dates longer than 4 digits - write!(f, "+")?; - } - if year >= 0 { - write!(f, "{:04}-{:02}-{:02}", year, month, day) - } else { - // rust counts the sign as a digit when padding - write!(f, "{:05}-{:02}-{:02}", year, month, day) - } - } -} - -impl Datetime { - // -63082281600000000 micros = Jan. 1 year 1 - pub const MIN: Datetime = Datetime { - micros: LocalDatetime::MIN.micros, - }; - // 252455615999999999 micros = Dec. 31 year 9999 - pub const MAX: Datetime = Datetime { - micros: LocalDatetime::MAX.micros, - }; - pub const UNIX_EPOCH: Datetime = Datetime { - //micros: 0 - micros: LocalDate::UNIX_EPOCH.days as i64 * MICROS_PER_DAY as i64, - }; - - /// Convert microseconds since unix epoch into a datetime - pub fn try_from_unix_micros(micros: i64) -> Result { - Self::_from_micros(micros).ok_or(OutOfRangeError) - } - - #[deprecated(since = "0.5.0", note = "use try_from_unix_micros instead")] - pub fn try_from_micros(micros: i64) -> Result { - Self::from_postgres_micros(micros) - } - - pub(crate) fn from_postgres_micros(micros: i64) -> Result { - if !(Self::MIN.micros..=Self::MAX.micros).contains(µs) { - return Err(OutOfRangeError); - } - Ok(Datetime { micros }) - } - - fn _from_micros(micros: i64) -> Option { - let micros = micros.checked_add(Self::UNIX_EPOCH.micros)?; - if !(Self::MIN.micros..=Self::MAX.micros).contains(µs) { - return None; - } - Some(Datetime { micros }) - } - - #[deprecated(since = "0.5.0", note = "use from_unix_micros instead")] - pub fn from_micros(micros: i64) -> Datetime { - Self::from_postgres_micros(micros).unwrap_or_else(|_| { - panic!( - "Datetime::from_micros({}) is outside the valid datetime range", - micros - ) - }) - } - - /// Convert microseconds since unix epoch into a datetime - /// - /// # Panics - /// - /// When value is out of range. - pub fn from_unix_micros(micros: i64) -> Datetime { - if let Some(result) = Self::_from_micros(micros) { - return result; - } - panic!( - "Datetime::from_micros({}) is outside the valid datetime range", - micros - ); - } - - #[deprecated(since = "0.5.0", note = "use to_unix_micros instead")] - pub fn to_micros(self) -> i64 { - self.micros - } - - /// Convert datetime to microseconds since Unix Epoch - pub fn to_unix_micros(self) -> i64 { - // i64 is enough to fit our range with both epochs - self.micros - Datetime::UNIX_EPOCH.micros - } - - fn postgres_epoch_unix() -> SystemTime { - use std::time::Duration; - // postgres epoch starts at 2000-01-01 - UNIX_EPOCH + Duration::from_micros((-Datetime::UNIX_EPOCH.micros) as u64) - } -} - -impl Display for Datetime { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{} UTC", - LocalDatetime { - micros: self.micros - } - ) - } -} - -impl Debug for Datetime { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{:?}Z", - LocalDatetime { - micros: self.micros - } - ) - } -} - -impl TryFrom for SystemTime { - type Error = OutOfRangeError; - - fn try_from(value: Datetime) -> Result { - use std::time::Duration; - - if value.micros > 0 { - Datetime::postgres_epoch_unix().checked_add(Duration::from_micros(value.micros as u64)) - } else { - Datetime::postgres_epoch_unix() - .checked_sub(Duration::from_micros((-value.micros) as u64)) - } - .ok_or(OutOfRangeError) - } -} - -impl TryFrom for Duration { - type Error = OutOfRangeError; - - fn try_from(value: std::time::Duration) -> Result { - TryFrom::try_from(&value) - } -} - -impl TryFrom<&std::time::Duration> for Duration { - type Error = OutOfRangeError; - - fn try_from(value: &std::time::Duration) -> Result { - let secs = value.as_secs(); - let subsec_nanos = value.subsec_nanos(); - let subsec_micros = nanos_to_micros(subsec_nanos.into()); - let micros = i64::try_from(secs) - .ok() - .and_then(|x| x.checked_mul(1_000_000)) - .and_then(|x| x.checked_add(subsec_micros)) - .ok_or(OutOfRangeError)?; - Ok(Duration { micros }) - } -} - -impl TryFrom<&Duration> for std::time::Duration { - type Error = OutOfRangeError; - - fn try_from(value: &Duration) -> Result { - let micros = value.micros.try_into().map_err(|_| OutOfRangeError)?; - Ok(std::time::Duration::from_micros(micros)) - } -} -impl TryFrom for std::time::Duration { - type Error = OutOfRangeError; - - fn try_from(value: Duration) -> Result { - (&value).try_into() - } -} - -impl TryFrom for Datetime { - type Error = OutOfRangeError; - - fn try_from(value: SystemTime) -> Result { - match value.duration_since(UNIX_EPOCH) { - Ok(duration) => { - let secs = duration.as_secs(); - let subsec_nanos = duration.subsec_nanos(); - let subsec_micros = nanos_to_micros(subsec_nanos.into()); - let micros = i64::try_from(secs) - .ok() - .and_then(|x| x.checked_mul(1_000_000)) - .and_then(|x| x.checked_add(subsec_micros)) - .and_then(|x| x.checked_add(Datetime::UNIX_EPOCH.micros)) - .ok_or(OutOfRangeError)?; - if micros > Datetime::MAX.micros { - return Err(OutOfRangeError); - } - Ok(Datetime { micros }) - } - Err(e) => { - let mut secs = e.duration().as_secs(); - let mut subsec_nanos = e.duration().subsec_nanos(); - if subsec_nanos > 0 { - secs = secs.checked_add(1).ok_or(OutOfRangeError)?; - subsec_nanos = 1_000_000_000 - subsec_nanos; - } - let subsec_micros = nanos_to_micros(subsec_nanos.into()); - let micros = i64::try_from(secs) - .ok() - .and_then(|x| x.checked_mul(1_000_000)) - .and_then(|x| Datetime::UNIX_EPOCH.micros.checked_sub(x)) - .and_then(|x| x.checked_add(subsec_micros)) - .ok_or(OutOfRangeError)?; - if micros < Datetime::MIN.micros { - return Err(OutOfRangeError); - } - Ok(Datetime { micros }) - } - } - } -} - -impl std::ops::Add<&'_ std::time::Duration> for Datetime { - type Output = Datetime; - fn add(self, other: &std::time::Duration) -> Datetime { - let Ok(duration) = Duration::try_from(other) else { - debug_assert!(false, "duration is out of range"); - return Datetime::MAX; - }; - if let Some(micros) = self.micros.checked_add(duration.micros) { - Datetime { micros } - } else { - debug_assert!(false, "duration is out of range"); - Datetime::MAX - } - } -} - -impl std::ops::Add for Datetime { - type Output = Datetime; - #[allow(clippy::op_ref)] - fn add(self, other: std::time::Duration) -> Datetime { - self + &other - } -} - -impl Display for Duration { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let abs = if self.micros < 0 { - write!(f, "-")?; - -self.micros - } else { - self.micros - }; - let (sec, micros) = (abs / 1_000_000, abs % 1_000_000); - if micros != 0 { - let mut fract = micros; - let mut zeros = 0; - while fract % 10 == 0 { - zeros += 1; - fract /= 10; - } - write!( - f, - "{hours}:{minutes:02}:{seconds:02}.{fract:0>fsize$}", - hours = sec / 3600, - minutes = sec / 60 % 60, - seconds = sec % 60, - fract = fract, - fsize = 6 - zeros, - ) - } else { - write!(f, "{}:{:02}:{:02}", sec / 3600, sec / 60 % 60, sec % 60) - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn micros_conv() { - let datetime = Datetime::from_unix_micros(1645681383000002); - assert_eq!(datetime.micros, 698996583000002); - assert_eq!(to_debug(datetime), "2022-02-24T05:43:03.000002Z"); - } - - #[test] - fn big_duration_abs() { - use super::Duration as Src; - use std::time::Duration as Trg; - assert_eq!(Src { micros: -1 }.abs_duration(), Trg::new(0, 1000)); - assert_eq!(Src { micros: -1000 }.abs_duration(), Trg::new(0, 1000000)); - assert_eq!(Src { micros: -1000000 }.abs_duration(), Trg::new(1, 0)); - assert_eq!( - Src { micros: i64::MIN }.abs_duration(), - Trg::new(9223372036854, 775808000) - ); - } - - #[test] - fn local_date_from_ymd() { - assert_eq!(0, LocalDate::from_ymd(2000, 1, 1).to_days()); - assert_eq!(-365, LocalDate::from_ymd(1999, 1, 1).to_days()); - assert_eq!(366, LocalDate::from_ymd(2001, 1, 1).to_days()); - assert_eq!(-730119, LocalDate::from_ymd(1, 1, 1).to_days()); - assert_eq!(2921575, LocalDate::from_ymd(9999, 1, 1).to_days()); - - assert_eq!(Err(OutOfRangeError), LocalDate::try_from_ymd(2001, 1, 32)); - assert_eq!(Err(OutOfRangeError), LocalDate::try_from_ymd(2001, 2, 29)); - } - - #[test] - fn local_date_from_ymd_leap_year() { - let days_in_month_leap = [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; - let mut total_days = 0; - let start_of_year = 365 * 4 + 1; - for month in 1..=12 { - let start_of_month = LocalDate::from_ymd(2004, month as u8, 1).to_days(); - assert_eq!(total_days, start_of_month - start_of_year); - - let days_in_current_month = days_in_month_leap[month - 1]; - total_days += days_in_current_month; - - let end_of_month = - LocalDate::from_ymd(2004, month as u8, days_in_current_month as u8).to_days(); - assert_eq!(total_days - 1, end_of_month - start_of_year); - } - assert_eq!(366, total_days); - } - - const DAYS_IN_MONTH_LEAP: [u8; 12] = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; - - #[test] - fn local_date_from_ymd_normal_year() { - let mut total_days = 0; - let start_of_year = 365 + 1; - for month in 1..=12 { - let start_of_month = LocalDate::from_ymd(2001, month as u8, 1).to_days(); - assert_eq!(total_days, start_of_month - start_of_year); - - let days_in_current_month = DAYS_IN_MONTH_LEAP[month - 1]; - total_days += days_in_current_month as i32; - - let end_of_month = - LocalDate::from_ymd(2001, month as u8, days_in_current_month).to_days(); - assert_eq!(total_days - 1, end_of_month - start_of_year); - } - assert_eq!(365, total_days); - } - - pub const CHRONO_MAX_YEAR: i32 = 262_143; - - fn extended_test_dates() -> impl Iterator { - const YEARS: [i32; 36] = [ - 1, - 2, - 1000, - 1969, - 1970, // unix epoch - 1971, - 1999, - 2000, // postgres epoch - 2001, - 2002, - 2003, - 2004, - 2008, - 2009, - 2010, - 2100, - 2200, - 2300, - 2400, - 9000, - 9999, - 10_000, - 10_001, - 11_000, - 20_000, - 100_000, - 200_000, - CHRONO_MAX_YEAR - 1, - CHRONO_MAX_YEAR, - CHRONO_MAX_YEAR + 1, - MAX_YEAR - 1000, - MAX_YEAR - 31, - MAX_YEAR - 30, // maximum unix based - MAX_YEAR - 29, // less than 30 years before maximum, so a unix epoch in microseconds overflows - MAX_YEAR - 1, - MAX_YEAR, - ]; - - const MONTHS: std::ops::RangeInclusive = 1u8..=12; - const DAYS: [u8; 6] = [1u8, 13, 28, 29, 30, 31]; - let dates = MONTHS.flat_map(|month| DAYS.iter().map(move |day| (month, *day))); - - YEARS - .iter() - .flat_map(move |year| dates.clone().map(move |date| (*year, date.0, date.1))) - } - - pub fn valid_test_dates() -> impl Iterator { - extended_test_dates().filter(|date| LocalDate::try_from_ymd(date.0, date.1, date.2).is_ok()) - } - - pub fn test_times() -> impl Iterator { - const TIMES: [u64; 7] = [ - 0, - 10, - 10_020, - 12345 * 1_000_000, - 12345 * 1_001_000, - 12345 * 1_001_001, - MICROS_PER_DAY - 1, - ]; - TIMES.iter().copied() - } - - #[test] - fn check_test_dates() { - assert!(valid_test_dates().count() > 1000); - } - - #[test] - fn local_date_ymd_roundtrip() { - for (year, month, day) in valid_test_dates() { - let date = LocalDate::from_ymd(year, month, day); - assert_eq!((year, month, day), date.to_ymd()); - } - } - - #[test] - fn local_time_parts_roundtrip() { - for time in test_times() { - let expected_time = LocalTime::from_micros(time); - let (hour, minute, second, microsecond) = expected_time.to_hmsu(); - let actual_time = LocalTime::from_hmsu(hour, minute, second, microsecond); - assert_eq!(expected_time, actual_time); - } - } - - #[test] - fn format_local_date() { - assert_eq!("2000-01-01", LocalDate::from_days(0).to_string()); - assert_eq!("0001-01-01", LocalDate::MIN.to_string()); - assert_eq!("9999-12-31", LocalDate::MAX.to_string()); - } - - #[test] - fn format_local_time() { - assert_eq!("00:00:00", LocalTime::MIDNIGHT.to_string()); - assert_eq!("00:00:00.010", LocalTime::from_micros(10_000).to_string()); - assert_eq!( - "00:00:00.010020", - LocalTime::from_micros(10_020).to_string() - ); - assert_eq!("23:59:59.999999", LocalTime::MAX.to_string()); - } - - pub fn to_debug(x: T) -> String { - format!("{:?}", x) - } - - #[test] - #[allow(deprecated)] - fn format_local_datetime() { - assert_eq!( - "2039-02-13 23:31:30.123456", - LocalDatetime::from_micros(1_234_567_890_123_456).to_string() - ); - assert_eq!( - "2039-02-13T23:31:30.123456", - to_debug(LocalDatetime::from_micros(1_234_567_890_123_456)) - ); - - assert_eq!("0001-01-01 00:00:00", LocalDatetime::MIN.to_string()); - assert_eq!("0001-01-01T00:00:00", to_debug(LocalDatetime::MIN)); - - assert_eq!("9999-12-31 23:59:59.999999", LocalDatetime::MAX.to_string()); - assert_eq!("9999-12-31T23:59:59.999999", to_debug(LocalDatetime::MAX)); - } - - #[test] - #[allow(deprecated)] - fn format_datetime() { - assert_eq!( - "2039-02-13 23:31:30.123456 UTC", - Datetime::from_micros(1_234_567_890_123_456).to_string() - ); - assert_eq!( - "2039-02-13T23:31:30.123456Z", - to_debug(Datetime::from_micros(1_234_567_890_123_456)) - ); - - assert_eq!("0001-01-01 00:00:00 UTC", Datetime::MIN.to_string()); - assert_eq!("0001-01-01T00:00:00Z", to_debug(Datetime::MIN)); - - assert_eq!("9999-12-31 23:59:59.999999 UTC", Datetime::MAX.to_string()); - assert_eq!("9999-12-31T23:59:59.999999Z", to_debug(Datetime::MAX)); - } - - #[test] - fn format_duration() { - fn dur_str(msec: i64) -> String { - Duration::from_micros(msec).to_string() - } - assert_eq!(dur_str(1_000_000), "0:00:01"); - assert_eq!(dur_str(1), "0:00:00.000001"); - assert_eq!(dur_str(7_015_000), "0:00:07.015"); - assert_eq!(dur_str(10_000_000_015_000), "2777:46:40.015"); - assert_eq!(dur_str(12_345_678_000_000), "3429:21:18"); - } - - #[test] - fn parse_duration_str() { - fn micros(input: &str) -> i64 { - Duration::from_str(input).unwrap().micros - } - assert_eq!(micros(" 100 "), 100_000_000); - assert_eq!(micros("123"), 123_000_000); - assert_eq!(micros("-123"), -123_000_000); - assert_eq!(micros(" 20 mins 1hr "), 4_800_000_000); - assert_eq!(micros(" 20 mins -1hr "), -2_400_000_000); - assert_eq!(micros(" 20us 1h 20 "), 3_620_000_020); - assert_eq!(micros(" -20us 1h 20 "), 3_619_999_980); - assert_eq!(micros(" -20US 1H 20 "), 3_619_999_980); - assert_eq!( - micros("1 hour 20 minutes 30 seconds 40 milliseconds 50 microseconds"), - 4_830_040_050 - ); - assert_eq!( - micros("1 hour 20 minutes +30seconds 40 milliseconds -50microseconds"), - 4_830_039_950 - ); - assert_eq!( - micros("1 houR 20 minutes 30SECOND 40 milliseconds 50 us"), - 4_830_040_050 - ); - assert_eq!(micros(" 20 us 1H 20 minutes "), 4_800_000_020); - assert_eq!(micros("-1h"), -3_600_000_000); - assert_eq!(micros("100h"), 360_000_000_000); - let h12 = 12 * 3_600_000_000_i64; - let m12 = 12 * 60_000_000_i64; - assert_eq!(micros(" 12:12:12.2131 "), h12 + m12 + 12_213_100); - assert_eq!(micros("-12:12:12.21313"), -(h12 + m12 + 12_213_130)); - assert_eq!(micros("-12:12:12.213134"), -(h12 + m12 + 12_213_134)); - assert_eq!(micros("-12:12:12.2131341"), -(h12 + m12 + 12_213_134)); - assert_eq!(micros("-12:12:12.2131341111111"), -(h12 + m12 + 12_213_134)); - assert_eq!(micros("-12:12:12.2131315111111"), -(h12 + m12 + 12_213_132)); - assert_eq!(micros("-12:12:12.2131316111111"), -(h12 + m12 + 12_213_132)); - assert_eq!(micros("-12:12:12.2131314511111"), -(h12 + m12 + 12_213_131)); - assert_eq!(micros("-0:12:12.2131"), -(m12 + 12_213_100)); - assert_eq!(micros("12:12"), h12 + m12); - assert_eq!(micros("-12:12"), -(h12 + m12)); - assert_eq!(micros("-12:1:1"), -(h12 + 61_000_000)); - assert_eq!(micros("+12:1:1"), h12 + 61_000_000); - assert_eq!(micros("-12:1:1.1234"), -(h12 + 61_123_400)); - assert_eq!(micros("1211:59:59.9999"), h12 * 100 + h12 - 100); - assert_eq!(micros("-12:"), -h12); - assert_eq!(micros("0"), 0); - assert_eq!(micros("00:00:00"), 0); - assert_eq!(micros("00:00:10.9"), 10_900_000); - assert_eq!(micros("00:00:10.09"), 10_090_000); - assert_eq!(micros("00:00:10.009"), 10_009_000); - assert_eq!(micros("00:00:10.0009"), 10_000_900); - assert_eq!(micros("00:00:00.5"), 500_000); - assert_eq!(micros(" +00005"), 5_000_000); - assert_eq!(micros(" -00005"), -5_000_000); - assert_eq!(micros("PT"), 0); - assert_eq!(micros("PT1H1M1S"), 3_661_000_000); - assert_eq!(micros("PT1M1S"), 61_000_000); - assert_eq!(micros("PT1S"), 1_000_000); - assert_eq!(micros("PT1H1S"), 3_601_000_000); - assert_eq!(micros("PT1H1M1.1S"), 3_661_100_000); - assert_eq!(micros("PT1H1M1.01S"), 3_661_010_000); - assert_eq!(micros("PT1H1M1.10S"), 3_661_100_000); - assert_eq!(micros("PT1H1M1.1234567S"), 3_661_123_456); - assert_eq!(micros("PT1H1M1.1234564S"), 3_661_123_456); - assert_eq!(micros("PT-1H1M1.1S"), -3_538_900_000); - assert_eq!(micros("PT+1H-1M1.1S"), 3_541_100_000); - assert_eq!(micros("PT1H+1M-1.1S"), 3_658_900_000); - - fn assert_error(input: &str, expected_pos: usize, pat: &str) { - let ParseDurationError { - pos, - message, - is_final: _, - } = Duration::from_str(input).unwrap_err(); - assert_eq!(pos, expected_pos); - assert!( - message.contains(pat), - "`{}` not found in `{}`", - pat, - message, - ); - } - assert_error("blah", 0, "numeric"); - assert_error("!", 0, "unexpected"); - assert_error("-", 0, "invalid digit"); - assert_error("+", 0, "invalid digit"); - assert_error(" 20 us 1H 20 30 minutes ", 14, "alphabetic"); - assert_error(" 12:12:121.2131 ", 11, "seconds"); - assert_error(" 12:60:21.2131 ", 7, "minutes"); - assert_error(" 20us 20 1h ", 12, "alphabetic"); - assert_error(" 20us $ 20 1h ", 7, "unexpected"); - assert_error( - "1 houR 20 minutes 30SECOND 40 milliseconds 50 uss", - 47, - "unit", - ); - assert_error("PT1M1H", 4, "EOF"); - assert_error("PT1S1M", 4, "EOF"); - } - - #[test] - fn add_duration_rounding() { - // round down - assert_eq!( - Datetime::UNIX_EPOCH + std::time::Duration::new(17, 500), - Datetime::UNIX_EPOCH + std::time::Duration::new(17, 0), - ); - // round up - assert_eq!( - Datetime::UNIX_EPOCH + std::time::Duration::new(12345, 1500), - Datetime::UNIX_EPOCH + std::time::Duration::new(12345, 2000), - ); - } - - #[test] - #[allow(deprecated)] - fn to_and_from_unix_micros_roundtrip() { - let zero_micros = 0; - let datetime = Datetime::from_unix_micros(0); - // Unix micros should equal 0 - assert_eq!(zero_micros, datetime.to_unix_micros()); - // Datetime (Postgres epoch-based) micros should be negative - // Micros = negative micros to go from 2000 to 1970 - assert_eq!(datetime.micros, datetime.to_micros()); - assert_eq!(datetime.micros, Datetime::UNIX_EPOCH.micros); - assert_eq!(datetime.micros, -946684800000000); - } -} - -impl RelativeDuration { - pub fn try_from_years(years: i32) -> Result { - Ok(RelativeDuration { - months: years.checked_mul(12).ok_or(OutOfRangeError)?, - days: 0, - micros: 0, - }) - } - pub fn from_years(years: i32) -> RelativeDuration { - RelativeDuration::try_from_years(years).unwrap() - } - pub fn try_from_months(months: i32) -> Result { - Ok(RelativeDuration { - months, - days: 0, - micros: 0, - }) - } - pub fn from_months(months: i32) -> RelativeDuration { - RelativeDuration::try_from_months(months).unwrap() - } - pub fn try_from_days(days: i32) -> Result { - Ok(RelativeDuration { - months: 0, - days, - micros: 0, - }) - } - pub fn from_days(days: i32) -> RelativeDuration { - RelativeDuration::try_from_days(days).unwrap() - } - pub fn try_from_hours(hours: i64) -> Result { - Ok(RelativeDuration { - months: 0, - days: 0, - micros: hours.checked_mul(3_600_000_000).ok_or(OutOfRangeError)?, - }) - } - pub fn from_hours(hours: i64) -> RelativeDuration { - RelativeDuration::try_from_hours(hours).unwrap() - } - pub fn try_from_minutes(minutes: i64) -> Result { - Ok(RelativeDuration { - months: 0, - days: 0, - micros: minutes.checked_mul(60_000_000).ok_or(OutOfRangeError)?, - }) - } - pub fn from_minutes(minutes: i64) -> RelativeDuration { - RelativeDuration::try_from_minutes(minutes).unwrap() - } - pub fn try_from_secs(secs: i64) -> Result { - Ok(RelativeDuration { - months: 0, - days: 0, - micros: secs.checked_mul(1_000_000).ok_or(OutOfRangeError)?, - }) - } - pub fn from_secs(secs: i64) -> RelativeDuration { - RelativeDuration::try_from_secs(secs).unwrap() - } - pub fn try_from_millis(millis: i64) -> Result { - Ok(RelativeDuration { - months: 0, - days: 0, - micros: millis.checked_mul(1_000).ok_or(OutOfRangeError)?, - }) - } - pub fn from_millis(millis: i64) -> RelativeDuration { - RelativeDuration::try_from_millis(millis).unwrap() - } - pub fn try_from_micros(micros: i64) -> Result { - Ok(RelativeDuration { - months: 0, - days: 0, - micros, - }) - } - pub fn from_micros(micros: i64) -> RelativeDuration { - RelativeDuration::try_from_micros(micros).unwrap() - } - pub fn checked_add(self, other: Self) -> Option { - Some(RelativeDuration { - months: self.months.checked_add(other.months)?, - days: self.days.checked_add(other.days)?, - micros: self.micros.checked_add(other.micros)?, - }) - } - pub fn checked_sub(self, other: Self) -> Option { - Some(RelativeDuration { - months: self.months.checked_sub(other.months)?, - days: self.days.checked_sub(other.days)?, - micros: self.micros.checked_sub(other.micros)?, - }) - } -} - -impl std::ops::Add for RelativeDuration { - type Output = Self; - fn add(self, other: Self) -> Self { - RelativeDuration { - months: self.months + other.months, - days: self.days + other.days, - micros: self.micros + other.micros, - } - } -} - -impl std::ops::Sub for RelativeDuration { - type Output = Self; - fn sub(self, other: Self) -> Self { - RelativeDuration { - months: self.months - other.months, - days: self.days - other.days, - micros: self.micros - other.micros, - } - } -} - -impl Display for RelativeDuration { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self.months == 0 && self.days == 0 && self.micros == 0 { - return write!(f, "PT0S"); - } - write!(f, "P")?; - if self.months.abs() >= 12 { - write!(f, "{}Y", self.months / 12)?; - } - if (self.months % 12).abs() > 0 { - write!(f, "{}M", self.months % 12)?; - } - if self.days.abs() > 0 { - write!(f, "{}D", self.days)?; - } - if self.micros.abs() > 0 { - write!(f, "T")?; - if self.micros.abs() >= 3_600_000_000 { - write!(f, "{}H", self.micros / 3_600_000_000)?; - } - let minutes = self.micros % 3_600_000_000; - if minutes.abs() >= 60_000_000 { - write!(f, "{}M", minutes / 60_000_000)?; - } - let seconds = minutes % 60_000_000; - if seconds.abs() >= 1_000_000 { - write!(f, "{}", seconds / 1_000_000)?; - } - let micros = seconds % 1_000_000; - if micros.abs() > 0 { - let mut buf = [0u8; 6]; - let text = { - use std::io::{Cursor, Write}; - - let mut cur = Cursor::new(&mut buf[..]); - write!(cur, "{:06}", micros.abs()).unwrap(); - let mut len = buf.len(); - while buf[len - 1] == b'0' { - len -= 1; - } - std::str::from_utf8(&buf[..len]).unwrap() - }; - write!(f, ".{}", text)?; - } - if seconds.abs() > 0 { - write!(f, "S")?; - } - } - Ok(()) - } -} - -#[test] -fn relative_duration_display() { - let dur = RelativeDuration::from_years(2) - + RelativeDuration::from_months(56) - + RelativeDuration::from_days(-16) - + RelativeDuration::from_hours(48) - + RelativeDuration::from_minutes(245) - + RelativeDuration::from_secs(7) - + RelativeDuration::from_millis(600); - assert_eq!(dur.to_string(), "P6Y8M-16DT52H5M7.6S"); - - let dur = RelativeDuration::from_years(2) - + RelativeDuration::from_months(-56) - + RelativeDuration::from_days(-16) - + RelativeDuration::from_minutes(-245) - + RelativeDuration::from_secs(7) - + RelativeDuration::from_millis(600); - assert_eq!(dur.to_string(), "P-2Y-8M-16DT-4H-4M-52.4S"); - - let dur = RelativeDuration::from_years(1); - assert_eq!(dur.to_string(), "P1Y"); - let dur = RelativeDuration::from_months(1); - assert_eq!(dur.to_string(), "P1M"); - let dur = RelativeDuration::from_hours(1); - assert_eq!(dur.to_string(), "PT1H"); - let dur = RelativeDuration::from_minutes(1); - assert_eq!(dur.to_string(), "PT1M"); - let dur = RelativeDuration::from_secs(1); - assert_eq!(dur.to_string(), "PT1S"); -} - -impl DateDuration { - pub fn try_from_years(years: i32) -> Result { - Ok(DateDuration { - months: years.checked_mul(12).ok_or(OutOfRangeError)?, - days: 0, - }) - } - pub fn from_years(years: i32) -> DateDuration { - DateDuration::try_from_years(years).unwrap() - } - pub fn try_from_months(months: i32) -> Result { - Ok(DateDuration { months, days: 0 }) - } - pub fn from_months(months: i32) -> DateDuration { - DateDuration::try_from_months(months).unwrap() - } - pub fn try_from_days(days: i32) -> Result { - Ok(DateDuration { months: 0, days }) - } - pub fn from_days(days: i32) -> DateDuration { - DateDuration::try_from_days(days).unwrap() - } - pub fn checked_add(self, other: Self) -> Option { - Some(DateDuration { - months: self.months.checked_add(other.months)?, - days: self.days.checked_add(other.days)?, - }) - } - pub fn checked_sub(self, other: Self) -> Option { - Some(DateDuration { - months: self.months.checked_sub(other.months)?, - days: self.days.checked_sub(other.days)?, - }) - } -} - -impl std::ops::Add for DateDuration { - type Output = Self; - fn add(self, other: Self) -> Self { - DateDuration { - months: self.months + other.months, - days: self.days + other.days, - } - } -} - -impl std::ops::Sub for DateDuration { - type Output = Self; - fn sub(self, other: Self) -> Self { - DateDuration { - months: self.months - other.months, - days: self.days - other.days, - } - } -} - -impl Display for DateDuration { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self.months == 0 && self.days == 0 { - return write!(f, "PT0D"); // XXX - } - write!(f, "P")?; - if self.months.abs() >= 12 { - write!(f, "{}Y", self.months / 12)?; - } - if (self.months % 12).abs() > 0 { - write!(f, "{}M", self.months % 12)?; - } - if self.days.abs() > 0 { - write!(f, "{}D", self.days)?; - } - Ok(()) - } -} - -fn nanos_to_micros(nanos: i64) -> i64 { - // round to the nearest even - let mut micros = nanos / 1000; - let remainder = nanos % 1000; - if remainder == 500 && micros % 2 == 1 || remainder > 500 { - micros += 1; - } - micros -} - -#[cfg(feature = "chrono")] -mod chrono_interop { - use super::*; - use chrono::naive::{NaiveDate, NaiveDateTime, NaiveTime}; - use chrono::DateTime; - - type ChronoDatetime = chrono::DateTime; - - impl From<&LocalDatetime> for NaiveDateTime { - fn from(value: &LocalDatetime) -> NaiveDateTime { - let timestamp_seconds = value.micros.wrapping_div_euclid(1_000_000) - - (Datetime::UNIX_EPOCH.micros / 1_000_000); - let timestamp_nanos = (value.micros.wrapping_rem_euclid(1_000_000) * 1000) as u32; - DateTime::from_timestamp(timestamp_seconds, timestamp_nanos) - .expect("NaiveDateTime range is bigger than LocalDatetime") - .naive_utc() - } - } - - impl TryFrom<&NaiveDateTime> for LocalDatetime { - type Error = OutOfRangeError; - fn try_from(d: &NaiveDateTime) -> Result { - let secs = d.and_utc().timestamp(); - let subsec_nanos = d.and_utc().timestamp_subsec_nanos(); - let subsec_micros = nanos_to_micros(subsec_nanos.into()); - let micros = secs - .checked_mul(1_000_000) - .and_then(|x| x.checked_add(subsec_micros)) - .and_then(|x| x.checked_add(Datetime::UNIX_EPOCH.micros)) - .ok_or(OutOfRangeError)?; - if !(LocalDatetime::MIN.micros..=LocalDatetime::MAX.micros).contains(µs) { - return Err(OutOfRangeError); - } - Ok(LocalDatetime { micros }) - } - } - - impl From<&Datetime> for ChronoDatetime { - fn from(value: &Datetime) -> ChronoDatetime { - use chrono::TimeZone; - - let pg_epoch = chrono::Utc.with_ymd_and_hms(2000, 1, 1, 0, 0, 0).unwrap(); - let duration = chrono::Duration::microseconds(value.micros); - pg_epoch - .checked_add_signed(duration) - .expect("EdgeDB datetime range is smaller than Chrono's") - } - } - - impl From for ChronoDatetime { - fn from(value: Datetime) -> ChronoDatetime { - (&value).into() - } - } - - impl TryFrom<&ChronoDatetime> for Datetime { - type Error = OutOfRangeError; - - fn try_from(value: &ChronoDatetime) -> Result { - let min = ChronoDatetime::from(Datetime::MIN); - let duration = value - .signed_duration_since(min) - .to_std() - .map_err(|_| OutOfRangeError)?; - let secs = duration.as_secs(); - let subsec_micros = nanos_to_micros(duration.subsec_nanos().into()); - let micros = i64::try_from(secs) - .ok() - .and_then(|x| x.checked_mul(1_000_000)) - .and_then(|x| x.checked_add(subsec_micros)) - .and_then(|x| x.checked_add(Datetime::MIN.micros)) - .ok_or(OutOfRangeError)?; - if micros > Datetime::MAX.micros { - return Err(OutOfRangeError); - } - Ok(Datetime { micros }) - } - } - - impl TryFrom<&NaiveDate> for LocalDate { - type Error = OutOfRangeError; - fn try_from(d: &NaiveDate) -> Result { - let days = chrono::Datelike::num_days_from_ce(d); - Ok(LocalDate { - days: days - .checked_sub(DAYS_IN_2000_YEARS - 365) - .ok_or(OutOfRangeError)?, - }) - } - } - - impl From<&LocalDate> for NaiveDate { - fn from(value: &LocalDate) -> NaiveDate { - value - .days - .checked_add(DAYS_IN_2000_YEARS - 365) - .and_then(NaiveDate::from_num_days_from_ce_opt) - .expect("NaiveDate range is bigger than LocalDate") - } - } - - impl From<&LocalTime> for NaiveTime { - fn from(value: &LocalTime) -> NaiveTime { - NaiveTime::from_num_seconds_from_midnight_opt( - (value.micros / 1_000_000) as u32, - ((value.micros % 1_000_000) * 1000) as u32, - ) - .expect("localtime and native time have equal range") - } - } - - impl From<&NaiveTime> for LocalTime { - fn from(time: &NaiveTime) -> LocalTime { - let sec = chrono::Timelike::num_seconds_from_midnight(time); - let nanos = nanos_to_micros(chrono::Timelike::nanosecond(time) as i64) as u64; - let mut micros = sec as u64 * 1_000_000 + nanos; - - if micros >= 86_400_000_000 { - // this is only possible due to rounding: - // >= 23:59:59.999999500 - micros -= 86_400_000_000; - } - - LocalTime { micros } - } - } - - impl From for NaiveDateTime { - fn from(value: LocalDatetime) -> NaiveDateTime { - (&value).into() - } - } - - impl From for NaiveDate { - fn from(value: LocalDate) -> NaiveDate { - (&value).into() - } - } - - impl TryFrom for LocalDate { - type Error = OutOfRangeError; - fn try_from(d: NaiveDate) -> Result { - std::convert::TryFrom::try_from(&d) - } - } - - impl From for NaiveTime { - fn from(value: LocalTime) -> NaiveTime { - (&value).into() - } - } - - impl TryFrom for LocalDatetime { - type Error = OutOfRangeError; - fn try_from(d: NaiveDateTime) -> Result { - std::convert::TryFrom::try_from(&d) - } - } - - impl TryFrom for Datetime { - type Error = OutOfRangeError; - fn try_from(d: ChronoDatetime) -> Result { - std::convert::TryFrom::try_from(&d) - } - } - - impl From for LocalTime { - fn from(time: NaiveTime) -> LocalTime { - From::from(&time) - } - } - - #[cfg(test)] - mod test { - use super::*; - use crate::model::time::test::{test_times, to_debug, valid_test_dates, CHRONO_MAX_YEAR}; - - #[test] - fn chrono_roundtrips() -> Result<(), Box> { - let naive = NaiveDateTime::from_str("2019-12-27T01:02:03.123456")?; - assert_eq!( - naive, - Into::::into(LocalDatetime::try_from(naive)?) - ); - let naive = NaiveDate::from_str("2019-12-27")?; - assert_eq!(naive, Into::::into(LocalDate::try_from(naive)?)); - let naive = NaiveTime::from_str("01:02:03.123456")?; - assert_eq!(naive, Into::::into(LocalTime::from(naive))); - Ok(()) - } - - fn check_display(expected_value: E, actual_value: A) { - let expected_display = expected_value.to_string(); - let actual_display = actual_value.to_string(); - assert_eq!(expected_display, actual_display); - } - - fn check_debug(expected_value: E, actual_value: A) { - let expected_debug = to_debug(expected_value); - let actual_debug = to_debug(actual_value); - assert_eq!(expected_debug, actual_debug); - } - - #[test] - fn format_local_time() { - for time in test_times() { - let actual_value = LocalTime::from_micros(time); - let expected_value = NaiveTime::from(actual_value); - - check_display(expected_value, actual_value); - check_debug(expected_value, actual_value); - } - } - - #[test] - fn format_local_date() { - let dates = valid_test_dates().filter(|d| d.0 <= CHRONO_MAX_YEAR); - for (y, m, d) in dates { - let actual_value = LocalDate::from_ymd(y, m, d); - let expected = NaiveDate::from_ymd_opt(y, m as u32, d as u32).unwrap(); - - check_display(expected, actual_value); - check_debug(expected, actual_value); - } - } - - #[test] - fn format_local_datetime() { - let dates = valid_test_dates().filter(|d| d.0 <= CHRONO_MAX_YEAR); - for date in dates { - for time in test_times() { - let actual_date = LocalDate::from_ymd(date.0, date.1, date.2); - let actual_time = LocalTime::from_micros(time); - let actual_value = LocalDatetime::new(actual_date, actual_time); - let expected_value = NaiveDateTime::from(actual_value); - - check_display(expected_value, actual_value); - check_debug(expected_value, actual_value); - } - } - } - - #[test] - fn format_datetime() { - let dates = valid_test_dates().filter(|d| d.0 <= CHRONO_MAX_YEAR); - for date in dates { - for time in test_times() { - let actual_date = LocalDate::from_ymd(date.0, date.1, date.2); - let actual_time = LocalTime::from_micros(time); - let local_datetime = LocalDatetime::new(actual_date, actual_time); - let actual_value = local_datetime.to_utc(); - let expected_value = ChronoDatetime::from(actual_value); - - check_display(expected_value, actual_value); - check_debug(expected_value, actual_value); - } - } - } - - #[test] - fn date_duration() -> Result<(), Box> { - assert_eq!(DateDuration::from_years(1).to_string(), "P1Y"); - assert_eq!(DateDuration::from_months(1).to_string(), "P1M"); - assert_eq!(DateDuration::from_days(1).to_string(), "P1D"); - assert_eq!(DateDuration::from_months(10).to_string(), "P10M"); - assert_eq!(DateDuration::from_months(20).to_string(), "P1Y8M"); - assert_eq!(DateDuration::from_days(131).to_string(), "P131D"); - assert_eq!( - (DateDuration::from_months(7) + DateDuration::from_days(131)).to_string(), - "P7M131D" - ); - Ok(()) - } - } -} diff --git a/edgedb-protocol/src/model/vector.rs b/edgedb-protocol/src/model/vector.rs deleted file mode 100644 index d6b5bb24..00000000 --- a/edgedb-protocol/src/model/vector.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::ops::{Deref, DerefMut}; - -use bytes::Buf; -use snafu::ensure; - -use crate::codec; -use crate::descriptors::TypePos; -use crate::errors::{self, DecodeError}; -use crate::queryable::DescriptorMismatch; -use crate::queryable::{Decoder, DescriptorContext, Queryable}; -use crate::serialization::decode::queryable::scalars::check_scalar; - -/// A structure that represents `ext::pgvector::vector` -#[derive(Debug, PartialEq, Clone)] -#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Vector(pub Vec); - -impl Deref for Vector { - type Target = Vec; - fn deref(&self) -> &Vec { - &self.0 - } -} - -impl DerefMut for Vector { - fn deref_mut(&mut self) -> &mut Vec { - &mut self.0 - } -} - -impl Queryable for Vector { - fn decode(_decoder: &Decoder, mut buf: &[u8]) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - let length = buf.get_u16() as usize; - let _reserved = buf.get_u16(); - ensure!(buf.remaining() >= length * 4, errors::Underflow); - let vec = (0..length).map(|_| f32::from_bits(buf.get_u32())).collect(); - Ok(Vector(vec)) - } - - fn check_descriptor( - ctx: &DescriptorContext, - type_pos: TypePos, - ) -> Result<(), DescriptorMismatch> { - check_scalar( - ctx, - type_pos, - codec::PGVECTOR_VECTOR, - "ext::pgvector::vector", - ) - } -} diff --git a/edgedb-protocol/src/query_arg.rs b/edgedb-protocol/src/query_arg.rs deleted file mode 100644 index d7deacea..00000000 --- a/edgedb-protocol/src/query_arg.rs +++ /dev/null @@ -1,555 +0,0 @@ -/*! -Contains the [QueryArg] and [QueryArgs] traits. -*/ - -use std::convert::{TryFrom, TryInto}; -use std::ops::Deref; -use std::sync::Arc; - -use bytes::{BufMut, BytesMut}; -use snafu::OptionExt; -use uuid::Uuid; - -use edgedb_errors::ParameterTypeMismatchError; -use edgedb_errors::{ClientEncodingError, DescriptorMismatch, ProtocolError}; -use edgedb_errors::{Error, ErrorKind, InvalidReferenceError}; - -use crate::codec::{self, build_codec, Codec}; -use crate::descriptors::TypePos; -use crate::descriptors::{Descriptor, EnumerationTypeDescriptor}; -use crate::errors; -use crate::features::ProtocolVersion; -use crate::model::range; -use crate::value::Value; - -pub struct Encoder<'a> { - pub ctx: &'a DescriptorContext<'a>, - pub buf: &'a mut BytesMut, -} - -/// A single argument for a query. -pub trait QueryArg: Send + Sync + Sized { - fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error>; - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>; - fn to_value(&self) -> Result; -} - -pub trait ScalarArg: Send + Sync + Sized { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>; - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error>; - fn to_value(&self) -> Result; -} - -/// A tuple of query arguments. -/// -/// This trait is implemented for tuples of sizes up to twelve. You can derive -/// it for a structure in this case it's treated as a named tuple (i.e. query -/// should include named arguments rather than numeric ones). -pub trait QueryArgs: Send + Sync { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error>; -} - -pub struct DescriptorContext<'a> { - #[allow(dead_code)] - pub(crate) proto: &'a ProtocolVersion, - pub(crate) root_pos: Option, - pub(crate) descriptors: &'a [Descriptor], -} - -impl<'a> Encoder<'a> { - pub fn new(ctx: &'a DescriptorContext<'a>, buf: &'a mut BytesMut) -> Encoder<'a> { - Encoder { ctx, buf } - } - pub fn length_prefixed( - &mut self, - f: impl FnOnce(&mut Encoder) -> Result<(), Error>, - ) -> Result<(), Error> { - self.buf.reserve(4); - let pos = self.buf.len(); - self.buf.put_u32(0); // replaced after serializing a value - // - f(self)?; - - let len = self.buf.len() - pos - 4; - self.buf[pos..pos + 4].copy_from_slice( - &u32::try_from(len) - .map_err(|_| ClientEncodingError::with_message("alias is too long"))? - .to_be_bytes(), - ); - - Ok(()) - } -} - -impl DescriptorContext<'_> { - pub fn get(&self, type_pos: TypePos) -> Result<&Descriptor, Error> { - self.descriptors - .get(type_pos.0 as usize) - .ok_or_else(|| ProtocolError::with_message("invalid type descriptor")) - } - pub fn build_codec(&self) -> Result, Error> { - build_codec(self.root_pos, self.descriptors) - .map_err(|e| ProtocolError::with_source(e).context("error decoding input codec")) - } - pub fn wrong_type(&self, descriptor: &Descriptor, expected: &str) -> Error { - DescriptorMismatch::with_message(format!( - "server returned unexpected type {descriptor:?} when client expected {expected}" - )) - } - pub fn field_number(&self, expected: usize, unexpected: usize) -> Error { - DescriptorMismatch::with_message(format!( - "expected {} fields, got {}", - expected, unexpected - )) - } -} - -impl ScalarArg for &T { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - (*self).encode(encoder) - } - - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - T::check_descriptor(ctx, pos) - } - - fn to_value(&self) -> Result { - (*self).to_value() - } -} - -impl QueryArgs for () { - fn encode(&self, enc: &mut Encoder) -> Result<(), Error> { - if enc.ctx.root_pos.is_some() { - if enc.ctx.proto.is_at_most(0, 11) { - let root = enc.ctx.root_pos.and_then(|p| enc.ctx.get(p).ok()); - match root { - Some(Descriptor::Tuple(t)) - if t.id == Uuid::from_u128(0xFF) && t.element_types.is_empty() => {} - _ => { - return Err(ParameterTypeMismatchError::with_message( - "query arguments expected", - )) - } - }; - } else { - return Err(ParameterTypeMismatchError::with_message( - "query arguments expected", - )); - } - } - if enc.ctx.proto.is_at_most(0, 11) { - enc.buf.reserve(4); - enc.buf.put_u32(0); - } - Ok(()) - } -} - -impl QueryArg for Value { - fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> { - use Value::*; - match self { - Nothing => { - enc.buf.reserve(4); - enc.buf.put_i32(-1); - } - Uuid(v) => v.encode_slot(enc)?, - Str(v) => v.encode_slot(enc)?, - Bytes(v) => v.encode_slot(enc)?, - Int16(v) => v.encode_slot(enc)?, - Int32(v) => v.encode_slot(enc)?, - Int64(v) => v.encode_slot(enc)?, - Float32(v) => v.encode_slot(enc)?, - Float64(v) => v.encode_slot(enc)?, - BigInt(v) => v.encode_slot(enc)?, - ConfigMemory(v) => v.encode_slot(enc)?, - Decimal(v) => v.encode_slot(enc)?, - Bool(v) => v.encode_slot(enc)?, - Datetime(v) => v.encode_slot(enc)?, - LocalDatetime(v) => v.encode_slot(enc)?, - LocalDate(v) => v.encode_slot(enc)?, - LocalTime(v) => v.encode_slot(enc)?, - Duration(v) => v.encode_slot(enc)?, - RelativeDuration(v) => v.encode_slot(enc)?, - DateDuration(v) => v.encode_slot(enc)?, - Json(v) => v.encode_slot(enc)?, - Set(_) => { - return Err(ClientEncodingError::with_message( - "set cannot be query argument", - )) - } - Object { .. } => { - return Err(ClientEncodingError::with_message( - "object cannot be query argument", - )) - } - SparseObject(_) => { - return Err(ClientEncodingError::with_message( - "sparse object cannot be query argument", - )) - } - Tuple(_) => { - return Err(ClientEncodingError::with_message( - "tuple object cannot be query argument", - )) - } - NamedTuple { .. } => { - return Err(ClientEncodingError::with_message( - "named tuple object cannot be query argument", - )) - } - Array(v) => v.encode_slot(enc)?, - Enum(v) => v.encode_slot(enc)?, - Range(v) => v.encode_slot(enc)?, - Vector(v) => v.encode_slot(enc)?, - PostGisGeometry(v) => v.encode_slot(enc)?, - PostGisGeography(v) => v.encode_slot(enc)?, - PostGisBox2d(v) => v.encode_slot(enc)?, - PostGisBox3d(v) => v.encode_slot(enc)?, - SQLRow { .. } => { - return Err(ClientEncodingError::with_message( - "SQL row cannot be query argument", - )) - } - } - - Ok(()) - } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - use Descriptor::*; - use Value::*; - let desc = ctx.get(pos)?.normalize_to_base(ctx)?; - - match (self, desc) { - (Nothing, _) => Ok(()), // any descriptor works - (BigInt(_), BaseScalar(d)) if d.id == codec::STD_BIGINT => Ok(()), - (Bool(_), BaseScalar(d)) if d.id == codec::STD_BOOL => Ok(()), - (Bytes(_), BaseScalar(d)) if d.id == codec::STD_BYTES => Ok(()), - (ConfigMemory(_), BaseScalar(d)) if d.id == codec::CFG_MEMORY => Ok(()), - (DateDuration(_), BaseScalar(d)) if d.id == codec::CAL_DATE_DURATION => Ok(()), - (Datetime(_), BaseScalar(d)) if d.id == codec::STD_DATETIME => Ok(()), - (Decimal(_), BaseScalar(d)) if d.id == codec::STD_DECIMAL => Ok(()), - (Duration(_), BaseScalar(d)) if d.id == codec::STD_DURATION => Ok(()), - (Float32(_), BaseScalar(d)) if d.id == codec::STD_FLOAT32 => Ok(()), - (Float64(_), BaseScalar(d)) if d.id == codec::STD_FLOAT64 => Ok(()), - (Int16(_), BaseScalar(d)) if d.id == codec::STD_INT16 => Ok(()), - (Int32(_), BaseScalar(d)) if d.id == codec::STD_INT32 => Ok(()), - (Int64(_), BaseScalar(d)) if d.id == codec::STD_INT64 => Ok(()), - (Json(_), BaseScalar(d)) if d.id == codec::STD_JSON => Ok(()), - (LocalDate(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATE => Ok(()), - (LocalDatetime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_DATETIME => Ok(()), - (LocalTime(_), BaseScalar(d)) if d.id == codec::CAL_LOCAL_TIME => Ok(()), - (RelativeDuration(_), BaseScalar(d)) if d.id == codec::CAL_RELATIVE_DURATION => Ok(()), - (Str(_), BaseScalar(d)) if d.id == codec::STD_STR => Ok(()), - (Uuid(_), BaseScalar(d)) if d.id == codec::STD_UUID => Ok(()), - (Enum(val), Enumeration(EnumerationTypeDescriptor { members, .. })) => { - let val = val.deref(); - check_enum(val, &members) - } - (PostGisGeometry(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOMETRY => Ok(()), - (PostGisGeography(_), BaseScalar(d)) if d.id == codec::POSTGIS_GEOGRAPHY => Ok(()), - (PostGisBox2d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_2D => Ok(()), - (PostGisBox3d(_), BaseScalar(d)) if d.id == codec::POSTGIS_BOX_3D => Ok(()), - // TODO(tailhook) all types - (_, desc) => Err(ctx.wrong_type(&desc, self.kind())), - } - } - fn to_value(&self) -> Result { - Ok(self.clone()) - } -} - -pub(crate) fn check_enum(variant_name: &str, expected_members: &[String]) -> Result<(), Error> { - if expected_members.iter().any(|c| c == variant_name) { - Ok(()) - } else { - let mut members = expected_members - .iter() - .map(|c| format!("'{c}'")) - .collect::>(); - members.sort_unstable(); - let members = members.join(", "); - Err(InvalidReferenceError::with_message(format!( - "Expected one of: {members}, while enum value '{variant_name}' was provided" - ))) - } -} - -impl QueryArgs for Value { - fn encode(&self, enc: &mut Encoder) -> Result<(), Error> { - let codec = enc.ctx.build_codec()?; - codec - .encode(enc.buf, self) - .map_err(ClientEncodingError::with_source) - } -} - -impl QueryArg for T { - fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> { - enc.buf.reserve(4); - let pos = enc.buf.len(); - enc.buf.put_u32(0); // will fill after encoding - ScalarArg::encode(self, enc)?; - let len = enc.buf.len() - pos - 4; - enc.buf[pos..pos + 4].copy_from_slice( - &i32::try_from(len) - .ok() - .context(errors::ElementTooLong) - .map_err(ClientEncodingError::with_source)? - .to_be_bytes(), - ); - Ok(()) - } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - T::check_descriptor(ctx, pos) - } - fn to_value(&self) -> Result { - ScalarArg::to_value(self) - } -} - -impl QueryArg for Option { - fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> { - if let Some(val) = self { - QueryArg::encode_slot(val, enc) - } else { - enc.buf.reserve(4); - enc.buf.put_i32(-1); - Ok(()) - } - } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - T::check_descriptor(ctx, pos) - } - fn to_value(&self) -> Result { - match self.as_ref() { - Some(v) => ScalarArg::to_value(v), - None => Ok(Value::Nothing), - } - } -} - -impl QueryArg for Vec { - fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> { - enc.buf.reserve(8); - enc.length_prefixed(|enc| { - if self.is_empty() { - enc.buf.reserve(12); - enc.buf.put_u32(0); // ndims - enc.buf.put_u32(0); // reserved0 - enc.buf.put_u32(0); // reserved1 - return Ok(()); - } - enc.buf.reserve(20); - enc.buf.put_u32(1); // ndims - enc.buf.put_u32(0); // reserved0 - enc.buf.put_u32(0); // reserved1 - enc.buf.put_u32( - self.len() - .try_into() - .map_err(|_| ClientEncodingError::with_message("array is too long"))?, - ); - enc.buf.put_u32(1); // lower - for item in self { - enc.length_prefixed(|enc| item.encode(enc))?; - } - Ok(()) - }) - } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - let desc = ctx.get(pos)?; - if let Descriptor::Array(arr) = desc { - T::check_descriptor(ctx, arr.type_pos) - } else { - Err(ctx.wrong_type(desc, "array")) - } - } - fn to_value(&self) -> Result { - Ok(Value::Array( - self.iter() - .map(|v| v.to_value()) - .collect::>()?, - )) - } -} - -impl QueryArg for Vec { - fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> { - enc.buf.reserve(8); - enc.length_prefixed(|enc| { - if self.is_empty() { - enc.buf.reserve(12); - enc.buf.put_u32(0); // ndims - enc.buf.put_u32(0); // reserved0 - enc.buf.put_u32(0); // reserved1 - return Ok(()); - } - enc.buf.reserve(20); - enc.buf.put_u32(1); // ndims - enc.buf.put_u32(0); // reserved0 - enc.buf.put_u32(0); // reserved1 - enc.buf.put_u32( - self.len() - .try_into() - .map_err(|_| ClientEncodingError::with_message("array is too long"))?, - ); - enc.buf.put_u32(1); // lower - for item in self { - enc.length_prefixed(|enc| item.encode(enc))?; - } - Ok(()) - }) - } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - let desc = ctx.get(pos)?; - if let Descriptor::Array(arr) = desc { - for val in self { - val.check_descriptor(ctx, arr.type_pos)? - } - Ok(()) - } else { - Err(ctx.wrong_type(desc, "array")) - } - } - fn to_value(&self) -> Result { - Ok(Value::Array( - self.iter() - .map(|v| v.to_value()) - .collect::>()?, - )) - } -} - -impl QueryArg for range::Range> { - fn encode_slot(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.length_prefixed(|encoder| { - let flags = if self.empty { - range::EMPTY - } else { - (if self.inc_lower { range::LB_INC } else { 0 }) - | (if self.inc_upper { range::UB_INC } else { 0 }) - | (if self.lower.is_none() { - range::LB_INF - } else { - 0 - }) - | (if self.upper.is_none() { - range::UB_INF - } else { - 0 - }) - }; - encoder.buf.reserve(1); - encoder.buf.put_u8(flags as u8); - - if let Some(lower) = &self.lower { - encoder.length_prefixed(|encoder| lower.encode(encoder))? - } - - if let Some(upper) = &self.upper { - encoder.length_prefixed(|encoder| upper.encode(encoder))?; - } - Ok(()) - }) - } - fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - let desc = ctx.get(pos)?; - if let Descriptor::Range(rng) = desc { - self.lower - .as_ref() - .map(|v| v.check_descriptor(ctx, rng.type_pos)) - .transpose()?; - self.upper - .as_ref() - .map(|v| v.check_descriptor(ctx, rng.type_pos)) - .transpose()?; - Ok(()) - } else { - Err(ctx.wrong_type(desc, "range")) - } - } - fn to_value(&self) -> Result { - Ok(Value::Range(self.clone())) - } -} - -macro_rules! implement_tuple { - ( $count:expr, $($name:ident,)+ ) => { - impl<$($name:QueryArg),+> QueryArgs for ($($name,)+) { - fn encode(&self, enc: &mut Encoder) - -> Result<(), Error> - { - #![allow(non_snake_case)] - let root_pos = enc.ctx.root_pos - .ok_or_else(|| DescriptorMismatch::with_message( - format!( - "provided {} positional arguments, \ - but no arguments expected by the server", - $count)))?; - let desc = enc.ctx.get(root_pos)?; - match desc { - Descriptor::ObjectShape(desc) - if enc.ctx.proto.is_at_least(0, 12) - => { - if desc.elements.len() != $count { - return Err(enc.ctx.field_number( - desc.elements.len(), $count)); - } - let mut els = desc.elements.iter().enumerate(); - let ($(ref $name,)+) = self; - $( - let (idx, el) = els.next().unwrap(); - if el.name.parse() != Ok(idx) { - return Err(DescriptorMismatch::with_message( - format!("expected positional arguments, \ - got {} instead of {}", - el.name, idx))); - } - $name.check_descriptor(enc.ctx, el.type_pos)?; - )+ - } - Descriptor::Tuple(desc) if enc.ctx.proto.is_at_most(0, 11) - => { - if desc.element_types.len() != $count { - return Err(enc.ctx.field_number( - desc.element_types.len(), $count)); - } - let mut els = desc.element_types.iter(); - let ($(ref $name,)+) = self; - $( - let type_pos = els.next().unwrap(); - $name.check_descriptor(enc.ctx, *type_pos)?; - )+ - } - _ => return Err(enc.ctx.wrong_type(desc, - if enc.ctx.proto.is_at_least(0, 12) { "object" } - else { "tuple" })) - } - - enc.buf.reserve(4 + 8*$count); - enc.buf.put_u32($count); - let ($(ref $name,)+) = self; - $( - enc.buf.reserve(8); - enc.buf.put_u32(0); - QueryArg::encode_slot($name, enc)?; - )* - Ok(()) - } - } - } -} - -implement_tuple! {1, T0, } -implement_tuple! {2, T0, T1, } -implement_tuple! {3, T0, T1, T2, } -implement_tuple! {4, T0, T1, T2, T3, } -implement_tuple! {5, T0, T1, T2, T3, T4, } -implement_tuple! {6, T0, T1, T2, T3, T4, T5, } -implement_tuple! {7, T0, T1, T2, T3, T4, T5, T6, } -implement_tuple! {8, T0, T1, T2, T3, T4, T5, T6, T7, } -implement_tuple! {9, T0, T1, T2, T3, T4, T5, T6, T7, T8, } -implement_tuple! {10, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, } -implement_tuple! {11, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, } -implement_tuple! {12, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, } diff --git a/edgedb-protocol/src/query_result.rs b/edgedb-protocol/src/query_result.rs deleted file mode 100644 index 5f163e64..00000000 --- a/edgedb-protocol/src/query_result.rs +++ /dev/null @@ -1,67 +0,0 @@ -/*! -Contains the [QueryResult](crate::query_result::QueryResult) trait. -*/ - -use std::sync::Arc; - -use bytes::Bytes; - -use edgedb_errors::{DescriptorMismatch, ProtocolEncodingError}; -use edgedb_errors::{Error, ErrorKind}; - -use crate::codec::Codec; -use crate::descriptors::TypePos; -use crate::queryable::{Decoder, DescriptorContext, Queryable}; -use crate::value::Value; - -pub trait Sealed: Sized {} - -/// A trait representing single result from a query. -/// -/// This is implemented for scalars and tuples. To receive a shape from EdgeDB -/// derive [`Queryable`](Queryable) for a structure. This will automatically -/// implement `QueryResult` for you. -pub trait QueryResult: Sealed { - type State; - fn prepare(ctx: &DescriptorContext, root_pos: TypePos) -> Result; - fn decode(state: &mut Self::State, msg: &Bytes) -> Result; -} - -impl Sealed for T {} - -impl Sealed for Value {} - -impl QueryResult for T { - type State = Decoder; - fn prepare(ctx: &DescriptorContext, root_pos: TypePos) -> Result { - T::check_descriptor(ctx, root_pos).map_err(DescriptorMismatch::with_source)?; - Ok(Decoder { - has_implicit_id: ctx.has_implicit_id, - has_implicit_tid: ctx.has_implicit_tid, - has_implicit_tname: ctx.has_implicit_tname, - }) - } - fn decode(decoder: &mut Decoder, msg: &Bytes) -> Result { - Queryable::decode(decoder, msg).map_err(ProtocolEncodingError::with_source) - } -} - -impl QueryResult for Value { - type State = Arc; - fn prepare(ctx: &DescriptorContext, root_pos: TypePos) -> Result, Error> { - ctx.build_codec(root_pos) - } - fn decode(codec: &mut Arc, msg: &Bytes) -> Result { - let res = codec.decode(msg); - - match res { - Ok(v) => Ok(v), - Err(e) => { - if let Some(bt) = snafu::ErrorCompat::backtrace(&e) { - eprintln!("{bt}"); - } - Err(ProtocolEncodingError::with_source(e)) - } - } - } -} diff --git a/edgedb-protocol/src/queryable.rs b/edgedb-protocol/src/queryable.rs deleted file mode 100644 index 0c8f09af..00000000 --- a/edgedb-protocol/src/queryable.rs +++ /dev/null @@ -1,101 +0,0 @@ -/*! -Contains the [Queryable] trait. -*/ -use snafu::{ensure, Snafu}; -use std::default::Default; -use std::sync::Arc; - -use crate::codec::{build_codec, Codec}; -use crate::descriptors::{Descriptor, TypePos}; -use crate::errors::{self, DecodeError}; -use edgedb_errors::{Error, ErrorKind, ProtocolEncodingError}; - -#[non_exhaustive] -#[derive(Default)] -pub struct Decoder { - pub has_implicit_id: bool, - pub has_implicit_tid: bool, - pub has_implicit_tname: bool, -} - -pub trait Queryable: Sized { - fn decode(decoder: &Decoder, buf: &[u8]) -> Result; - fn decode_optional(decoder: &Decoder, buf: Option<&[u8]>) -> Result { - ensure!(buf.is_some(), errors::MissingRequiredElement); - Self::decode(decoder, buf.unwrap()) - } - fn check_descriptor( - ctx: &DescriptorContext, - type_pos: TypePos, - ) -> Result<(), DescriptorMismatch>; -} - -#[derive(Snafu, Debug)] -#[non_exhaustive] -pub enum DescriptorMismatch { - #[snafu(display("unexpected type {}, expected {}", unexpected, expected))] - WrongType { - unexpected: String, - expected: String, - }, - #[snafu(display("unexpected field {}, expected {}", unexpected, expected))] - WrongField { - unexpected: String, - expected: String, - }, - #[snafu(display("expected {} fields, got {}", expected, unexpected))] - FieldNumber { unexpected: usize, expected: usize }, - #[snafu(display("expected {}", expected))] - Expected { expected: String }, - #[snafu(display("invalid type descriptor"))] - InvalidDescriptor, -} - -pub struct DescriptorContext<'a> { - pub has_implicit_id: bool, - pub has_implicit_tid: bool, - pub has_implicit_tname: bool, - descriptors: &'a [Descriptor], -} - -impl DescriptorContext<'_> { - pub(crate) fn new(descriptors: &[Descriptor]) -> DescriptorContext { - DescriptorContext { - descriptors, - has_implicit_id: false, - has_implicit_tid: false, - has_implicit_tname: false, - } - } - pub fn build_codec(&self, root_pos: TypePos) -> Result, Error> { - build_codec(Some(root_pos), self.descriptors).map_err(ProtocolEncodingError::with_source) - } - pub fn get(&self, type_pos: TypePos) -> Result<&Descriptor, DescriptorMismatch> { - self.descriptors - .get(type_pos.0 as usize) - .ok_or(DescriptorMismatch::InvalidDescriptor) - } - pub fn wrong_type(&self, descriptor: &Descriptor, expected: &str) -> DescriptorMismatch { - DescriptorMismatch::WrongType { - unexpected: format!("{descriptor:?}"), - expected: expected.into(), - } - } - pub fn field_number(&self, expected: usize, unexpected: usize) -> DescriptorMismatch { - DescriptorMismatch::FieldNumber { - expected, - unexpected, - } - } - pub fn wrong_field(&self, expected: &str, unexpected: &str) -> DescriptorMismatch { - DescriptorMismatch::WrongField { - expected: expected.into(), - unexpected: unexpected.into(), - } - } - pub fn expected(&self, expected: &str) -> DescriptorMismatch { - DescriptorMismatch::Expected { - expected: expected.into(), - } - } -} diff --git a/edgedb-protocol/src/serialization.rs b/edgedb-protocol/src/serialization.rs deleted file mode 100644 index 7fb30ae1..00000000 --- a/edgedb-protocol/src/serialization.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod decode; - -#[cfg(test)] -mod test_scalars; diff --git a/edgedb-protocol/src/serialization/decode.rs b/edgedb-protocol/src/serialization/decode.rs deleted file mode 100644 index 403319cc..00000000 --- a/edgedb-protocol/src/serialization/decode.rs +++ /dev/null @@ -1,11 +0,0 @@ -pub(crate) mod queryable; -mod raw_composite; -mod raw_scalar; - -#[cfg(feature = "chrono")] -mod chrono; - -pub(crate) use self::raw_composite::DecodeArrayLike; -pub(crate) use self::raw_composite::DecodeRange; -pub use self::raw_composite::DecodeTupleLike; -pub(crate) use self::raw_scalar::RawCodec; diff --git a/edgedb-protocol/src/serialization/decode/chrono.rs b/edgedb-protocol/src/serialization/decode/chrono.rs deleted file mode 100644 index 2ea64e74..00000000 --- a/edgedb-protocol/src/serialization/decode/chrono.rs +++ /dev/null @@ -1,27 +0,0 @@ -use crate::errors::DecodeError; -use crate::serialization::decode::raw_scalar::RawCodec; -use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; - -impl<'t> RawCodec<'t> for DateTime { - fn decode(buf: &[u8]) -> Result { - crate::model::Datetime::decode(buf).map(Into::into) - } -} - -impl<'t> RawCodec<'t> for NaiveDateTime { - fn decode(buf: &[u8]) -> Result { - crate::model::LocalDatetime::decode(buf).map(Into::into) - } -} - -impl<'t> RawCodec<'t> for NaiveDate { - fn decode(buf: &[u8]) -> Result { - crate::model::LocalDate::decode(buf).map(Into::into) - } -} - -impl<'t> RawCodec<'t> for NaiveTime { - fn decode(buf: &[u8]) -> Result { - crate::model::LocalTime::decode(buf).map(Into::into) - } -} diff --git a/edgedb-protocol/src/serialization/decode/queryable.rs b/edgedb-protocol/src/serialization/decode/queryable.rs deleted file mode 100644 index ad8a6cce..00000000 --- a/edgedb-protocol/src/serialization/decode/queryable.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub(crate) mod collections; -pub(crate) mod scalars; -pub(crate) mod tuples; diff --git a/edgedb-protocol/src/serialization/decode/queryable/collections.rs b/edgedb-protocol/src/serialization/decode/queryable/collections.rs deleted file mode 100644 index aae9be90..00000000 --- a/edgedb-protocol/src/serialization/decode/queryable/collections.rs +++ /dev/null @@ -1,73 +0,0 @@ -use crate::descriptors::{Descriptor, TypePos}; -use crate::errors::DecodeError; -use crate::queryable::DescriptorMismatch; -use crate::queryable::{Decoder, DescriptorContext, Queryable}; -use crate::serialization::decode::DecodeArrayLike; -use std::iter::FromIterator; - -impl Queryable for Option { - fn decode(decoder: &Decoder, buf: &[u8]) -> Result { - Ok(Some(T::decode(decoder, buf)?)) - } - - fn decode_optional(decoder: &Decoder, buf: Option<&[u8]>) -> Result { - buf.map(|buf| T::decode(decoder, buf)).transpose() - } - - fn check_descriptor( - ctx: &DescriptorContext, - type_pos: TypePos, - ) -> Result<(), DescriptorMismatch> { - T::check_descriptor(ctx, type_pos) - } -} - -struct Collection(T); - -impl::Item>> Collection -where - ::Item: Queryable, -{ - fn decode(decoder: &Decoder, buf: &[u8]) -> Result { - let elements = DecodeArrayLike::new_collection(buf)?; - let elements = elements.map(|e| ::Item::decode(decoder, e?)); - elements.collect::>() - } - - fn decode_optional(decoder: &Decoder, buf: Option<&[u8]>) -> Result { - match buf { - Some(buf) => Self::decode(decoder, buf), - None => Ok(T::from_iter(std::iter::empty())), - } - } - - fn check_descriptor( - ctx: &DescriptorContext, - type_pos: TypePos, - ) -> Result<(), DescriptorMismatch> { - let desc = ctx.get(type_pos)?; - let element_type_pos = match desc { - Descriptor::Set(desc) => desc.type_pos, - Descriptor::Array(desc) => desc.type_pos, - _ => return Err(ctx.wrong_type(desc, "array or set")), - }; - ::Item::check_descriptor(ctx, element_type_pos) - } -} - -impl Queryable for Vec { - fn decode(decoder: &Decoder, buf: &[u8]) -> Result { - Collection::>::decode(decoder, buf) - } - - fn decode_optional(decoder: &Decoder, buf: Option<&[u8]>) -> Result { - Collection::>::decode_optional(decoder, buf) - } - - fn check_descriptor( - ctx: &DescriptorContext, - type_pos: TypePos, - ) -> Result<(), DescriptorMismatch> { - Collection::>::check_descriptor(ctx, type_pos) - } -} diff --git a/edgedb-protocol/src/serialization/decode/queryable/scalars.rs b/edgedb-protocol/src/serialization/decode/queryable/scalars.rs deleted file mode 100644 index 85d30bdc..00000000 --- a/edgedb-protocol/src/serialization/decode/queryable/scalars.rs +++ /dev/null @@ -1,309 +0,0 @@ -use bytes::Bytes; - -use crate::queryable::DescriptorMismatch; -use crate::queryable::{Decoder, DescriptorContext, Queryable}; - -use crate::codec; -use crate::descriptors::TypePos; -use crate::errors::DecodeError; -use crate::model::{BigInt, Decimal, Json, RelativeDuration, Uuid}; -use crate::model::{ConfigMemory, DateDuration}; -use crate::model::{Datetime, Duration, LocalDate, LocalDatetime, LocalTime}; -use crate::serialization::decode::RawCodec; -use std::time::SystemTime; - -pub(crate) fn check_scalar( - ctx: &DescriptorContext, - type_pos: TypePos, - type_id: Uuid, - name: &str, -) -> Result<(), DescriptorMismatch> { - use crate::descriptors::Descriptor::{BaseScalar, Scalar}; - let desc = ctx.get(type_pos)?; - match desc { - Scalar(scalar) if scalar.base_type_pos.is_some() => { - return check_scalar(ctx, scalar.base_type_pos.unwrap(), type_id, name); - } - Scalar(scalar) if *scalar.id == type_id => { - return Ok(()); - } - BaseScalar(base) if *base.id == type_id => { - return Ok(()); - } - _ => {} - } - Err(ctx.wrong_type(desc, name)) -} - -pub trait DecodeScalar: for<'a> RawCodec<'a> + Sized { - fn uuid() -> Uuid; - fn typename() -> &'static str; -} - -impl Queryable for T { - fn decode(_decoder: &Decoder, buf: &[u8]) -> Result { - RawCodec::decode(buf) - } - fn check_descriptor( - ctx: &DescriptorContext, - type_pos: TypePos, - ) -> Result<(), DescriptorMismatch> { - check_scalar(ctx, type_pos, T::uuid(), T::typename()) - } -} - -impl DecodeScalar for String { - fn uuid() -> Uuid { - codec::STD_STR - } - fn typename() -> &'static str { - "std::str" - } -} - -impl DecodeScalar for Bytes { - fn uuid() -> Uuid { - codec::STD_BYTES - } - fn typename() -> &'static str { - "std::bytes" - } -} - -impl DecodeScalar for Json { - fn uuid() -> Uuid { - codec::STD_JSON - } - fn typename() -> &'static str { - "std::json" - } -} - -/* -impl DecodeScalar for Vec { - fn uuid() -> Uuid { codec::STD_BYTES } - fn typename() -> &'static str { "std::bytes" } -} -*/ - -impl DecodeScalar for i16 { - fn uuid() -> Uuid { - codec::STD_INT16 - } - fn typename() -> &'static str { - "std::int16" - } -} - -impl DecodeScalar for i32 { - fn uuid() -> Uuid { - codec::STD_INT32 - } - fn typename() -> &'static str { - "std::int32" - } -} - -impl DecodeScalar for i64 { - fn uuid() -> Uuid { - codec::STD_INT64 - } - fn typename() -> &'static str { - "std::int64" - } -} - -impl DecodeScalar for f32 { - fn uuid() -> Uuid { - codec::STD_FLOAT32 - } - fn typename() -> &'static str { - "std::float32" - } -} - -impl DecodeScalar for f64 { - fn uuid() -> Uuid { - codec::STD_FLOAT64 - } - fn typename() -> &'static str { - "std::float64" - } -} - -impl DecodeScalar for Uuid { - fn uuid() -> Uuid { - codec::STD_UUID - } - fn typename() -> &'static str { - "std::uuid" - } -} - -impl DecodeScalar for bool { - fn uuid() -> Uuid { - codec::STD_BOOL - } - fn typename() -> &'static str { - "std::bool" - } -} - -impl DecodeScalar for BigInt { - fn uuid() -> Uuid { - codec::STD_BIGINT - } - fn typename() -> &'static str { - "std::bigint" - } -} - -#[cfg(feature = "num-bigint")] -impl DecodeScalar for num_bigint::BigInt { - fn uuid() -> Uuid { - codec::STD_BIGINT - } - fn typename() -> &'static str { - "std::bigint" - } -} - -impl DecodeScalar for Decimal { - fn uuid() -> Uuid { - codec::STD_DECIMAL - } - fn typename() -> &'static str { - "std::decimal" - } -} - -#[cfg(feature = "bigdecimal")] -impl DecodeScalar for bigdecimal::BigDecimal { - fn uuid() -> Uuid { - codec::STD_DECIMAL - } - fn typename() -> &'static str { - "std::decimal" - } -} - -impl DecodeScalar for LocalDatetime { - fn uuid() -> Uuid { - codec::CAL_LOCAL_DATETIME - } - fn typename() -> &'static str { - "cal::local_datetime" - } -} - -#[cfg(feature = "chrono")] -impl DecodeScalar for chrono::NaiveDateTime { - fn uuid() -> Uuid { - codec::CAL_LOCAL_DATETIME - } - fn typename() -> &'static str { - "cal::local_datetime" - } -} - -impl DecodeScalar for LocalDate { - fn uuid() -> Uuid { - codec::CAL_LOCAL_DATE - } - fn typename() -> &'static str { - "cal::local_date" - } -} - -#[cfg(feature = "chrono")] -impl DecodeScalar for chrono::NaiveDate { - fn uuid() -> Uuid { - codec::CAL_LOCAL_DATE - } - fn typename() -> &'static str { - "cal::local_date" - } -} - -impl DecodeScalar for LocalTime { - fn uuid() -> Uuid { - codec::CAL_LOCAL_TIME - } - fn typename() -> &'static str { - "cal::local_time" - } -} - -#[cfg(feature = "chrono")] -impl DecodeScalar for chrono::NaiveTime { - fn uuid() -> Uuid { - codec::CAL_LOCAL_TIME - } - fn typename() -> &'static str { - "cal::local_time" - } -} - -impl DecodeScalar for Duration { - fn uuid() -> Uuid { - codec::STD_DURATION - } - fn typename() -> &'static str { - "std::duration" - } -} - -impl DecodeScalar for RelativeDuration { - fn uuid() -> Uuid { - codec::CAL_RELATIVE_DURATION - } - fn typename() -> &'static str { - "cal::relative_duration" - } -} - -impl DecodeScalar for SystemTime { - fn uuid() -> Uuid { - codec::STD_DATETIME - } - fn typename() -> &'static str { - "std::datetime" - } -} - -impl DecodeScalar for Datetime { - fn uuid() -> Uuid { - codec::STD_DATETIME - } - fn typename() -> &'static str { - "std::datetime" - } -} - -#[cfg(feature = "chrono")] -impl DecodeScalar for chrono::DateTime { - fn uuid() -> Uuid { - codec::STD_DATETIME - } - fn typename() -> &'static str { - "std::datetime" - } -} - -impl DecodeScalar for ConfigMemory { - fn uuid() -> Uuid { - codec::CFG_MEMORY - } - fn typename() -> &'static str { - "cfg::memory" - } -} - -impl DecodeScalar for DateDuration { - fn uuid() -> Uuid { - codec::CAL_DATE_DURATION - } - fn typename() -> &'static str { - "cal::date_duration" - } -} diff --git a/edgedb-protocol/src/serialization/decode/queryable/tuples.rs b/edgedb-protocol/src/serialization/decode/queryable/tuples.rs deleted file mode 100644 index 51cd006e..00000000 --- a/edgedb-protocol/src/serialization/decode/queryable/tuples.rs +++ /dev/null @@ -1,53 +0,0 @@ -use crate::descriptors::{Descriptor, TypePos}; -use crate::errors::DecodeError; -use crate::queryable::DescriptorMismatch; -use crate::queryable::{Decoder, DescriptorContext, Queryable}; -use crate::serialization::decode::DecodeTupleLike; - -macro_rules! implement_tuple { - ( $count:expr, $($name:ident,)+ ) => ( - impl<$($name:Queryable),+> Queryable for ($($name,)+) { - fn decode(decoder: &Decoder, buf: &[u8]) - -> Result - { - let mut elements = DecodeTupleLike::new_tuple(buf, $count)?; - Ok(( - $( - <$name as crate::queryable::Queryable>:: - decode_optional(decoder, elements.read()?)?, - )+ - )) - } - - fn check_descriptor(ctx: &DescriptorContext, type_pos: TypePos) - -> Result<(), DescriptorMismatch> - { - let desc = ctx.get(type_pos)?; - match desc { - Descriptor::Tuple(desc) => { - if desc.element_types.len() != $count { - return Err(ctx.field_number($count, desc.element_types.len())); - } - let mut element_types = desc.element_types.iter().copied(); - $($name::check_descriptor(ctx, element_types.next().unwrap())?;)+ - Ok(()) - } - _ => Err(ctx.wrong_type(desc, "tuple")) - } - } - } - ) -} - -implement_tuple! {1, T0, } -implement_tuple! {2, T0, T1, } -implement_tuple! {3, T0, T1, T2, } -implement_tuple! {4, T0, T1, T2, T3, } -implement_tuple! {5, T0, T1, T2, T3, T4, } -implement_tuple! {6, T0, T1, T2, T3, T4, T5, } -implement_tuple! {7, T0, T1, T2, T3, T4, T5, T6, } -implement_tuple! {8, T0, T1, T2, T3, T4, T5, T6, T7, } -implement_tuple! {9, T0, T1, T2, T3, T4, T5, T6, T7, T8, } -implement_tuple! {10, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, } -implement_tuple! {11, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, } -implement_tuple! {12, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, } diff --git a/edgedb-protocol/src/serialization/decode/raw_composite.rs b/edgedb-protocol/src/serialization/decode/raw_composite.rs deleted file mode 100644 index 0d537837..00000000 --- a/edgedb-protocol/src/serialization/decode/raw_composite.rs +++ /dev/null @@ -1,212 +0,0 @@ -use self::inner::DecodeCompositeInner; -use crate::errors::{self, DecodeError}; -use snafu::ensure; - -pub struct DecodeTupleLike<'t> { - inner: DecodeCompositeInner<'t>, -} - -impl<'t> DecodeTupleLike<'t> { - fn new(buf: &'t [u8]) -> Result { - let inner = DecodeCompositeInner::read_tuple_like_header(buf)?; - Ok(DecodeTupleLike { inner }) - } - - pub fn new_object(buf: &'t [u8], expected_count: usize) -> Result { - let elements = Self::new(buf)?; - ensure!( - elements.inner.count() == expected_count, - errors::ObjectSizeMismatch - ); - Ok(elements) - } - - pub fn new_tuple(buf: &'t [u8], expected_count: usize) -> Result { - let elements = Self::new(buf)?; - ensure!( - elements.inner.count() == expected_count, - errors::TupleSizeMismatch - ); - Ok(elements) - } - - pub fn read(&mut self) -> Result, DecodeError> { - self.inner.read_object_element() - } - - pub fn skip_element(&mut self) -> Result<(), DecodeError> { - self.read()?; - Ok(()) - } -} - -pub struct DecodeArrayLike<'t> { - inner: DecodeCompositeInner<'t>, -} - -impl<'t> DecodeArrayLike<'t> { - pub fn new_array(buf: &'t [u8]) -> Result { - let inner = DecodeCompositeInner::read_array_like_header(buf, || { - errors::InvalidArrayShape.build() - })?; - Ok(DecodeArrayLike { inner }) - } - - pub fn new_set(buf: &'t [u8]) -> Result { - let inner = - DecodeCompositeInner::read_array_like_header(buf, || errors::InvalidSetShape.build())?; - Ok(DecodeArrayLike { inner }) - } - - pub fn new_collection(buf: &'t [u8]) -> Result { - let inner = DecodeCompositeInner::read_array_like_header(buf, || { - errors::InvalidArrayOrSetShape.build() - })?; - Ok(DecodeArrayLike { inner }) - } - - pub fn new_tuple_header(buf: &'t [u8]) -> Result { - let inner = DecodeCompositeInner::read_tuple_like_header(buf)?; - Ok(Self { inner }) - } -} - -pub struct DecodeRange<'t> { - inner: DecodeCompositeInner<'t>, -} - -impl<'t> DecodeRange<'t> { - pub fn new(buf: &'t [u8]) -> Result { - // flags header should already have been read externally - let inner = DecodeCompositeInner { raw: buf, count: 2 }; - Ok(DecodeRange { inner }) - } - pub fn read(&mut self) -> Result<&[u8], DecodeError> { - self.inner.read_array_like_element() - } -} - -impl<'t> Iterator for DecodeArrayLike<'t> { - type Item = Result<&'t [u8], DecodeError>; - - fn next(&mut self) -> Option { - if self.len() > 0 { - Some(self.inner.read_array_like_element()) - } else { - None - } - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.len(); - (len, Some(len)) - } -} - -impl<'t> ExactSizeIterator for DecodeArrayLike<'t> { - fn len(&self) -> usize { - self.inner.count() - } -} - -mod inner { - use crate::errors::{self, DecodeError}; - use bytes::Buf; - use snafu::ensure; - - pub(super) struct DecodeCompositeInner<'t> { - pub raw: &'t [u8], - pub count: usize, - } - - impl<'t> std::fmt::Debug for DecodeCompositeInner<'t> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!( - "count = {} data = {:x?}", - self.count, self.raw - )) - } - } - - impl<'t> DecodeCompositeInner<'t> { - fn underflow(&mut self) -> errors::Underflow { - // after one underflow happened, all further reads should underflow as well - // all other errors should be recoverable, since they only affect the content of one element and not the size of that element - self.raw = &[0u8; 0]; - errors::Underflow - } - - pub fn count(&self) -> usize { - self.count - } - - fn new(bytes: &'t [u8], count: usize) -> Self { - DecodeCompositeInner { raw: bytes, count } - } - - fn read_element(&mut self, position: usize) -> Result<&'t [u8], DecodeError> { - assert!( - self.count() > 0, - "reading from a finished elements sequence" - ); - self.count -= 1; - ensure!(self.raw.len() >= position, self.underflow()); - let result = &self.raw[..position]; - self.raw.advance(position); - ensure!(self.count > 0 || self.raw.is_empty(), errors::ExtraData); - Ok(result) - } - - pub fn read_raw_object_element(&mut self) -> Result, DecodeError> { - ensure!(self.raw.remaining() >= 4, self.underflow()); - let len = self.raw.get_i32(); - if len < 0 { - ensure!(len == -1, errors::InvalidMarker); - return Ok(None); - } - let len = len as usize; - Ok(Some(self.read_element(len)?)) - } - - pub fn read_object_element(&mut self) -> Result, DecodeError> { - ensure!(self.raw.remaining() >= 8, self.underflow()); - let _reserved = self.raw.get_i32(); - self.read_raw_object_element() - } - - pub fn read_array_like_element(&mut self) -> Result<&'t [u8], DecodeError> { - ensure!(self.raw.remaining() >= 4, self.underflow()); - let len = self.raw.get_i32() as usize; - self.read_element(len) - } - - pub fn read_tuple_like_header(mut buf: &'t [u8]) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - let count = buf.get_u32() as usize; - Ok(Self::new(buf, count)) - } - - pub fn read_array_like_header( - mut buf: &'t [u8], - error: impl Fn() -> DecodeError, - ) -> Result { - ensure!(buf.remaining() >= 12, errors::Underflow); - let ndims = buf.get_u32(); - let _reserved0 = buf.get_u32(); - let _reserved1 = buf.get_u32(); - if ndims == 0 { - return Ok(Self::new(buf, 0)); - } - if ndims != 1 { - return Err(error()); - } - ensure!(buf.remaining() >= 8, errors::Underflow); - let size = buf.get_u32() as usize; - let lower = buf.get_u32(); - if lower != 1 { - return Err(error()); - } - Ok(Self::new(buf, size)) - } - } -} diff --git a/edgedb-protocol/src/serialization/decode/raw_scalar.rs b/edgedb-protocol/src/serialization/decode/raw_scalar.rs deleted file mode 100644 index 52abc6b2..00000000 --- a/edgedb-protocol/src/serialization/decode/raw_scalar.rs +++ /dev/null @@ -1,736 +0,0 @@ -use std::convert::TryInto; -use std::mem::size_of; -use std::str; -use std::time::SystemTime; - -use bytes::{Buf, BufMut, Bytes}; -use edgedb_errors::{ClientEncodingError, Error, ErrorKind}; -use snafu::{ensure, ResultExt}; - -use crate::codec; -use crate::descriptors::{Descriptor, TypePos}; -use crate::errors::{self, DecodeError}; -use crate::model::range; -use crate::model::{BigInt, Decimal}; -use crate::model::{ConfigMemory, Range}; -use crate::model::{DateDuration, RelativeDuration}; -use crate::model::{Datetime, Duration, LocalDate, LocalDatetime, LocalTime}; -use crate::model::{Json, Uuid}; -use crate::query_arg::{DescriptorContext, Encoder, ScalarArg}; -use crate::serialization::decode::queryable::scalars::DecodeScalar; -use crate::value::{EnumValue, Value}; - -pub trait RawCodec<'t>: Sized { - fn decode(buf: &'t [u8]) -> Result; -} - -fn ensure_exact_size(buf: &[u8], expected_size: usize) -> Result<(), DecodeError> { - if buf.len() != expected_size { - if buf.len() < expected_size { - return errors::Underflow.fail(); - } else { - return errors::ExtraData.fail(); - } - } - Ok(()) -} - -impl<'t> RawCodec<'t> for String { - fn decode(buf: &[u8]) -> Result { - <&str>::decode(buf).map(|s| s.to_owned()) - } -} - -fn check_scalar( - ctx: &DescriptorContext, - type_pos: TypePos, - type_id: Uuid, - name: &str, -) -> Result<(), Error> { - use crate::descriptors::Descriptor::{BaseScalar, Scalar}; - let desc = ctx.get(type_pos)?; - match desc { - Scalar(scalar) if scalar.base_type_pos.is_some() => { - return check_scalar(ctx, scalar.base_type_pos.unwrap(), type_id, name); - } - Scalar(scalar) if ctx.proto.is_2() && *scalar.id == type_id => { - return Ok(()); - } - BaseScalar(base) if *base.id == type_id => { - return Ok(()); - } - _ => {} - } - Err(ctx.wrong_type(desc, name)) -} - -impl ScalarArg for String { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.extend(self.as_bytes()); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Str(self.clone())) - } -} - -impl ScalarArg for &'_ str { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.extend(self.as_bytes()); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - // special case: &str can express an enum variant - if let Descriptor::Enumeration(_) = ctx.get(pos)? { - return Ok(()); - } - - check_scalar(ctx, pos, String::uuid(), String::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Str(self.to_string())) - } -} - -impl<'t> RawCodec<'t> for &'t str { - fn decode(buf: &'t [u8]) -> Result { - let val = str::from_utf8(buf).context(errors::InvalidUtf8)?; - Ok(val) - } -} - -impl ScalarArg for Json { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.reserve(self.len() + 1); - encoder.buf.put_u8(1); - encoder.buf.extend(self.as_bytes()); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Json::uuid(), Json::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Json(self.clone())) - } -} - -impl<'t> RawCodec<'t> for Json { - fn decode(mut buf: &[u8]) -> Result { - ensure!(buf.remaining() >= 1, errors::Underflow); - let format = buf.get_u8(); - ensure!(format == 1, errors::InvalidJsonFormat); - let val = str::from_utf8(buf).context(errors::InvalidUtf8)?.to_owned(); - Ok(Json::new_unchecked(val)) - } -} - -impl<'t> RawCodec<'t> for Uuid { - fn decode(buf: &[u8]) -> Result { - ensure_exact_size(buf, 16)?; - let uuid = Uuid::from_slice(buf).unwrap(); - Ok(uuid) - } -} - -impl ScalarArg for Uuid { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.reserve(16); - encoder.buf.extend(self.as_bytes()); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Uuid(self.clone())) - } -} - -impl<'t> RawCodec<'t> for bool { - fn decode(buf: &[u8]) -> Result { - ensure_exact_size(buf, 1)?; - let res = match buf[0] { - 0x00 => false, - 0x01 => true, - v => errors::InvalidBool { val: v }.fail()?, - }; - Ok(res) - } -} - -impl ScalarArg for bool { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.reserve(1); - encoder.buf.put_u8(match self { - false => 0x00, - true => 0x01, - }); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Bool(self.clone())) - } -} - -impl<'t> RawCodec<'t> for i16 { - fn decode(mut buf: &[u8]) -> Result { - ensure_exact_size(buf, size_of::())?; - Ok(buf.get_i16()) - } -} - -impl ScalarArg for i16 { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.reserve(2); - encoder.buf.put_i16(*self); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Int16(self.clone())) - } -} - -impl<'t> RawCodec<'t> for i32 { - fn decode(mut buf: &[u8]) -> Result { - ensure_exact_size(buf, size_of::())?; - Ok(buf.get_i32()) - } -} - -impl ScalarArg for i32 { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.reserve(4); - encoder.buf.put_i32(*self); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Int32(self.clone())) - } -} - -impl<'t> RawCodec<'t> for i64 { - fn decode(mut buf: &[u8]) -> Result { - ensure_exact_size(buf, size_of::())?; - Ok(buf.get_i64()) - } -} - -impl<'t> RawCodec<'t> for ConfigMemory { - fn decode(mut buf: &[u8]) -> Result { - ensure_exact_size(buf, size_of::())?; - Ok(ConfigMemory(buf.get_i64())) - } -} - -impl ScalarArg for i64 { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.reserve(8); - encoder.buf.put_i64(*self); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Int64(self.clone())) - } -} - -impl<'t> RawCodec<'t> for f32 { - fn decode(mut buf: &[u8]) -> Result { - ensure_exact_size(buf, size_of::())?; - Ok(buf.get_f32()) - } -} - -impl ScalarArg for f32 { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.reserve(4); - encoder.buf.put_f32(*self); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Float32(self.clone())) - } -} - -impl<'t> RawCodec<'t> for f64 { - fn decode(mut buf: &[u8]) -> Result { - ensure_exact_size(buf, size_of::())?; - Ok(buf.get_f64()) - } -} - -impl ScalarArg for f64 { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.reserve(8); - encoder.buf.put_f64(*self); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Float64(self.clone())) - } -} - -impl<'t> RawCodec<'t> for &'t [u8] { - fn decode(buf: &'t [u8]) -> Result { - Ok(buf) - } -} - -impl ScalarArg for &'_ [u8] { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.extend(*self); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, codec::STD_BYTES, "std::bytes") - } - fn to_value(&self) -> Result { - Ok(Value::Bytes(Bytes::copy_from_slice(self))) - } -} - -impl<'t> RawCodec<'t> for Bytes { - fn decode(buf: &[u8]) -> Result { - Ok(Bytes::copy_from_slice(buf)) - } -} - -impl ScalarArg for Bytes { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.extend(&self[..]); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, codec::STD_BYTES, "std::bytes") - } - fn to_value(&self) -> Result { - Ok(Value::Bytes(self.clone())) - } -} - -impl ScalarArg for ConfigMemory { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.reserve(8); - encoder.buf.put_i64(self.0); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::ConfigMemory(self.clone())) - } -} - -impl<'t> RawCodec<'t> for Decimal { - fn decode(mut buf: &[u8]) -> Result { - ensure!(buf.remaining() >= 8, errors::Underflow); - let ndigits = buf.get_u16() as usize; - let weight = buf.get_i16(); - let negative = match buf.get_u16() { - 0x0000 => false, - 0x4000 => true, - _ => errors::BadSign.fail()?, - }; - let decimal_digits = buf.get_u16(); - ensure_exact_size(buf, ndigits * 2)?; - let mut digits = Vec::with_capacity(ndigits); - for _ in 0..ndigits { - digits.push(buf.get_u16()); - } - Ok(Decimal { - negative, - weight, - decimal_digits, - digits, - }) - } -} - -#[cfg(feature = "bigdecimal")] -impl<'t> RawCodec<'t> for bigdecimal::BigDecimal { - fn decode(buf: &[u8]) -> Result { - let dec: Decimal = RawCodec::decode(buf)?; - Ok(dec.into()) - } -} - -impl ScalarArg for Decimal { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - codec::encode_decimal(encoder.buf, self).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Decimal(self.clone())) - } -} - -#[cfg(feature = "num-bigint")] -impl<'t> RawCodec<'t> for num_bigint::BigInt { - fn decode(buf: &[u8]) -> Result { - let dec: BigInt = RawCodec::decode(buf)?; - Ok(dec.into()) - } -} - -#[cfg(feature = "bigdecimal")] -impl ScalarArg for bigdecimal::BigDecimal { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - let val = self.clone().try_into().map_err(|e| { - ClientEncodingError::with_source(e).context("cannot serialize BigDecimal value") - })?; - codec::encode_decimal(encoder.buf, &val).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Decimal( - self.clone() - .try_into() - .map_err(ClientEncodingError::with_source)?, - )) - } -} - -impl<'t> RawCodec<'t> for BigInt { - fn decode(mut buf: &[u8]) -> Result { - ensure!(buf.remaining() >= 8, errors::Underflow); - let ndigits = buf.get_u16() as usize; - let weight = buf.get_i16(); - let negative = match buf.get_u16() { - 0x0000 => false, - 0x4000 => true, - _ => errors::BadSign.fail()?, - }; - let decimal_digits = buf.get_u16(); - ensure!(decimal_digits == 0, errors::NonZeroReservedBytes); - let mut digits = Vec::with_capacity(ndigits); - ensure_exact_size(buf, ndigits * 2)?; - for _ in 0..ndigits { - digits.push(buf.get_u16()); - } - Ok(BigInt { - negative, - weight, - digits, - }) - } -} - -impl ScalarArg for BigInt { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - codec::encode_big_int(encoder.buf, self).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::BigInt(self.clone())) - } -} - -#[cfg(feature = "bigdecimal")] -impl ScalarArg for num_bigint::BigInt { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - let val = self.clone().try_into().map_err(|e| { - ClientEncodingError::with_source(e).context("cannot serialize BigInt value") - })?; - codec::encode_big_int(encoder.buf, &val).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - let val = self.clone().try_into().map_err(|e| { - ClientEncodingError::with_source(e).context("cannot serialize BigInt value") - })?; - Ok(Value::BigInt(val)) - } -} - -impl<'t> RawCodec<'t> for Duration { - fn decode(mut buf: &[u8]) -> Result { - ensure_exact_size(buf, 16)?; - let micros = buf.get_i64(); - let days = buf.get_u32(); - let months = buf.get_u32(); - ensure!(months == 0 && days == 0, errors::NonZeroReservedBytes); - Ok(Duration { micros }) - } -} - -impl<'t> RawCodec<'t> for std::time::Duration { - fn decode(buf: &[u8]) -> Result { - let dur = Duration::decode(buf)?; - dur.try_into().map_err(|_| errors::InvalidDate.build()) - } -} - -impl ScalarArg for Duration { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - codec::encode_duration(encoder.buf, self).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Duration(self.clone())) - } -} - -impl<'t> RawCodec<'t> for RelativeDuration { - fn decode(mut buf: &[u8]) -> Result { - ensure_exact_size(buf, 16)?; - let micros = buf.get_i64(); - let days = buf.get_i32(); - let months = buf.get_i32(); - Ok(RelativeDuration { - micros, - days, - months, - }) - } -} - -impl<'t> RawCodec<'t> for DateDuration { - fn decode(mut buf: &[u8]) -> Result { - ensure_exact_size(buf, 16)?; - let micros = buf.get_i64(); - let days = buf.get_i32(); - let months = buf.get_i32(); - ensure!(micros == 0, errors::NonZeroReservedBytes); - Ok(DateDuration { days, months }) - } -} - -impl ScalarArg for RelativeDuration { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - codec::encode_relative_duration(encoder.buf, self).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::RelativeDuration(self.clone())) - } -} - -impl<'t> RawCodec<'t> for SystemTime { - fn decode(buf: &[u8]) -> Result { - let dur = Datetime::decode(buf)?; - dur.try_into().map_err(|_| errors::InvalidDate.build()) - } -} - -impl ScalarArg for SystemTime { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - let val = self.clone().try_into().map_err(|e| { - ClientEncodingError::with_source(e).context("cannot serialize SystemTime value") - })?; - codec::encode_datetime(encoder.buf, &val).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - let val = self.clone().try_into().map_err(|e| { - ClientEncodingError::with_source(e).context("cannot serialize SystemTime value") - })?; - Ok(Value::Datetime(val)) - } -} - -impl<'t> RawCodec<'t> for Datetime { - fn decode(buf: &[u8]) -> Result { - let micros = i64::decode(buf)?; - Datetime::from_postgres_micros(micros).map_err(|_| errors::InvalidDate.build()) - } -} - -impl ScalarArg for Datetime { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - codec::encode_datetime(encoder.buf, self).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::Datetime(self.clone())) - } -} - -impl<'t> RawCodec<'t> for LocalDatetime { - fn decode(buf: &[u8]) -> Result { - let micros = i64::decode(buf)?; - LocalDatetime::from_postgres_micros(micros).map_err(|_| errors::InvalidDate.build()) - } -} - -impl ScalarArg for LocalDatetime { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - codec::encode_local_datetime(encoder.buf, self).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::LocalDatetime(self.clone())) - } -} - -impl<'t> RawCodec<'t> for LocalDate { - fn decode(buf: &[u8]) -> Result { - let days = i32::decode(buf)?; - Ok(LocalDate { days }) - } -} - -impl ScalarArg for LocalDate { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - codec::encode_local_date(encoder.buf, self).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::LocalDate(self.clone())) - } -} - -impl<'t> RawCodec<'t> for LocalTime { - fn decode(buf: &[u8]) -> Result { - let micros = i64::decode(buf)?; - ensure!( - (0..86_400 * 1_000_000).contains(µs), - errors::InvalidDate - ); - Ok(LocalTime { - micros: micros as u64, - }) - } -} - -impl ScalarArg for DateDuration { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - codec::encode_date_duration(encoder.buf, self).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::DateDuration(self.clone())) - } -} - -impl ScalarArg for LocalTime { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - codec::encode_local_time(encoder.buf, self).map_err(ClientEncodingError::with_source) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - check_scalar(ctx, pos, Self::uuid(), Self::typename()) - } - fn to_value(&self) -> Result { - Ok(Value::LocalTime(self.clone())) - } -} - -impl ScalarArg for EnumValue { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - encoder.buf.extend(self.as_bytes()); - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - use crate::descriptors::Descriptor::Enumeration; - - let desc = ctx.get(pos)?; - if let Enumeration(_) = desc { - // Should we check enum members? - // Should we override `QueryArg` check descriptor for that? - // Or maybe implement just `QueryArg` for enum? - } - Err(ctx.wrong_type(desc, "enum")) - } - fn to_value(&self) -> Result { - Ok(Value::Enum(self.clone())) - } -} - -impl ScalarArg for Range { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - let flags = if self.empty { - range::EMPTY - } else { - (if self.inc_lower { range::LB_INC } else { 0 }) - | (if self.inc_upper { range::UB_INC } else { 0 }) - | (if self.lower.is_none() { - range::LB_INF - } else { - 0 - }) - | (if self.upper.is_none() { - range::UB_INF - } else { - 0 - }) - }; - encoder.buf.reserve(1); - encoder.buf.put_u8(flags as u8); - - if let Some(lower) = &self.lower { - encoder.length_prefixed(|encoder| lower.encode(encoder))? - } - - if let Some(upper) = &self.upper { - encoder.length_prefixed(|encoder| upper.encode(encoder))?; - } - Ok(()) - } - fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> { - let desc = ctx.get(pos)?; - if let Descriptor::Range(rng) = desc { - T::check_descriptor(ctx, rng.type_pos) - } else { - Err(ctx.wrong_type(desc, "range")) - } - } - fn to_value(&self) -> Result { - Ok(Value::Range(Range { - lower: self - .lower - .as_ref() - .map(|v| v.to_value().map(Box::new)) - .transpose()?, - upper: self - .upper - .as_ref() - .map(|v| v.to_value().map(Box::new)) - .transpose()?, - inc_lower: self.inc_lower, - inc_upper: self.inc_upper, - empty: self.empty, - })) - } -} diff --git a/edgedb-protocol/src/serialization/test_scalars.rs b/edgedb-protocol/src/serialization/test_scalars.rs deleted file mode 100644 index 1c6a7b43..00000000 --- a/edgedb-protocol/src/serialization/test_scalars.rs +++ /dev/null @@ -1,225 +0,0 @@ -use bytes::{Bytes, BytesMut}; -use std::str::FromStr; -use uuid::Uuid; - -use crate::features::ProtocolVersion; -use crate::model::Json; -use crate::query_arg::{DescriptorContext, Encoder, ScalarArg}; -use crate::serialization::decode::RawCodec; - -fn encode(val: impl ScalarArg) -> Bytes { - let proto = ProtocolVersion::current(); - let ctx = DescriptorContext { - proto: &proto, - root_pos: None, - descriptors: &[], - }; - let mut buf = BytesMut::new(); - let mut encoder = Encoder::new(&ctx, &mut buf); - ScalarArg::encode(&val, &mut encoder).expect("encoded"); - buf.freeze() -} - -fn decode<'x, T: RawCodec<'x>>(bytes: &'x [u8]) -> T { - ::decode(bytes).expect("decoded") -} - -macro_rules! encoding_eq { - ($data: expr, $bytes: expr) => { - let lambda = |_a| (); // type inference hack - let data = $data; - lambda(&data); // type inference hack - let val = decode($bytes); - lambda(&val); // type inference hack - assert_eq!(val, data, "decoding failed"); - println!("Decoded value: {:?}", val); - - let buf = encode($data); - println!("Encoded value: {:?}", &buf[..]); - assert_eq!(&buf[..], $bytes, "encoding failed"); - }; -} - -#[test] -fn bool() { - encoding_eq!(true, b"\x01"); - encoding_eq!(false, b"\x00"); -} - -#[test] -fn str() { - encoding_eq!("hello", b"hello"); - encoding_eq!(r#""world!""#, b"\"world!\""); - encoding_eq!(String::from("hello"), b"hello"); - encoding_eq!(String::from(r#""world!""#), b"\"world!\""); -} - -#[test] -fn json() { - let val = Json::new_unchecked("{}".into()); - assert_eq!(&encode(val)[..], b"\x01{}"); - assert_eq!(&decode::(b"\x01{}")[..], "{}"); -} - -#[test] -fn int16() { - encoding_eq!(0i16, b"\0\0"); - encoding_eq!(0x105i16, b"\x01\x05"); - encoding_eq!(i16::MAX, b"\x7F\xFF"); - encoding_eq!(i16::MIN, b"\x80\x00"); - encoding_eq!(-1i16, b"\xFF\xFF"); -} - -#[test] -fn int32() { - encoding_eq!(0i32, b"\0\0\0\0"); - encoding_eq!(0x105i32, b"\0\0\x01\x05"); - encoding_eq!(i32::MAX, b"\x7F\xFF\xFF\xFF"); - encoding_eq!(i32::MIN, b"\x80\x00\x00\x00"); - encoding_eq!(-1i32, b"\xFF\xFF\xFF\xFF"); -} - -#[test] -fn int64() { - encoding_eq!(0i64, b"\0\0\0\0\0\0\0\0"); - encoding_eq!(0x105i64, b"\0\0\0\0\0\0\x01\x05"); - encoding_eq!(i64::MAX, b"\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); - encoding_eq!(i64::MIN, b"\x80\x00\x00\x00\x00\x00\x00\x00"); - encoding_eq!(-1i64, b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); -} - -#[test] -fn float32() { - encoding_eq!(0.0f32, b"\0\0\0\0"); - encoding_eq!(-0.0f32, b"\x80\0\0\0"); - encoding_eq!(1.0f32, b"?\x80\0\0"); - encoding_eq!(-1.123f32, b"\xbf\x8f\xbew"); - assert_eq!(&encode(f32::NAN)[..], b"\x7f\xc0\0\0"); - assert_eq!(&encode(f32::INFINITY)[..], b"\x7f\x80\0\0"); - assert_eq!(&encode(f32::NEG_INFINITY)[..], b"\xff\x80\0\0"); - assert!(decode::(b"\x7f\xc0\0\0").is_nan()); - assert!(decode::(b"\x7f\x80\0\0").is_infinite()); - assert!(decode::(b"\x7f\x80\0\0").is_sign_positive()); - assert!(decode::(b"\xff\x80\0\0").is_infinite()); - assert!(decode::(b"\xff\x80\0\0").is_sign_negative()); -} - -#[test] -fn float64() { - encoding_eq!(0.0, b"\0\0\0\0\0\0\0\0"); - encoding_eq!(-0.0, b"\x80\0\0\0\0\0\0\0"); - encoding_eq!(1.0, b"?\xf0\0\0\0\0\0\0"); - encoding_eq!(1e100, b"T\xb2I\xad%\x94\xc3}"); - assert_eq!(&encode(f64::NAN)[..], b"\x7f\xf8\0\0\0\0\0\0"); - assert_eq!(&encode(f64::INFINITY)[..], b"\x7f\xf0\0\0\0\0\0\0"); - assert_eq!(&encode(f64::NEG_INFINITY)[..], b"\xff\xf0\0\0\0\0\0\0"); - assert!(decode::(b"\x7f\xf8\0\0\0\0\0\0").is_nan()); - assert!(decode::(b"\x7f\xf0\0\0\0\0\0\0").is_infinite()); - assert!(decode::(b"\x7f\xf0\0\0\0\0\0\0").is_sign_positive()); - assert!(decode::(b"\xff\xf0\0\0\0\0\0\0").is_infinite()); - assert!(decode::(b"\xff\xf0\0\0\0\0\0\0").is_sign_negative()); -} - -#[test] -fn bytes() { - encoding_eq!(&b"hello"[..], b"hello"); - encoding_eq!(&b""[..], b""); - encoding_eq!(&b"\x00\x01\x02\x03\x81"[..], b"\x00\x01\x02\x03\x81"); - encoding_eq!(Bytes::copy_from_slice(b"hello"), b"hello"); - encoding_eq!(Bytes::new(), b""); - encoding_eq!( - Bytes::copy_from_slice(b"\x00\x01\x02\x03\x81"), - b"\x00\x01\x02\x03\x81" - ); -} - -#[test] -#[cfg(feature = "bigdecimal")] -fn decimal() { - use crate::model::Decimal; - use bigdecimal::BigDecimal; - use std::convert::TryInto; - - fn dec(s: &str) -> Decimal { - bdec(s).try_into().expect("bigdecimal -> decimal") - } - - fn bdec(s: &str) -> BigDecimal { - BigDecimal::from_str(s).expect("bigdecimal") - } - - encoding_eq!(bdec("42.00"), b"\0\x01\0\0\0\0\0\x02\0*"); - encoding_eq!(dec("42.00"), b"\0\x01\0\0\0\0\0\x02\0*"); - - encoding_eq!( - bdec("12345678.901234567"), - b"\0\x05\0\x01\0\0\0\t\x04\xd2\x16.#4\r\x80\x1bX" - ); - encoding_eq!( - dec("12345678.901234567"), - b"\0\x05\0\x01\0\0\0\t\x04\xd2\x16.#4\r\x80\x1bX" - ); - encoding_eq!(bdec("1e100"), b"\0\x01\0\x19\0\0\0\0\0\x01"); - encoding_eq!(dec("1e100"), b"\0\x01\0\x19\0\0\0\0\0\x01"); - encoding_eq!( - bdec("-703367234220692490200000000000000000000000000"), - b"\0\x06\0\x0b@\0\0\0\0\x07\x01P\x1cB\x08\x9e$!\0\xc8" - ); - encoding_eq!( - dec("-703367234220692490200000000000000000000000000"), - b"\0\x06\0\x0b@\0\0\0\0\x07\x01P\x1cB\x08\x9e$!\0\xc8" - ); - encoding_eq!( - bdec("-7033672342206924902e26"), - b"\0\x06\0\x0b@\0\0\0\0\x07\x01P\x1cB\x08\x9e$!\0\xc8" - ); - encoding_eq!( - dec("-7033672342206924902e26"), - b"\0\x06\0\x0b@\0\0\0\0\x07\x01P\x1cB\x08\x9e$!\0\xc8" - ); -} - -#[test] -#[cfg(feature = "num-bigint")] -fn bigint() { - use crate::model::BigInt; - use std::convert::TryInto; - - fn bint1(val: i32) -> num_bigint::BigInt { - val.into() - } - fn bint2(val: i32) -> BigInt { - bint1(val).try_into().unwrap() - } - - fn bint1s(val: &str) -> num_bigint::BigInt { - val.parse().unwrap() - } - fn bint2s(val: &str) -> BigInt { - bint1s(val).try_into().unwrap() - } - - encoding_eq!(bint1(42), b"\0\x01\0\0\0\0\0\0\0*"); - encoding_eq!(bint2(42), b"\0\x01\0\0\0\0\0\0\0*"); - encoding_eq!(bint1(30000), b"\0\x01\0\x01\0\0\0\0\0\x03"); - encoding_eq!(bint2(30000), b"\0\x01\0\x01\0\0\0\0\0\x03"); - encoding_eq!(bint1(30001), b"\0\x02\0\x01\0\0\0\0\0\x03\0\x01"); - encoding_eq!(bint1(-15000), b"\0\x02\0\x01@\0\0\0\0\x01\x13\x88"); - encoding_eq!(bint2(-15000), b"\0\x02\0\x01@\0\0\0\0\x01\x13\x88"); - encoding_eq!( - bint1s("1000000000000000000000"), - b"\0\x01\0\x05\0\0\0\0\0\n" - ); - encoding_eq!( - bint2s("1000000000000000000000"), - b"\0\x01\0\x05\0\0\0\0\0\n" - ); -} - -#[test] -fn uuid() { - encoding_eq!( - Uuid::from_str("4928cc1e-2065-11ea-8848-7b53a6adb383").unwrap(), - b"I(\xcc\x1e e\x11\xea\x88H{S\xa6\xad\xb3\x83" - ); -} diff --git a/edgedb-protocol/src/server_message.rs b/edgedb-protocol/src/server_message.rs deleted file mode 100644 index 16f6c2db..00000000 --- a/edgedb-protocol/src/server_message.rs +++ /dev/null @@ -1,1026 +0,0 @@ -/*! -The [ServerMessage] enum and related types. EdgeDB website documentation on messages [here](https://www.edgedb.com/docs/reference/protocol/messages). - -```rust,ignore -pub enum ServerMessage { - ServerHandshake(ServerHandshake), - UnknownMessage(u8, Bytes), - LogMessage(LogMessage), - ErrorResponse(ErrorResponse), - Authentication(Authentication), - ReadyForCommand(ReadyForCommand), - ServerKeyData(ServerKeyData), - ParameterStatus(ParameterStatus), - CommandComplete0(CommandComplete0), - CommandComplete1(CommandComplete1), - PrepareComplete(PrepareComplete), - CommandDataDescription0(CommandDataDescription0), // protocol < 1.0 - CommandDataDescription1(CommandDataDescription1), // protocol >= 1.0 - StateDataDescription(StateDataDescription), - Data(Data), - RestoreReady(RestoreReady), - DumpHeader(RawPacket), - DumpBlock(RawPacket), -} -``` -*/ - -use std::collections::HashMap; -use std::convert::{TryFrom, TryInto}; - -use bytes::{Buf, BufMut, Bytes}; -use snafu::{ensure, OptionExt}; -use uuid::Uuid; - -use crate::common::Capabilities; -pub use crate::common::{Cardinality, RawTypedesc, State}; -use crate::descriptors::Typedesc; -use crate::encoding::{Annotations, Decode, Encode, Input, KeyValues, Output}; -use crate::errors::{self, DecodeError, EncodeError}; -use crate::features::ProtocolVersion; - -#[derive(Debug, Clone, PartialEq, Eq)] -#[non_exhaustive] -pub enum ServerMessage { - Authentication(Authentication), - CommandComplete0(CommandComplete0), - CommandComplete1(CommandComplete1), - CommandDataDescription0(CommandDataDescription0), // protocol < 1.0 - CommandDataDescription1(CommandDataDescription1), // protocol >= 1.0 - StateDataDescription(StateDataDescription), - Data(Data), - // Don't decode Dump packets here as we only need to process them as - // whole - DumpHeader(RawPacket), - DumpBlock(RawPacket), - ErrorResponse(ErrorResponse), - LogMessage(LogMessage), - ParameterStatus(ParameterStatus), - ReadyForCommand(ReadyForCommand), - RestoreReady(RestoreReady), - ServerHandshake(ServerHandshake), - ServerKeyData(ServerKeyData), - UnknownMessage(u8, Bytes), - PrepareComplete(PrepareComplete), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ReadyForCommand { - pub headers: KeyValues, - pub transaction_state: TransactionState, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Authentication { - Ok, - Sasl { methods: Vec }, - SaslContinue { data: Bytes }, - SaslFinal { data: Bytes }, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ErrorSeverity { - Error, - Fatal, - Panic, - Unknown(u8), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MessageSeverity { - Debug, - Info, - Notice, - Warning, - Unknown(u8), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TransactionState { - // Not in a transaction block. - NotInTransaction = 0x49, - - // In a transaction block. - InTransaction = 0x54, - - // In a failed transaction block - // (commands will be rejected until the block is ended). - InFailedTransaction = 0x45, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ErrorResponse { - pub severity: ErrorSeverity, - pub code: u32, - pub message: String, - pub attributes: KeyValues, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct LogMessage { - pub severity: MessageSeverity, - pub code: u32, - pub text: String, - pub attributes: KeyValues, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ServerHandshake { - pub major_ver: u16, - pub minor_ver: u16, - pub extensions: HashMap, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ServerKeyData { - pub data: [u8; 32], -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ParameterStatus { - pub proto: ProtocolVersion, - pub name: Bytes, - pub value: Bytes, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CommandComplete0 { - pub headers: KeyValues, - pub status_data: Bytes, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CommandComplete1 { - pub annotations: Annotations, - pub capabilities: Capabilities, - pub status_data: Bytes, - pub state: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PrepareComplete { - pub headers: KeyValues, - pub cardinality: Cardinality, - pub input_typedesc_id: Uuid, - pub output_typedesc_id: Uuid, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ParseComplete { - pub headers: KeyValues, - pub cardinality: Cardinality, - pub input_typedesc_id: Uuid, - pub output_typedesc_id: Uuid, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CommandDataDescription0 { - pub headers: KeyValues, - pub result_cardinality: Cardinality, - pub input: RawTypedesc, - pub output: RawTypedesc, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CommandDataDescription1 { - pub annotations: Annotations, - pub capabilities: Capabilities, - pub result_cardinality: Cardinality, - pub input: RawTypedesc, - pub output: RawTypedesc, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct StateDataDescription { - pub typedesc: RawTypedesc, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Data { - pub data: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct RestoreReady { - pub headers: KeyValues, - pub jobs: u16, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct RawPacket { - pub data: Bytes, -} - -fn encode(buf: &mut Output, code: u8, msg: &T) -> Result<(), EncodeError> { - buf.reserve(5); - buf.put_u8(code); - let base = buf.len(); - buf.put_slice(&[0; 4]); - - msg.encode(buf)?; - - let size = u32::try_from(buf.len() - base) - .ok() - .context(errors::MessageTooLong)?; - buf[base..base + 4].copy_from_slice(&size.to_be_bytes()[..]); - Ok(()) -} - -impl CommandDataDescription0 { - pub fn output(&self) -> Result { - self.output.decode() - } - pub fn input(&self) -> Result { - self.input.decode() - } -} - -impl CommandDataDescription1 { - pub fn output(&self) -> Result { - self.output.decode() - } - pub fn input(&self) -> Result { - self.input.decode() - } -} - -impl From for CommandDataDescription1 { - fn from(value: CommandDataDescription0) -> Self { - Self { - annotations: HashMap::new(), - capabilities: decode_capabilities0(&value.headers).unwrap_or(Capabilities::ALL), - result_cardinality: value.result_cardinality, - input: value.input, - output: value.output, - } - } -} - -impl StateDataDescription { - pub fn parse(self) -> Result { - self.typedesc.decode() - } -} - -impl ParameterStatus { - pub fn parse_system_config(self) -> Result<(Typedesc, Bytes), DecodeError> { - let cur = &mut Input::new(self.proto.clone(), self.value); - let typedesc_data = Bytes::decode(cur)?; - let data = Bytes::decode(cur)?; - - let typedesc_buf = &mut Input::new(self.proto, typedesc_data); - let typedesc_id = Uuid::decode(typedesc_buf)?; - let typedesc = Typedesc::decode_with_id(typedesc_id, typedesc_buf)?; - Ok((typedesc, data)) - } -} - -impl ServerMessage { - pub fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - use ServerMessage::*; - match self { - ServerHandshake(h) => encode(buf, 0x76, h), - ErrorResponse(h) => encode(buf, 0x45, h), - LogMessage(h) => encode(buf, 0x4c, h), - Authentication(h) => encode(buf, 0x52, h), - ReadyForCommand(h) => encode(buf, 0x5a, h), - ServerKeyData(h) => encode(buf, 0x4b, h), - ParameterStatus(h) => encode(buf, 0x53, h), - CommandComplete0(h) => encode(buf, 0x43, h), - CommandComplete1(h) => encode(buf, 0x43, h), - PrepareComplete(h) => encode(buf, 0x31, h), - CommandDataDescription0(h) => encode(buf, 0x54, h), - CommandDataDescription1(h) => encode(buf, 0x54, h), - StateDataDescription(h) => encode(buf, 0x73, h), - Data(h) => encode(buf, 0x44, h), - RestoreReady(h) => encode(buf, 0x2b, h), - DumpHeader(h) => encode(buf, 0x40, h), - DumpBlock(h) => encode(buf, 0x3d, h), - - UnknownMessage(_, _) => errors::UnknownMessageCantBeEncoded.fail()?, - } - } - /// Decode exactly one frame from the buffer. - /// - /// This expects a full frame to already be in the buffer. It can return - /// an arbitrary error or be silent if a message is only partially present - /// in the buffer or if extra data is present. - pub fn decode(buf: &mut Input) -> Result { - use self::ServerMessage as M; - let data = &mut buf.slice(5..); - let result = match buf[0] { - 0x76 => ServerHandshake::decode(data).map(M::ServerHandshake)?, - 0x45 => ErrorResponse::decode(data).map(M::ErrorResponse)?, - 0x4c => LogMessage::decode(data).map(M::LogMessage)?, - 0x52 => Authentication::decode(data).map(M::Authentication)?, - 0x5a => ReadyForCommand::decode(data).map(M::ReadyForCommand)?, - 0x4b => ServerKeyData::decode(data).map(M::ServerKeyData)?, - 0x53 => ParameterStatus::decode(data).map(M::ParameterStatus)?, - 0x43 => { - if buf.proto().is_1() { - CommandComplete1::decode(data).map(M::CommandComplete1)? - } else { - CommandComplete0::decode(data).map(M::CommandComplete0)? - } - } - 0x31 => PrepareComplete::decode(data).map(M::PrepareComplete)?, - 0x44 => Data::decode(data).map(M::Data)?, - 0x2b => RestoreReady::decode(data).map(M::RestoreReady)?, - 0x40 => RawPacket::decode(data).map(M::DumpHeader)?, - 0x3d => RawPacket::decode(data).map(M::DumpBlock)?, - 0x54 => { - if buf.proto().is_1() { - CommandDataDescription1::decode(data).map(M::CommandDataDescription1)? - } else { - CommandDataDescription0::decode(data).map(M::CommandDataDescription0)? - } - } - 0x73 => StateDataDescription::decode(data).map(M::StateDataDescription)?, - code => M::UnknownMessage(code, data.copy_to_bytes(data.remaining())), - }; - ensure!(data.remaining() == 0, errors::ExtraData); - Ok(result) - } -} - -impl Encode for ServerHandshake { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(6); - buf.put_u16(self.major_ver); - buf.put_u16(self.minor_ver); - buf.put_u16( - u16::try_from(self.extensions.len()) - .ok() - .context(errors::TooManyExtensions)?, - ); - for (name, headers) in &self.extensions { - name.encode(buf)?; - buf.reserve(2); - buf.put_u16( - u16::try_from(headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - } - Ok(()) - } -} - -impl Decode for ServerHandshake { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 6, errors::Underflow); - let major_ver = buf.get_u16(); - let minor_ver = buf.get_u16(); - let num_ext = buf.get_u16(); - let mut extensions = HashMap::new(); - for _ in 0..num_ext { - let name = String::decode(buf)?; - ensure!(buf.remaining() >= 2, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - extensions.insert(name, headers); - } - Ok(ServerHandshake { - major_ver, - minor_ver, - extensions, - }) - } -} - -impl Encode for ErrorResponse { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(11); - buf.put_u8(self.severity.to_u8()); - buf.put_u32(self.code); - self.message.encode(buf)?; - buf.reserve(2); - buf.put_u16( - u16::try_from(self.attributes.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.attributes { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - Ok(()) - } -} - -impl Decode for ErrorResponse { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 11, errors::Underflow); - let severity = ErrorSeverity::from_u8(buf.get_u8()); - let code = buf.get_u32(); - let message = String::decode(buf)?; - ensure!(buf.remaining() >= 2, errors::Underflow); - let num_attributes = buf.get_u16(); - let mut attributes = HashMap::new(); - for _ in 0..num_attributes { - ensure!(buf.remaining() >= 4, errors::Underflow); - attributes.insert(buf.get_u16(), Bytes::decode(buf)?); - } - Ok(ErrorResponse { - severity, - code, - message, - attributes, - }) - } -} - -impl Encode for LogMessage { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(11); - buf.put_u8(self.severity.to_u8()); - buf.put_u32(self.code); - self.text.encode(buf)?; - buf.reserve(2); - buf.put_u16( - u16::try_from(self.attributes.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.attributes { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - Ok(()) - } -} - -impl Decode for LogMessage { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 11, errors::Underflow); - let severity = MessageSeverity::from_u8(buf.get_u8()); - let code = buf.get_u32(); - let text = String::decode(buf)?; - ensure!(buf.remaining() >= 2, errors::Underflow); - let num_attributes = buf.get_u16(); - let mut attributes = HashMap::new(); - for _ in 0..num_attributes { - ensure!(buf.remaining() >= 4, errors::Underflow); - attributes.insert(buf.get_u16(), Bytes::decode(buf)?); - } - Ok(LogMessage { - severity, - code, - text, - attributes, - }) - } -} - -impl Encode for Authentication { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - use Authentication as A; - buf.reserve(1); - match self { - A::Ok => buf.put_u32(0), - A::Sasl { methods } => { - buf.put_u32(0x0A); - buf.reserve(4); - buf.put_u32( - methods - .len() - .try_into() - .ok() - .context(errors::TooManyMethods)?, - ); - for meth in methods { - meth.encode(buf)?; - } - } - A::SaslContinue { data } => { - buf.put_u32(0x0B); - data.encode(buf)?; - } - A::SaslFinal { data } => { - buf.put_u32(0x0C); - data.encode(buf)?; - } - } - Ok(()) - } -} - -impl Decode for Authentication { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - match buf.get_u32() { - 0x00 => Ok(Authentication::Ok), - 0x0A => { - ensure!(buf.remaining() >= 4, errors::Underflow); - let num_methods = buf.get_u32() as usize; - let mut methods = Vec::with_capacity(num_methods); - for _ in 0..num_methods { - methods.push(String::decode(buf)?); - } - Ok(Authentication::Sasl { methods }) - } - 0x0B => { - let data = Bytes::decode(buf)?; - Ok(Authentication::SaslContinue { data }) - } - 0x0C => { - let data = Bytes::decode(buf)?; - Ok(Authentication::SaslFinal { data }) - } - c => errors::AuthStatusInvalid { auth_status: c }.fail()?, - } - } -} - -impl Encode for ReadyForCommand { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(3); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - buf.reserve(1); - buf.put_u8(self.transaction_state as u8); - Ok(()) - } -} -impl Decode for ReadyForCommand { - fn decode(buf: &mut Input) -> Result { - use TransactionState::*; - ensure!(buf.remaining() >= 3, errors::Underflow); - let mut headers = HashMap::new(); - let num_headers = buf.get_u16(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - ensure!(buf.remaining() >= 1, errors::Underflow); - let transaction_state = match buf.get_u8() { - 0x49 => NotInTransaction, - 0x54 => InTransaction, - 0x45 => InFailedTransaction, - s => errors::InvalidTransactionState { - transaction_state: s, - } - .fail()?, - }; - Ok(ReadyForCommand { - headers, - transaction_state, - }) - } -} - -impl ErrorSeverity { - pub fn from_u8(code: u8) -> ErrorSeverity { - use ErrorSeverity::*; - match code { - 120 => Error, - 200 => Fatal, - 255 => Panic, - _ => Unknown(code), - } - } - pub fn to_u8(&self) -> u8 { - use ErrorSeverity::*; - match *self { - Error => 120, - Fatal => 200, - Panic => 255, - Unknown(code) => code, - } - } -} - -impl MessageSeverity { - fn from_u8(code: u8) -> MessageSeverity { - use MessageSeverity::*; - match code { - 20 => Debug, - 40 => Info, - 60 => Notice, - 80 => Warning, - _ => Unknown(code), - } - } - fn to_u8(self) -> u8 { - use MessageSeverity::*; - match self { - Debug => 20, - Info => 40, - Notice => 60, - Warning => 80, - Unknown(code) => code, - } - } -} - -impl Encode for ServerKeyData { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.extend(&self.data[..]); - Ok(()) - } -} -impl Decode for ServerKeyData { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 32, errors::Underflow); - let mut data = [0u8; 32]; - buf.copy_to_slice(&mut data[..]); - Ok(ServerKeyData { data }) - } -} - -impl Encode for ParameterStatus { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - self.name.encode(buf)?; - self.value.encode(buf)?; - Ok(()) - } -} -impl Decode for ParameterStatus { - fn decode(buf: &mut Input) -> Result { - let name = Bytes::decode(buf)?; - let value = Bytes::decode(buf)?; - Ok(ParameterStatus { - proto: buf.proto().clone(), - name, - value, - }) - } -} - -impl Encode for CommandComplete0 { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(6); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - self.status_data.encode(buf)?; - Ok(()) - } -} - -impl Decode for CommandComplete0 { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 6, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - let status_data = Bytes::decode(buf)?; - Ok(CommandComplete0 { - status_data, - headers, - }) - } -} - -impl Encode for CommandComplete1 { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(26); - buf.put_u16( - u16::try_from(self.annotations.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (name, value) in &self.annotations { - name.encode(buf)?; - value.encode(buf)?; - } - buf.put_u64(self.capabilities.bits()); - self.status_data.encode(buf)?; - if let Some(state) = &self.state { - state.typedesc_id.encode(buf)?; - state.data.encode(buf)?; - } else { - Uuid::from_u128(0).encode(buf)?; - Bytes::new().encode(buf)?; - } - Ok(()) - } -} - -impl Decode for CommandComplete1 { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 26, errors::Underflow); - let num_annotations = buf.get_u16(); - let mut annotations = HashMap::new(); - for _ in 0..num_annotations { - annotations.insert(String::decode(buf)?, String::decode(buf)?); - } - let capabilities = Capabilities::from_bits_retain(buf.get_u64()); - let status_data = Bytes::decode(buf)?; - let typedesc_id = Uuid::decode(buf)?; - let state_data = Bytes::decode(buf)?; - let state = if typedesc_id == Uuid::from_u128(0) { - None - } else { - Some(State { - typedesc_id, - data: state_data, - }) - }; - Ok(CommandComplete1 { - annotations, - capabilities, - status_data, - state, - }) - } -} - -impl Encode for PrepareComplete { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(35); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - buf.reserve(33); - buf.put_u8(self.cardinality as u8); - self.input_typedesc_id.encode(buf)?; - self.output_typedesc_id.encode(buf)?; - Ok(()) - } -} - -impl Decode for PrepareComplete { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 35, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - ensure!(buf.remaining() >= 33, errors::Underflow); - let cardinality = TryFrom::try_from(buf.get_u8())?; - let input_typedesc_id = Uuid::decode(buf)?; - let output_typedesc_id = Uuid::decode(buf)?; - Ok(PrepareComplete { - headers, - cardinality, - input_typedesc_id, - output_typedesc_id, - }) - } -} - -impl Encode for CommandDataDescription0 { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - debug_assert!(!buf.proto().is_1()); - buf.reserve(43); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - buf.reserve(41); - buf.put_u8(self.result_cardinality as u8); - self.input.id.encode(buf)?; - self.input.data.encode(buf)?; - self.output.id.encode(buf)?; - self.output.data.encode(buf)?; - Ok(()) - } -} - -impl Decode for CommandDataDescription0 { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 43, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - ensure!(buf.remaining() >= 41, errors::Underflow); - let result_cardinality = TryFrom::try_from(buf.get_u8())?; - let input = RawTypedesc { - proto: buf.proto().clone(), - id: Uuid::decode(buf)?, - data: Bytes::decode(buf)?, - }; - let output = RawTypedesc { - proto: buf.proto().clone(), - id: Uuid::decode(buf)?, - data: Bytes::decode(buf)?, - }; - - Ok(CommandDataDescription0 { - headers, - result_cardinality, - input, - output, - }) - } -} - -impl Encode for CommandDataDescription1 { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - debug_assert!(buf.proto().is_1()); - buf.reserve(51); - buf.put_u16( - u16::try_from(self.annotations.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (name, value) in &self.annotations { - buf.reserve(4); - name.encode(buf)?; - value.encode(buf)?; - } - buf.reserve(49); - buf.put_u64(self.capabilities.bits()); - buf.put_u8(self.result_cardinality as u8); - self.input.id.encode(buf)?; - self.input.data.encode(buf)?; - self.output.id.encode(buf)?; - self.output.data.encode(buf)?; - Ok(()) - } -} - -impl Decode for CommandDataDescription1 { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 51, errors::Underflow); - let num_annotations = buf.get_u16(); - let mut annotations = HashMap::new(); - for _ in 0..num_annotations { - ensure!(buf.remaining() >= 4, errors::Underflow); - annotations.insert(String::decode(buf)?, String::decode(buf)?); - } - ensure!(buf.remaining() >= 49, errors::Underflow); - let capabilities = Capabilities::from_bits_retain(buf.get_u64()); - let result_cardinality = TryFrom::try_from(buf.get_u8())?; - let input = RawTypedesc { - proto: buf.proto().clone(), - id: Uuid::decode(buf)?, - data: Bytes::decode(buf)?, - }; - let output = RawTypedesc { - proto: buf.proto().clone(), - id: Uuid::decode(buf)?, - data: Bytes::decode(buf)?, - }; - - Ok(CommandDataDescription1 { - annotations, - capabilities, - result_cardinality, - input, - output, - }) - } -} - -impl Encode for StateDataDescription { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - debug_assert!(buf.proto().is_1()); - self.typedesc.id.encode(buf)?; - self.typedesc.data.encode(buf)?; - Ok(()) - } -} - -impl Decode for StateDataDescription { - fn decode(buf: &mut Input) -> Result { - let typedesc = RawTypedesc { - proto: buf.proto().clone(), - id: Uuid::decode(buf)?, - data: Bytes::decode(buf)?, - }; - - Ok(StateDataDescription { typedesc }) - } -} - -impl Encode for Data { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(2); - buf.put_u16( - u16::try_from(self.data.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for chunk in &self.data { - chunk.encode(buf)?; - } - Ok(()) - } -} - -impl Decode for Data { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 2, errors::Underflow); - let num_chunks = buf.get_u16() as usize; - let mut data = Vec::with_capacity(num_chunks); - for _ in 0..num_chunks { - data.push(Bytes::decode(buf)?); - } - Ok(Data { data }) - } -} - -impl Encode for RestoreReady { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.reserve(4); - buf.put_u16( - u16::try_from(self.headers.len()) - .ok() - .context(errors::TooManyHeaders)?, - ); - for (&name, value) in &self.headers { - buf.reserve(2); - buf.put_u16(name); - value.encode(buf)?; - } - buf.reserve(2); - buf.put_u16(self.jobs); - Ok(()) - } -} - -impl Decode for RestoreReady { - fn decode(buf: &mut Input) -> Result { - ensure!(buf.remaining() >= 4, errors::Underflow); - let num_headers = buf.get_u16(); - let mut headers = HashMap::new(); - for _ in 0..num_headers { - ensure!(buf.remaining() >= 4, errors::Underflow); - headers.insert(buf.get_u16(), Bytes::decode(buf)?); - } - ensure!(buf.remaining() >= 2, errors::Underflow); - let jobs = buf.get_u16(); - Ok(RestoreReady { jobs, headers }) - } -} - -impl Encode for RawPacket { - fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> { - buf.extend(&self.data); - Ok(()) - } -} - -impl Decode for RawPacket { - fn decode(buf: &mut Input) -> Result { - Ok(RawPacket { - data: buf.copy_to_bytes(buf.remaining()), - }) - } -} - -impl PrepareComplete { - pub fn get_capabilities(&self) -> Option { - decode_capabilities0(&self.headers) - } -} - -fn decode_capabilities0(headers: &KeyValues) -> Option { - headers.get(&0x1001).and_then(|bytes| { - if bytes.len() == 8 { - let mut array = [0u8; 8]; - array.copy_from_slice(bytes); - Some(Capabilities::from_bits_retain(u64::from_be_bytes(array))) - } else { - None - } - }) -} diff --git a/edgedb-protocol/src/value.rs b/edgedb-protocol/src/value.rs deleted file mode 100644 index 686f6dc0..00000000 --- a/edgedb-protocol/src/value.rs +++ /dev/null @@ -1,272 +0,0 @@ -/*! -Contains the [Value] enum. -*/ -pub use crate::codec::EnumValue; -use crate::codec::{ - InputObjectShape, InputShapeElement, NamedTupleShape, ObjectShape, SQLRowShape, -}; -use crate::common::Cardinality; -use crate::model::{BigInt, ConfigMemory, Decimal, Range, Uuid}; -use crate::model::{DateDuration, Json, RelativeDuration}; -use crate::model::{Datetime, Duration, LocalDate, LocalDatetime, LocalTime}; - -#[derive(Clone, Debug, PartialEq)] -pub enum Value { - Nothing, - Uuid(Uuid), - Str(String), - Bytes(bytes::Bytes), - Int16(i16), - Int32(i32), - Int64(i64), - Float32(f32), - Float64(f64), - BigInt(BigInt), - ConfigMemory(ConfigMemory), - Decimal(Decimal), - Bool(bool), - Datetime(Datetime), - LocalDatetime(LocalDatetime), - LocalDate(LocalDate), - LocalTime(LocalTime), - Duration(Duration), - RelativeDuration(RelativeDuration), - DateDuration(DateDuration), - Json(Json), - Set(Vec), - Object { - shape: ObjectShape, - fields: Vec>, - }, - SparseObject(SparseObject), - Tuple(Vec), - NamedTuple { - shape: NamedTupleShape, - fields: Vec, - }, - SQLRow { - shape: SQLRowShape, - fields: Vec, - }, - Array(Vec), - Vector(Vec), - Enum(EnumValue), - Range(Range>), - PostGisGeometry(bytes::Bytes), - PostGisGeography(bytes::Bytes), - PostGisBox2d(bytes::Bytes), - PostGisBox3d(bytes::Bytes), -} - -#[derive(Clone, Debug)] -pub struct SparseObject { - pub(crate) shape: InputObjectShape, - pub(crate) fields: Vec>>, -} - -impl Value { - pub fn kind(&self) -> &'static str { - use Value::*; - match self { - Array(..) => "array", - BigInt(..) => "bigint", - Bool(..) => "bool", - Bytes(..) => "bytes", - ConfigMemory(..) => "cfg::memory", - DateDuration(..) => "cal::date_duration", - Datetime(..) => "datetime", - Decimal(..) => "decimal", - Duration(..) => "duration", - Enum(..) => "enum", - Float32(..) => "float32", - Float64(..) => "float64", - Int16(..) => "int16", - Int32(..) => "int32", - Int64(..) => "int64", - Json(..) => "json", - LocalDate(..) => "cal::local_date", - LocalDatetime(..) => "cal::local_datetime", - LocalTime(..) => "cal::local_time", - NamedTuple { .. } => "named_tuple", - Nothing => "nothing", - Object { .. } => "object", - Range { .. } => "range", - RelativeDuration(..) => "cal::relative_duration", - Set(..) => "set", - SparseObject { .. } => "sparse_object", - Str(..) => "str", - Tuple(..) => "tuple", - Uuid(..) => "uuid", - Vector(..) => "ext::pgvector::vector", - PostGisGeometry(..) => "ext::postgis::geometry", - PostGisGeography(..) => "ext::postgis::geography", - PostGisBox2d(..) => "ext::postgis::box2d", - PostGisBox3d(..) => "ext::postgis::box3d", - SQLRow { .. } => "sql_row", - } - } - pub fn empty_tuple() -> Value { - Value::Tuple(Vec::new()) - } - - pub fn try_from_uuid(input: &str) -> Result { - Ok(Self::Uuid(Uuid::parse_str(input)?)) - } -} - -impl SparseObject { - /// Create a new sparse object from key-value pairs - /// - /// Note: this method has two limitations: - /// 1. Shape created uses `AtMostOne` cardinality for all the elements. - /// 2. There are no extra shape elements - /// - /// Both of these are irrelevant when serializing the object. - pub fn from_pairs>>( - iter: impl IntoIterator, - ) -> SparseObject { - let mut elements = Vec::new(); - let mut fields = Vec::new(); - for (key, val) in iter.into_iter() { - elements.push(InputShapeElement { - cardinality: Some(Cardinality::AtMostOne), - name: key.to_string(), - }); - fields.push(Some(val.into())); - } - SparseObject { - shape: InputObjectShape::new(elements), - fields, - } - } - /// Create an empty sparse object - pub fn empty() -> SparseObject { - SparseObject { - shape: InputObjectShape::new(Vec::new()), - fields: Vec::new(), - } - } - pub fn pairs(&self) -> impl Iterator)> { - self.shape - .0 - .elements - .iter() - .zip(&self.fields) - .filter_map(|(el, opt)| opt.as_ref().map(|opt| (&*el.name, opt.as_ref()))) - } -} - -impl PartialEq for SparseObject { - fn eq(&self, other: &SparseObject) -> bool { - let mut num = 0; - let o = &other.shape.0.elements; - for (el, value) in self.shape.0.elements.iter().zip(&self.fields) { - if let Some(value) = value { - num += 1; - if let Some(pos) = o.iter().position(|e| e.name == el.name) { - if other.fields[pos].as_ref() != Some(value) { - return false; - } - } - } - } - let other_num = other.fields.iter().filter(|e| e.is_some()).count(); - num == other_num - } -} - -impl From for Value { - fn from(s: String) -> Value { - Value::Str(s) - } -} - -impl From<&str> for Value { - fn from(s: &str) -> Value { - Value::Str(s.to_string()) - } -} - -impl From for Value { - fn from(b: bool) -> Value { - Value::Bool(b) - } -} - -impl From for Value { - fn from(s: i16) -> Value { - Value::Int16(s) - } -} - -impl From for Value { - fn from(s: i32) -> Value { - Value::Int32(s) - } -} - -impl From for Value { - fn from(s: i64) -> Value { - Value::Int64(s) - } -} - -impl From for Value { - fn from(num: f32) -> Value { - Value::Float32(num) - } -} - -impl From for Value { - fn from(num: f64) -> Value { - Value::Float64(num) - } -} - -impl From for Value { - fn from(model: BigInt) -> Value { - Value::BigInt(model) - } -} - -impl From for Value { - fn from(v: Decimal) -> Value { - Value::Decimal(v) - } -} - -impl From for Value { - fn from(v: Uuid) -> Value { - Value::Uuid(v) - } -} - -impl From for Value { - fn from(v: Json) -> Value { - Value::Json(v) - } -} - -impl From for Value { - fn from(v: Duration) -> Value { - Value::Duration(v) - } -} - -impl From for Value { - fn from(v: Datetime) -> Value { - Value::Datetime(v) - } -} - -impl From for Value { - fn from(v: LocalDate) -> Value { - Value::LocalDate(v) - } -} - -impl From for Value { - fn from(v: LocalDatetime) -> Value { - Value::LocalDatetime(v) - } -} diff --git a/edgedb-protocol/src/value_opt.rs b/edgedb-protocol/src/value_opt.rs deleted file mode 100644 index 16ced97a..00000000 --- a/edgedb-protocol/src/value_opt.rs +++ /dev/null @@ -1,126 +0,0 @@ -use std::collections::HashMap; - -use edgedb_errors::{ClientEncodingError, Error, ErrorKind}; - -use crate::codec::{ObjectShape, ShapeElement}; -use crate::descriptors::Descriptor; -use crate::query_arg::{Encoder, QueryArgs}; -use crate::value::Value; - -/// An optional [Value] that can be constructed from `impl Into`, -/// `Option>`, `Vec>` or -/// `Option>>`. -/// Used by [named_args!](`crate::named_args!`) macro. -#[derive(Clone, Debug, PartialEq)] -pub struct ValueOpt(Option); - -impl> From for ValueOpt { - fn from(value: V) -> Self { - ValueOpt(Some(value.into())) - } -} -impl> From> for ValueOpt -where - Value: From, -{ - fn from(value: Option) -> Self { - ValueOpt(value.map(Value::from)) - } -} -impl> From> for ValueOpt -where - Value: From, -{ - fn from(value: Vec) -> Self { - ValueOpt(Some(Value::Array( - value.into_iter().map(Value::from).collect(), - ))) - } -} -impl> From>> for ValueOpt -where - Value: From, -{ - fn from(value: Option>) -> Self { - let mapped = value.map(|value| Value::Array(value.into_iter().map(Value::from).collect())); - ValueOpt(mapped) - } -} -impl From for Option { - fn from(value: ValueOpt) -> Self { - value.0 - } -} - -impl QueryArgs for HashMap<&str, ValueOpt> { - fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> { - if self.is_empty() && encoder.ctx.root_pos.is_none() { - return Ok(()); - } - - let root_pos = encoder.ctx.root_pos.ok_or_else(|| { - ClientEncodingError::with_message(format!( - "provided {} named arguments, but no arguments were expected by the server", - self.len() - )) - })?; - - let Descriptor::ObjectShape(target_shape) = encoder.ctx.get(root_pos)? else { - return Err(ClientEncodingError::with_message( - "query didn't expect named arguments", - )); - }; - - let mut shape_elements: Vec = Vec::new(); - let mut fields: Vec> = Vec::new(); - - for param_descriptor in target_shape.elements.iter() { - let value = self.get(param_descriptor.name.as_str()); - - let Some(value) = value else { - return Err(ClientEncodingError::with_message(format!( - "argument for ${} missing", - param_descriptor.name - ))); - }; - - shape_elements.push(ShapeElement::from(param_descriptor)); - fields.push(value.0.clone()); - } - - Value::Object { - shape: ObjectShape::new(shape_elements), - fields, - } - .encode(encoder) - } -} - -/// Constructs named query arguments that implement [QueryArgs] so they can be passed -/// into any query method. -/// ```no_run -/// use edgedb_protocol::value::Value; -/// -/// let query = "SELECT ($my_str, $my_int)"; -/// let args = edgedb_protocol::named_args! { -/// "my_str" => "Hello world!".to_string(), -/// "my_int" => Value::Int64(42), -/// }; -/// ``` -/// -/// The value side of an argument must be `impl Into`. -/// The type of the returned object is `HashMap<&str, ValueOpt>`. -#[macro_export] -macro_rules! named_args { - ($($key:expr => $value:expr,)+) => { $crate::named_args!($($key => $value),+) }; - ($($key:expr => $value:expr),*) => { - { - const CAP: usize = <[()]>::len(&[$({ stringify!($key); }),*]); - let mut map = ::std::collections::HashMap::<&str, $crate::value_opt::ValueOpt>::with_capacity(CAP); - $( - map.insert($key, $crate::value_opt::ValueOpt::from($value)); - )* - map - } - }; -} diff --git a/edgedb-protocol/tests/base.rs b/edgedb-protocol/tests/base.rs deleted file mode 100644 index 43ce58cf..00000000 --- a/edgedb-protocol/tests/base.rs +++ /dev/null @@ -1,12 +0,0 @@ -#[macro_export] -macro_rules! bconcat { - ($($token: expr)*) => { - &{ - let mut buf = ::bytes::BytesMut::new(); - $( - buf.extend($token); - )* - buf - } - } -} diff --git a/edgedb-protocol/tests/client_messages.rs b/edgedb-protocol/tests/client_messages.rs deleted file mode 100644 index 46edb33e..00000000 --- a/edgedb-protocol/tests/client_messages.rs +++ /dev/null @@ -1,307 +0,0 @@ -use std::collections::HashMap; -use std::error::Error; - -use bytes::{Bytes, BytesMut}; -use edgedb_protocol::common::InputLanguage; -use uuid::Uuid; - -use edgedb_protocol::client_message::OptimisticExecute; -use edgedb_protocol::client_message::Restore; -use edgedb_protocol::client_message::SaslInitialResponse; -use edgedb_protocol::client_message::SaslResponse; -use edgedb_protocol::client_message::{Cardinality, IoFormat, Parse, Prepare}; -use edgedb_protocol::client_message::{ClientHandshake, ClientMessage}; -use edgedb_protocol::client_message::{DescribeAspect, DescribeStatement}; -use edgedb_protocol::client_message::{Execute0, Execute1, ExecuteScript}; -use edgedb_protocol::common::{Capabilities, CompilationFlags, State}; -use edgedb_protocol::encoding::{Input, Output}; -use edgedb_protocol::features::ProtocolVersion; - -mod base; - -macro_rules! encoding_eq_ver { - ($major: expr, $minor: expr, $message: expr, $bytes: expr) => { - let proto = ProtocolVersion::new($major, $minor); - let data: &[u8] = $bytes; - let mut bytes = BytesMut::new(); - $message.encode(&mut Output::new(&proto, &mut bytes))?; - println!("Serialized bytes {:?}", bytes); - let bytes = bytes.freeze(); - assert_eq!(&bytes[..], data); - assert_eq!( - ClientMessage::decode(&mut Input::new(proto, Bytes::copy_from_slice(data)))?, - $message, - ); - }; -} - -macro_rules! encoding_eq { - ($message: expr, $bytes: expr) => { - let (major, minor) = ProtocolVersion::current().version_tuple(); - encoding_eq_ver!(major, minor, $message, $bytes); - }; -} - -#[test] -fn client_handshake() -> Result<(), Box> { - encoding_eq!( - ClientMessage::ClientHandshake(ClientHandshake { - major_ver: 1, - minor_ver: 2, - params: HashMap::new(), - extensions: HashMap::new(), - }), - b"\x56\x00\x00\x00\x0C\x00\x01\x00\x02\x00\x00\x00\x00" - ); - Ok(()) -} - -#[test] -fn execute_script() -> Result<(), Box> { - encoding_eq!( - ClientMessage::ExecuteScript(ExecuteScript { - headers: HashMap::new(), - script_text: String::from("START TRANSACTION"), - }), - b"Q\0\0\0\x1b\0\0\0\0\0\x11START TRANSACTION" - ); - Ok(()) -} - -#[test] -fn prepare() -> Result<(), Box> { - encoding_eq_ver!( - 0, - 13, - ClientMessage::Prepare(Prepare { - headers: HashMap::new(), - io_format: IoFormat::Binary, - expected_cardinality: Cardinality::AtMostOne, - statement_name: Bytes::from_static(b"example"), - command_text: String::from("SELECT 1;"), - }), - b"P\0\0\0 \0\0bo\0\0\0\x07example\0\0\0\tSELECT 1;" - ); - Ok(()) -} - -#[test] -fn parse() -> Result<(), Box> { - encoding_eq_ver!( - 1, - 0, - ClientMessage::Parse(Parse { - annotations: None, - allowed_capabilities: Capabilities::MODIFICATIONS, - compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES, - implicit_limit: Some(77), - input_language: InputLanguage::EdgeQL, - output_format: IoFormat::Binary, - expected_cardinality: Cardinality::AtMostOne, - command_text: String::from("SELECT 1;"), - state: State { - typedesc_id: Uuid::from_u128(0), - data: Bytes::from(""), - }, - }), - b"P\0\0\0A\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0Mbo\ - \0\0\0\tSELECT 1;\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" - ); - Ok(()) -} - -#[test] -fn parse3() -> Result<(), Box> { - encoding_eq_ver!( - 3, - 0, - ClientMessage::Parse(Parse { - annotations: None, - allowed_capabilities: Capabilities::MODIFICATIONS, - compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES, - implicit_limit: Some(77), - input_language: InputLanguage::EdgeQL, - output_format: IoFormat::Binary, - expected_cardinality: Cardinality::AtMostOne, - command_text: String::from("SELECT 1;"), - state: State { - typedesc_id: Uuid::from_u128(0), - data: Bytes::from(""), - }, - }), - b"P\0\0\0B\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0MEbo\ - \0\0\0\tSELECT 1;\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" - ); - Ok(()) -} - -#[test] -fn describe_statement() -> Result<(), Box> { - encoding_eq!( - ClientMessage::DescribeStatement(DescribeStatement { - headers: HashMap::new(), - aspect: DescribeAspect::DataDescription, - statement_name: Bytes::from_static(b"example"), - }), - b"D\0\0\0\x12\0\0T\0\0\0\x07example" - ); - Ok(()) -} - -#[test] -fn execute0() -> Result<(), Box> { - encoding_eq_ver!( - 0, - 13, - ClientMessage::Execute0(Execute0 { - headers: HashMap::new(), - statement_name: Bytes::from_static(b"example"), - arguments: Bytes::new(), - }), - b"E\0\0\0\x15\0\0\0\0\0\x07example\0\0\0\0" - ); - Ok(()) -} - -#[test] -fn execute1() -> Result<(), Box> { - encoding_eq_ver!( - 1, - 0, - ClientMessage::Execute1(Execute1 { - annotations: None, - allowed_capabilities: Capabilities::MODIFICATIONS, - compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES, - implicit_limit: Some(77), - input_language: InputLanguage::EdgeQL, - output_format: IoFormat::Binary, - expected_cardinality: Cardinality::AtMostOne, - command_text: String::from("SELECT 1;"), - state: State { - typedesc_id: Uuid::from_u128(0), - data: Bytes::from(""), - }, - input_typedesc_id: Uuid::from_u128(123), - output_typedesc_id: Uuid::from_u128(456), - arguments: Bytes::new(), - }), - b"O\0\0\0e\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0Mbo\ - \0\0\0\tSELECT 1;\ - \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\ - \0\0\0{\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\xc8\0\0\0\0" - ); - Ok(()) -} - -#[test] -fn execute3() -> Result<(), Box> { - encoding_eq_ver!( - 3, - 0, - ClientMessage::Execute1(Execute1 { - annotations: None, - allowed_capabilities: Capabilities::MODIFICATIONS, - compilation_flags: CompilationFlags::INJECT_OUTPUT_TYPE_NAMES, - implicit_limit: Some(77), - input_language: InputLanguage::EdgeQL, - output_format: IoFormat::Binary, - expected_cardinality: Cardinality::AtMostOne, - command_text: String::from("SELECT 1;"), - state: State { - typedesc_id: Uuid::from_u128(0), - data: Bytes::from(""), - }, - input_typedesc_id: Uuid::from_u128(123), - output_typedesc_id: Uuid::from_u128(456), - arguments: Bytes::new(), - }), - b"O\0\0\0f\0\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0\x02\0\0\0\0\0\0\0MEbo\ - \0\0\0\tSELECT 1;\ - \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\ - \0\0\0{\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\xc8\0\0\0\0" - ); - Ok(()) -} - -#[test] -fn optimistic_execute() -> Result<(), Box> { - encoding_eq_ver!( - 0, - 13, - ClientMessage::OptimisticExecute(OptimisticExecute { - headers: HashMap::new(), - io_format: IoFormat::Binary, - expected_cardinality: Cardinality::AtMostOne, - command_text: String::from("COMMIT"), - input_typedesc_id: Uuid::from_u128(0xFF), - output_typedesc_id: Uuid::from_u128(0x0), - arguments: Bytes::new(), - }), - b"O\0\0\x006\0\0bo\0\0\0\x06COMMIT\ - \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\ - \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" - ); - Ok(()) -} - -#[test] -fn sync() -> Result<(), Box> { - encoding_eq!(ClientMessage::Sync, b"S\0\0\0\x04"); - Ok(()) -} - -#[test] -fn flush() -> Result<(), Box> { - encoding_eq!(ClientMessage::Flush, b"H\0\0\0\x04"); - Ok(()) -} - -#[test] -fn terminate() -> Result<(), Box> { - encoding_eq!(ClientMessage::Terminate, b"X\0\0\0\x04"); - Ok(()) -} - -#[test] -fn authentication() -> Result<(), Box> { - encoding_eq!( - ClientMessage::AuthenticationSaslInitialResponse(SaslInitialResponse { - method: "SCRAM-SHA-256".into(), - data: "n,,n=tutorial,r=%NR65>7bQ2S3jzl^k$G&b1^A".into(), - }), - bconcat!(b"p\0\0\0A\0\0\0\rSCRAM-SHA-256" - b"\0\0\0(n,,n=tutorial," - b"r=%NR65>7bQ2S3jzl^k$G&b1^A") - ); - encoding_eq!( - ClientMessage::AuthenticationSaslResponse(SaslResponse { - data: bconcat!(b"c=biws," - b"r=%NR65>7bQ2S3jzl^k$G&b1^A" - b"YsykYKRbp/Gli53UEElsGb4I," - b"p=UNQQkuQ0m5RRy24Ovzj/" - b"sCevUB36WTDbGXIWbCIsJmo=") - .clone() - .freeze(), - }), - bconcat!(b"r\0\0\0p" - b"\0\0\0hc=biws," - b"r=%NR65>7bQ2S3jzl^k$G&b1^A" - b"YsykYKRbp/Gli53UEElsGb4I," - b"p=UNQQkuQ0m5RRy24Ovzj/" - b"sCevUB36WTDbGXIWbCIsJmo=") - ); - Ok(()) -} - -#[test] -fn restore() -> Result<(), Box> { - encoding_eq!( - ClientMessage::Restore(Restore { - headers: HashMap::new(), - jobs: 1, - data: Bytes::from_static(b"TEST"), - }), - b"<\x00\x00\x00\x0C\x00\x00\x00\x01TEST" - ); - Ok(()) -} diff --git a/edgedb-protocol/tests/codecs.rs b/edgedb-protocol/tests/codecs.rs deleted file mode 100644 index fc73bab1..00000000 --- a/edgedb-protocol/tests/codecs.rs +++ /dev/null @@ -1,1430 +0,0 @@ -#[macro_use] -extern crate pretty_assertions; - -use std::error::Error; -use std::sync::Arc; - -use bytes::Bytes; - -use edgedb_protocol::codec::build_codec; -use edgedb_protocol::codec::{Codec, ObjectShape}; -use edgedb_protocol::common::RawTypedesc; -use edgedb_protocol::descriptors::ArrayTypeDescriptor; -use edgedb_protocol::descriptors::BaseScalarTypeDescriptor; -use edgedb_protocol::descriptors::EnumerationTypeDescriptor; -use edgedb_protocol::descriptors::ScalarTypeDescriptor; -use edgedb_protocol::descriptors::SetDescriptor; -use edgedb_protocol::descriptors::TupleTypeDescriptor; -use edgedb_protocol::descriptors::{Descriptor, TypePos}; -use edgedb_protocol::descriptors::{MultiRangeTypeDescriptor, RangeTypeDescriptor}; -use edgedb_protocol::descriptors::{NamedTupleTypeDescriptor, TupleElement}; -use edgedb_protocol::descriptors::{ObjectShapeDescriptor, ShapeElement}; -use edgedb_protocol::features::ProtocolVersion; -use edgedb_protocol::model::{Datetime, Json, RelativeDuration}; -use edgedb_protocol::model::{Duration, LocalDate, LocalTime}; -use edgedb_protocol::server_message::StateDataDescription; -use edgedb_protocol::value::{SparseObject, Value}; -use uuid::Uuid; - -mod base; - -macro_rules! encoding_eq { - ($codec: expr, $bytes: expr, $value: expr) => { - let orig_value = $value; - let value = decode($codec, $bytes)?; - assert_eq!(value, orig_value); - let mut bytes = bytes::BytesMut::new(); - $codec.encode(&mut bytes, &orig_value)?; - println!("Serialized bytes {:?}", bytes); - let bytes = bytes.freeze(); - assert_eq!(&bytes[..], $bytes); - }; -} - -fn decode(codec: &Arc, data: &[u8]) -> Result> { - Ok(codec.decode(data)?) -} - -#[test] -fn int16() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000103" - .parse::()? - .into(), - })], - )?; - encoding_eq!(&codec, b"\0\0", Value::Int16(0)); - encoding_eq!(&codec, b"\x01\x05", Value::Int16(0x105)); - encoding_eq!(&codec, b"\x7F\xFF", Value::Int16(i16::MAX)); - encoding_eq!(&codec, b"\x80\x00", Value::Int16(i16::MIN)); - encoding_eq!(&codec, b"\xFF\xFF", Value::Int16(-1)); - Ok(()) -} - -#[test] -fn int32() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000104" - .parse::()? - .into(), - })], - )?; - encoding_eq!(&codec, b"\0\0\0\0", Value::Int32(0)); - encoding_eq!(&codec, b"\0\0\x01\x05", Value::Int32(0x105)); - encoding_eq!(&codec, b"\x7F\xFF\xFF\xFF", Value::Int32(i32::MAX)); - encoding_eq!(&codec, b"\x80\x00\x00\x00", Value::Int32(i32::MIN)); - encoding_eq!(&codec, b"\xFF\xFF\xFF\xFF", Value::Int32(-1)); - Ok(()) -} - -#[test] -fn int64() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - })], - )?; - encoding_eq!(&codec, b"\0\0\0\0\0\0\0\0", Value::Int64(0)); - encoding_eq!(&codec, b"\0\0\0\0\0\0\x01\x05", Value::Int64(0x105)); - encoding_eq!( - &codec, - b"\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF", - Value::Int64(i64::MAX) - ); - encoding_eq!( - &codec, - b"\x80\x00\x00\x00\x00\x00\x00\x00", - Value::Int64(i64::MIN) - ); - encoding_eq!( - &codec, - b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF", - Value::Int64(-1) - ); - Ok(()) -} - -#[test] -fn float32() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000106" - .parse::()? - .into(), - })], - )?; - - encoding_eq!(&codec, b"\0\0\0\0", Value::Float32(0.0)); - encoding_eq!(&codec, b"\x80\0\0\0", Value::Float32(-0.0)); - encoding_eq!(&codec, b"?\x80\0\0", Value::Float32(1.0)); - encoding_eq!(&codec, b"\xbf\x8f\xbew", Value::Float32(-1.123)); - - match decode(&codec, b"\x7f\xc0\0\0")? { - Value::Float32(val) => assert!(val.is_nan()), - _ => panic!("could not parse NaN"), - }; - - match decode(&codec, b"\x7f\x80\0\0")? { - Value::Float32(val) => { - assert!(val.is_infinite()); - assert!(val.is_sign_positive()) - } - _ => panic!("could not parse +inf"), - }; - - match decode(&codec, b"\xff\x80\0\0")? { - Value::Float32(val) => { - assert!(val.is_infinite()); - assert!(val.is_sign_negative()) - } - _ => panic!("could not parse -inf"), - }; - - Ok(()) -} - -#[test] -fn float64() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000107" - .parse::()? - .into(), - })], - )?; - - encoding_eq!(&codec, b"\0\0\0\0\0\0\0\0", Value::Float64(0.0)); - encoding_eq!(&codec, b"\x80\0\0\0\0\0\0\0", Value::Float64(-0.0)); - encoding_eq!(&codec, b"?\xf0\0\0\0\0\0\0", Value::Float64(1.0)); - encoding_eq!(&codec, b"T\xb2I\xad%\x94\xc3}", Value::Float64(1e100)); - - match decode(&codec, b"\x7f\xf8\0\0\0\0\0\0")? { - Value::Float64(val) => assert!(val.is_nan()), - _ => panic!("could not parse NaN"), - }; - - match decode(&codec, b"\x7f\xf0\0\0\0\0\0\0")? { - Value::Float64(val) => { - assert!(val.is_infinite()); - assert!(val.is_sign_positive()) - } - _ => panic!("could not parse +inf"), - }; - - match decode(&codec, b"\xff\xf0\0\0\0\0\0\0")? { - Value::Float64(val) => { - assert!(val.is_infinite()); - assert!(val.is_sign_negative()) - } - _ => panic!("could not parse -inf"), - }; - - Ok(()) -} - -#[test] -fn str() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000101" - .parse::()? - .into(), - })], - )?; - encoding_eq!(&codec, b"hello", Value::Str(String::from("hello"))); - encoding_eq!(&codec, b"", Value::Str(String::from(""))); - encoding_eq!( - &codec, - b"\xd0\xbf\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82", - Value::Str(String::from("привет")) - ); - Ok(()) -} - -#[test] -fn bytes() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000102" - .parse::()? - .into(), - })], - )?; - encoding_eq!(&codec, b"hello", Value::Bytes(b"hello"[..].into())); - encoding_eq!(&codec, b"", Value::Bytes(b""[..].into())); - encoding_eq!( - &codec, - b"\x00\x01\x02\x03\x81", - Value::Bytes(b"\x00\x01\x02\x03\x81"[..].into()) - ); - Ok(()) -} - -#[test] -fn uuid() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000100" - .parse::()? - .into(), - })], - )?; - encoding_eq!( - &codec, - b"I(\xcc\x1e e\x11\xea\x88H{S\xa6\xad\xb3\x83", - Value::Uuid("4928cc1e-2065-11ea-8848-7b53a6adb383".parse::()?) - ); - Ok(()) -} - -#[test] -fn duration() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-00000000010e" - .parse::()? - .into(), - })], - )?; - - // SELECT '2019-11-29T00:00:00Z'-'2000-01-01T00:00:00Z' - encoding_eq!( - &codec, - b"\0\x02;o\xad\xff\0\0\0\0\0\0\0\0\0\0", - Value::Duration(Duration::from_micros(7272 * 86400 * 1_000_000)) - ); - // SELECT '2019-11-29T00:00:00Z'-'2019-11-28T01:00:00Z' - encoding_eq!( - &codec, - b"\0\0\0\x13GC\xbc\0\0\0\0\0\0\0\0\0", - Value::Duration(Duration::from_micros(82800 * 1_000_000)) - ); - encoding_eq!( - &codec, - b"\xff\xff\xff\xff\xd3,\xba\xe0\0\0\0\0\0\0\0\0", - Value::Duration(Duration::from_micros(-752043296)) - ); - - assert_eq!( - decode(&codec, b"\0\0\0\0\0\0\0\0\0\0\0\x01\0\0\0\0") - .unwrap_err() - .to_string(), - "non-zero reserved bytes received in data" - ); - Ok(()) -} - -#[test] -fn relative_duration() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000111" - .parse::()? - .into(), - })], - )?; - - // SELECT - // '2 years 7 months 16 days 48 hours 45 minutes 7.6 seconds' - encoding_eq!( - &codec, - b"\0\0\0\x28\xdd\x11\x72\x80\0\0\0\x10\0\0\0\x1f", - Value::RelativeDuration( - RelativeDuration::from_years(2) - + RelativeDuration::from_months(7) - + RelativeDuration::from_days(16) - + RelativeDuration::from_hours(48) - + RelativeDuration::from_minutes(45) - + RelativeDuration::from_secs(7) - + RelativeDuration::from_millis(600) - ) - ); - Ok(()) -} - -#[test] -fn null_codec() -> Result<(), Box> { - let codec = build_codec(None, &[])?; - encoding_eq!(&codec, b"", Value::Nothing); - Ok(()) -} - -#[test] -fn object_codec() -> Result<(), Box> { - let elements = vec![ - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: String::from("__tid__"), - type_pos: TypePos(0), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: String::from("id"), - type_pos: TypePos(0), - source_type_pos: None, - }, - ]; - let shape = elements.as_slice().into(); - let codec = build_codec( - Some(TypePos(1)), - &[ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000100" - .parse::()? - .into(), - }), - Descriptor::ObjectShape(ObjectShapeDescriptor { - id: "5d5ebe41-eac8-eab7-a24e-cc3a8cd2766c" - .parse::()? - .into(), - ephemeral_free_shape: false, - type_pos: None, - elements, - }), - ], - )?; - // TODO(tailhook) test with non-zero reserved bytes - encoding_eq!( - &codec, - bconcat!( - b"\0\0\0\x02\0\0\x00\x00\0\0\0\x100Wd\0 d" - b"\x11\xea\x98\xc53\xc5\xcf\xb4r^\0\0\x00" - b"\x00\0\0\0\x10I(\xcc\x1e e\x11\xea\x88H{S" - b"\xa6\xad\xb3\x83"), - Value::Object { - shape, - fields: vec![ - Some(Value::Uuid("30576400-2064-11ea-98c5-33c5cfb4725e".parse()?)), - Some(Value::Uuid("4928cc1e-2065-11ea-8848-7b53a6adb383".parse()?)), - ] - } - ); - Ok(()) -} - -#[test] -fn input_codec() -> Result<(), Box> { - let sdd = StateDataDescription { - typedesc: RawTypedesc { - proto: ProtocolVersion::new(1, 0), - id: "fd6c3b17504a714858ec2282431ce72c".parse()?, - data: Bytes::from_static( - b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\ - \x01\x01\x04\xcf\x9d\xce6\x17\xf05O\t%g\x8eW\xa1\x842\0\x02\ - \0\0\0\0\x06\xc6R\xf3\xf1\xdd\xe7\0a?\x07|=&\x0b\xfbt\0\x01\ - \0\x01\xff\xff\xff\xff\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\ - \x0e\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05\0\xa5zjc\xee\ - \xc4@\x91\xabnI\x97#\xf5\xe8\xaa\0\0\x02\0\0\0\0\0\0\0\0\ - \0\0\0\0\0\0\x01\t\x01\xd9\xa1-\xbfH\xfa\xeb\x1a/\xf5xe7\ - \xc8\xb8\xee\0\0\0\x16w\xe5\x87Y\xbd\x05\xb9\x14\xce\x8a\ - \xc2\x99\x85b5\0\x07\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x010\ - \x07\x82w\xed\x1a\xfd\xe0\x11\xec\x8bl\x85\xd0\xc8\xdc\xcd[\ - \0\x02\0\0\0\x0bAlwaysAllow\0\0\0\nNeverAllow\x01v\x9eH\xcb#\\1\ - \x90c&\x9b\x90p-\xa7\x03\0\0\0\xb1\xef6\xe2\xbb%Wr\xafk\x11\x84l\ - \x183n\0\x0b\x07\x85[<\"\xfd\xe0\x11\xec\x9a\xf6\xa1U\x99\xf2+\xc2\ - \0\x03\0\0\0\x03One\0\0\0\x03Two\0\0\0\x05Three\x08t\x13\xa1IP\xe6\ - \xc3\xf9*\xd7U1\x9f\xf1\xe1o\0\x10\0\0\0\0o\0\0\0\x07durprop\0\x03\ - \0\0\0\0o\0\0\0\x14__pg_max_connections\0\x04\0\0\0\0o\ - \0\0\0\x17query_execution_timeout\0\x03\0\0\0\0m\0\0\0\tmultiprop\ - \0\x05\0\0\0\0o\0\0\0\x1b__internal_no_const_folding\0\x06\0\0\0\0\ - m\0\0\0\x06sysobj\0\x08\0\0\0\0o\0\0\0\x07memprop\0\t\0\0\0\0o\0\0\ - \0\x13__internal_testmode\0\x06\0\0\0\0o\ - \0\0\0\x15apply_access_policies\0\x06\0\0\0\0o\ - \0\0\0 session_idle_transaction_timeout\0\x03\0\0\0\0o\ - \0\0\0\x0eallow_bare_ddl\0\n\0\0\0\0o\0\0\0\nsingleprop\ - \0\0\0\0\0\0o\0\0\0\x16allow_dml_in_functions\0\x06\0\0\0\0\ - o\0\0\0\x19__internal_sess_testvalue\0\x04\0\0\0\0m\ - \0\0\0\x07sessobj\0\x0c\0\0\0\0o\0\0\0\x08enumprop\0\r\x08!s\xfc,)\ - \x19\x80\x13/E\xea\xf3!\x98\x84\t\0\x01\0\0\0\0o\ - \0\0\0\x17default::my_globalvar_1\0\0\x08\xfdl;\x17PJqHX\xec\"\x82\ - C\x1c\xe7,\0\x04\0\0\0\0o\0\0\0\x06module\0\0\0\0\0\0o\ - \0\0\0\x07aliases\0\x02\0\0\0\0o\0\0\0\x07globals\0\x0f\0\0\0\0\ - o\0\0\0\x06config\0\x0e", - ), - }, - }; - let out_desc = sdd.parse()?; - let codec = build_codec(Some(TypePos(16)), out_desc.descriptors())?; - encoding_eq!( - &codec, - b"\0\0\0\x03\0\0\0\0\0\0\0\x07default\0\0\0\x02\0\0\0\x1c\ - \0\0\0\x01\0\0\0\0\0\0\0\x10GLOBAL VAR VALUE\ - \0\0\0\x03\0\0\0\x1c\0\0\0\x01\0\0\0\t\0\0\0\x10\ - \0\0\0\0\x11\xe1\xa3\0\0\0\0\0\0\0\0\0", - Value::SparseObject(SparseObject::from_pairs([ - ("module", Some(Value::Str("default".into()))), - ( - "globals", - Some(Value::SparseObject(SparseObject::from_pairs([( - "default::my_globalvar_1", - Some(Value::Str("GLOBAL VAR VALUE".into())) - ),]))) - ), - ( - "config", - Some(Value::SparseObject(SparseObject::from_pairs([( - "session_idle_transaction_timeout", - Some(Value::Duration(Duration::from_micros(300_000_000))) - ),]))) - ), - ])) - ); - Ok(()) -} - -#[test] -fn set_codec() -> Result<(), Box> { - let inner_elements = vec![ - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: "__tid__".into(), - type_pos: TypePos(0), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: "id".into(), - type_pos: TypePos(0), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: "first_name".into(), - type_pos: TypePos(1), - source_type_pos: None, - }, - ]; - let outer_elements = vec![ - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: "__tid__".into(), - type_pos: TypePos(0), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: "id".into(), - type_pos: TypePos(0), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: "first_name".into(), - type_pos: TypePos(1), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: true, - cardinality: None, - name: "collegues".into(), - type_pos: TypePos(3), - source_type_pos: None, - }, - ]; - let inner_shape = ObjectShape::from(&inner_elements[..]); - let outer_shape = ObjectShape::from(&outer_elements[..]); - let codec = build_codec( - Some(TypePos(4)), - &[ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000100" - .parse::()? - .into(), - }), - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000101" - .parse::()? - .into(), - }), - Descriptor::ObjectShape(ObjectShapeDescriptor { - id: "8faa7193-48c6-4263-18d3-1a127652569b" - .parse::()? - .into(), - elements: inner_elements, - ephemeral_free_shape: false, - type_pos: None, - }), - Descriptor::Set(SetDescriptor { - id: "afbb389d-aa73-2aae-9310-84a9163cb5ed" - .parse::()? - .into(), - type_pos: TypePos(2), - }), - Descriptor::ObjectShape(ObjectShapeDescriptor { - id: "9740ff04-324e-08a4-4ac7-2192d72c6967" - .parse::()? - .into(), - elements: outer_elements, - ephemeral_free_shape: false, - type_pos: None, - }), - ], - )?; - // TODO(tailhook) test with non-zero reserved bytes - encoding_eq!( - &codec, - bconcat!( - b"\0\0\0\x04\0\0\x00\x00\0\0\0\x10\x0c\xf06\xbd " - b"\xbd\x11\xea\xa4\xeb\xe9T\xb4(\x13\x91\0\0\x00\x00\0\0\0\x10" - b"[\xe3\x9c( \xbd\x11\xea\xaa\xb9g4\x82*\xf1\xc9\0\0\0\x00\0\0\0" - b"\x04Ryan\0\0\x00\x00\0\0\0\x9f\0\0\0\x01\0\0\0\0\0\0\x00\x00\0" - b"\0\0\x02\0\0\0\x01\0\0\0?\0\0\0\x03\0\0\x00\x00\0\0\0\x10\x0c\xf0" - b"6\xbd \xbd\x11\xea\xa4\xeb\xe9T\xb4(\x13\x91\0\0\x00\x00\0\0\0\x10" - b"[\xe3\x9e\x80 \xbd\x11\xea\xaa\xb9\x17]\xbf\x18G\xe5\0\0\0\x00\0\0" - b"\0\x03Ana\0\0\0D\0\0\0\x03\0\0\x00\x00\0\0\0\x10\x0c\xf06\xbd " - b"\xbd\x11\xea\xa4\xeb\xe9T\xb4(\x13\x91\0\0\x00\x00\0\0\0\x10[" - b"\xe3\x97\x14 \xbd\x11\xea\xaa\xb9?7\xe7 \xb8T\0\0\0\x00\0\0\0" - b"\x08Harrison" - ), - Value::Object { - shape: outer_shape.clone(), - fields: vec![ - Some(Value::Uuid("0cf036bd-20bd-11ea-a4eb-e954b4281391".parse()?)), - Some(Value::Uuid("5be39c28-20bd-11ea-aab9-6734822af1c9".parse()?)), - Some(Value::Str(String::from("Ryan"))), - Some(Value::Set(vec![ - Value::Object { - shape: inner_shape.clone(), - fields: vec![ - Some(Value::Uuid("0cf036bd-20bd-11ea-a4eb-e954b4281391".parse()?)), - Some(Value::Uuid("5be39e80-20bd-11ea-aab9-175dbf1847e5".parse()?)), - Some(Value::Str(String::from("Ana"))), - ] - }, - Value::Object { - shape: inner_shape, - fields: vec![ - Some(Value::Uuid("0cf036bd-20bd-11ea-a4eb-e954b4281391".parse()?)), - Some(Value::Uuid("5be39714-20bd-11ea-aab9-3f37e720b854".parse()?)), - Some(Value::Str(String::from("Harrison"))), - ] - } - ])), - ] - } - ); - encoding_eq!( - &codec, - bconcat!(b"\0\0\0\x04\0\0\x00\x00\0\0\0\x10\x0c\xf06" - b"\xbd \xbd\x11\xea\xa4\xeb\xe9T\xb4(\x13\x91\0\0\x00\x00\0\0\0\x10" - b"[\xe3\x9c( \xbd\x11\xea\xaa\xb9g4\x82*\xf1\xc9\0\0\0\x00" - b"\0\0\0\x04Ryan\0\0\x00\x00\0\0\0\x0c\0\0\0\0\0\0\0\0\0\0\x00\x00" - ), - Value::Object { - shape: outer_shape.clone(), - fields: vec![ - Some(Value::Uuid("0cf036bd-20bd-11ea-a4eb-e954b4281391".parse()?)), - Some(Value::Uuid("5be39c28-20bd-11ea-aab9-6734822af1c9".parse()?)), - Some(Value::Str(String::from("Ryan"))), - Some(Value::Set(vec![])), - ] - } - ); - encoding_eq!( - &codec, - bconcat!(b"\0\0\0\x04\0\0\x00\x00\0\0\0\x10\x0c\xf06" - b"\xbd \xbd\x11\xea\xa4\xeb\xe9T\xb4(\x13\x91\0\0\x00\x00\0\0\0\x10" - b"[\xe3\x9c( \xbd\x11\xea\xaa\xb9g4\x82*\xf1\xc9\0\0\0\x00" - b"\xFF\xFF\xFF\xFF\0\0\x00\x00\0\0\0\x0c\0\0\0\0\0\0\0\0\0\0\x00\x00" - ), - Value::Object { - shape: outer_shape, - fields: vec![ - Some(Value::Uuid("0cf036bd-20bd-11ea-a4eb-e954b4281391".parse()?)), - Some(Value::Uuid("5be39c28-20bd-11ea-aab9-6734822af1c9".parse()?)), - None, - Some(Value::Set(vec![])), - ] - } - ); - Ok(()) -} - -#[test] -#[cfg(feature = "num-bigint")] -fn bigint() -> Result<(), Box> { - use num_bigint::BigInt; - use std::convert::TryInto; - use std::str::FromStr; - - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000110" - .parse::()? - .into(), - })], - )?; - encoding_eq!(&codec, b"\0\x01\0\0\0\0\0\0\0*", Value::BigInt(42.into())); - encoding_eq!( - &codec, - b"\0\x01\0\x01\0\0\0\0\0\x03", - Value::BigInt((30000).into()) - ); - encoding_eq!( - &codec, - b"\0\x02\0\x01\0\0\0\0\0\x03\0\x01", - Value::BigInt((30001).into()) - ); - encoding_eq!( - &codec, - b"\0\x02\0\x01@\0\0\0\0\x01\x13\x88", - Value::BigInt((-15000).into()) - ); - encoding_eq!( - &codec, - b"\0\x01\0\x05\0\0\0\0\0\n", - Value::BigInt(BigInt::from_str("1000000000000000000000")?.try_into()?) - ); - Ok(()) -} - -#[test] -#[cfg(feature = "bigdecimal")] -fn decimal() -> Result<(), Box> { - use bigdecimal::BigDecimal; - use std::convert::TryInto; - use std::str::FromStr; - - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000108" - .parse::()? - .into(), - })], - )?; - encoding_eq!( - &codec, - b"\0\x01\0\0\0\0\0\x02\0*", - Value::Decimal(BigDecimal::from_str("42.00")?.try_into()?) - ); - encoding_eq!( - &codec, - b"\0\x05\0\x01\0\0\0\t\x04\xd2\x16.#4\r\x80\x1bX", - Value::Decimal(BigDecimal::from_str("12345678.901234567")?.try_into()?) - ); - encoding_eq!( - &codec, - b"\0\x01\0\x19\0\0\0\0\0\x01", - Value::Decimal(BigDecimal::from_str("1e100")?.try_into()?) - ); - encoding_eq!( - &codec, - b"\0\x06\0\x0b@\0\0\0\0\x07\x01P\x1cB\x08\x9e$!\0\xc8", - Value::Decimal( - BigDecimal::from_str("-703367234220692490200000000000000000000000000")?.try_into()? - ) - ); - encoding_eq!( - &codec, - b"\0\x06\0\x0b@\0\0\0\0\x07\x01P\x1cB\x08\x9e$!\0\xc8", - Value::Decimal(BigDecimal::from_str("-7033672342206924902e26")?.try_into()?) - ); - Ok(()) -} - -#[test] -fn bool() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000109" - .parse::()? - .into(), - })], - )?; - encoding_eq!(&codec, b"\x01", Value::Bool(true)); - encoding_eq!(&codec, b"\x00", Value::Bool(false)); - Ok(()) -} - -#[test] -fn datetime() -> Result<(), Box> { - use std::time::Duration; - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-00000000010a" - .parse::()? - .into(), - })], - )?; - - encoding_eq!( - &codec, - b"\0\x02=^\x1bTc\xe7", - Value::Datetime(Datetime::UNIX_EPOCH + Duration::new(1577109148, 156903000)) - ); - Ok(()) -} - -#[test] -fn local_datetime() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-00000000010b" - .parse::()? - .into(), - })], - )?; - - encoding_eq!( - &codec, - b"\0\x02=^@\xf9\x1f\xfd", - Value::LocalDatetime(Datetime::from_unix_micros(1577109779709949).into()) - ); - Ok(()) -} - -#[test] -fn local_date() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-00000000010c" - .parse::()? - .into(), - })], - )?; - - encoding_eq!( - &codec, - b"\0\0\x1c\x80", - Value::LocalDate(LocalDate::from_days(7296)) - ); - Ok(()) -} - -#[test] -fn vector() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "9565dd88-04f5-11ee-a691-0b6ebe179825" - .parse::()? - .into(), - })], - )?; - - encoding_eq!( - &codec, - b"\0\x03\0\0?\x80\0\0@\0\0\0@@\0\0", - Value::Vector(vec![1., 2., 3.]) - ); - Ok(()) -} - -#[test] -fn local_time() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-00000000010d" - .parse::()? - .into(), - })], - )?; - - encoding_eq!( - &codec, - b"\0\0\0\x0b\xd7\x84\0\x01", - Value::LocalTime(LocalTime::from_micros(50860392449)) - ); - Ok(()) -} - -#[test] -fn json() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-00000000010f" - .parse::()? - .into(), - })], - )?; - - encoding_eq!( - &codec, - b"\x01\"txt\"", - Value::Json(Json::new_unchecked(String::from(r#""txt""#))) - ); - Ok(()) -} - -#[test] -fn custom_scalar() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000101" - .parse::()? - .into(), - }), - Descriptor::Scalar(ScalarTypeDescriptor { - id: "234dc787-2646-11ea-bebd-010d530c06ca" - .parse::()? - .into(), - base_type_pos: Some(TypePos(0)), - name: None, - schema_defined: None, - ancestors: vec![], - }), - ], - )?; - - encoding_eq!(&codec, b"xx", Value::Str(String::from("xx"))); - Ok(()) -} - -#[test] -fn tuple() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(2)), - &[ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - }), - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000101" - .parse::()? - .into(), - }), - Descriptor::Tuple(TupleTypeDescriptor { - id: "6c87a50a-fce2-dcae-6872-8c4c9c4d1e7c" - .parse::()? - .into(), - element_types: vec![TypePos(0), TypePos(1)], - name: None, - schema_defined: None, - ancestors: vec![], - }), - ], - )?; - - // TODO(tailhook) test with non-zero reserved bytes - encoding_eq!( - &codec, - bconcat!(b"\0\0\0\x02\0\0\0\x00\0\0\0" - b"\x08\0\0\0\0\0\0\0\x01\0\0\0\x00\0\0\0\x03str"), - Value::Tuple(vec![Value::Int64(1), Value::Str("str".into()),]) - ); - Ok(()) -} - -#[test] -fn named_tuple() -> Result<(), Box> { - let elements = vec![ - TupleElement { - name: "a".into(), - type_pos: TypePos(0), - }, - TupleElement { - name: "b".into(), - type_pos: TypePos(1), - }, - ]; - let shape = elements.as_slice().into(); - let codec = build_codec( - Some(TypePos(2)), - &[ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - }), - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000101" - .parse::()? - .into(), - }), - Descriptor::NamedTuple(NamedTupleTypeDescriptor { - id: "101385c1-d6d5-ec67-eec4-b2b88be8a197" - .parse::()? - .into(), - elements, - name: None, - schema_defined: None, - ancestors: vec![], - }), - ], - )?; - - // TODO(tailhook) test with non-zero reserved bytes - encoding_eq!( - &codec, - bconcat!(b"\0\0\0\x02\0\0\0\x00\0\0\0" - b"\x08\0\0\0\0\0\0\0\x01\0\0\0\x00\0\0\0\x01x"), - Value::NamedTuple { - shape, - fields: vec![Value::Int64(1), Value::Str("x".into()),], - } - ); - Ok(()) -} - -#[test] -fn array() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(1)), - &[ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - }), - Descriptor::Array(ArrayTypeDescriptor { - id: "b0105467-a177-635f-e207-0a21867f9be0" - .parse::()? - .into(), - type_pos: TypePos(0), - dimensions: vec![None], - name: None, - schema_defined: None, - ancestors: vec![], - }), - ], - )?; - - // TODO(tailhook) test with non-zero reserved bytes - encoding_eq!( - &codec, - bconcat!(b"\0\0\0\x01\0\0\0\0\0\0\0\x00\0\0\0\x03" - b"\0\0\0\x01\0\0\0\x08\0\0\0\0\0\0\0\x01" - b"\0\0\0\x08\0\0\0\0\0\0\0\x02" - b"\0\0\0\x08\0\0\0\0\0\0\0\x03"), - Value::Array(vec![Value::Int64(1), Value::Int64(2), Value::Int64(3),]) - ); - encoding_eq!( - &codec, - bconcat!(b"\0\0\0\0\0\0\0\0\0\0\0\x00"), - Value::Array(vec![]) - ); - Ok(()) -} - -#[test] -fn enums() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::Enumeration(EnumerationTypeDescriptor { - id: "ac5dc6a4-2656-11ea-aa6d-233f91e80ff6" - .parse::()? - .into(), - members: vec!["x".into(), "y".into()], - name: None, - schema_defined: None, - ancestors: vec![], - })], - )?; - encoding_eq!(&codec, bconcat!(b"x"), Value::Enum("x".into())); - Ok(()) -} - -#[test] -fn set_of_arrays() -> Result<(), Box> { - let elements = vec![ - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: String::from("__tname__"), - type_pos: TypePos(0), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: String::from("id"), - type_pos: TypePos(1), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: String::from("sets"), - type_pos: TypePos(4), - source_type_pos: None, - }, - ]; - let shape = ObjectShape::from(&elements[..]); - let elements = elements.as_slice().into(); - let codec = build_codec( - Some(TypePos(5)), - &[ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000101" - .parse::()? - .into(), // str - }), - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000100" - .parse::()? - .into(), // uuid - }), - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), // int64 - }), - Descriptor::Array(ArrayTypeDescriptor { - id: "b0105467-a177-635f-e207-0a21867f9be0" - .parse::()? - .into(), - type_pos: TypePos(2), - dimensions: vec![None], - name: None, - schema_defined: None, - ancestors: vec![], - }), - Descriptor::Set(SetDescriptor { - id: "499ffd5c-f21b-574d-af8a-1c094c9d6fb0" - .parse::()? - .into(), - type_pos: TypePos(3), - }), - Descriptor::ObjectShape(ObjectShapeDescriptor { - id: "499ffd5c-f21b-574d-af8a-1c094c9d6fb0" - .parse::()? - .into(), - elements, - ephemeral_free_shape: false, - type_pos: None, - }), - ], - )?; - encoding_eq!( - &codec, - bconcat!( - // TODO(tailhook) test with non-zero reserved bytes - b"\0\0\0\x03\0\0\0\0\0\0\0\x10schema::Function" - b"\0\0\0\0\0\0\0\x10\xb8\xf2\x91\x99\x8b#\x11" - b"\xeb\xb9EO\x882\x0e[\xd6\0\0\0\0\0\0\0\x80" - b"\0\0\0\x01\0\0\0\0\0\0\0\0\0\0\0\x02\0\0\0\x01\0\0\08" - b"\0\0\0\x01\0\0\0\0\0\0\0,\0\0\0\x01\0\0\0\0\0\0\0\0" - b"\0\0\0\x02\0\0\0\x01\0\0\0\x08\0\0\0\0\0\0\0\x01\0\0\0\x08" - b"\0\0\0\0\0\0\0\x02\0\0\0,\0\0\0\x01\0\0\0\0\0\0\0 " - b"\0\0\0\x01\0\0\0\0\0\0\0\0\0\0\0\x01\0\0\0\x01\0\0\0\x08" - b"\0\0\0\0\0\0\0\x03"), - Value::Object { - shape, - fields: vec![ - Some(Value::Str("schema::Function".into())), - Some(Value::Uuid("b8f29199-8b23-11eb-b945-4f88320e5bd6".parse()?)), - Some(Value::Set(vec![ - Value::Array(vec![Value::Int64(1), Value::Int64(2),]), - Value::Array(vec![Value::Int64(3),]), - ])) - ] - } - ); - Ok(()) -} - -#[test] -fn range() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(1)), - &[ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - }), - Descriptor::Range(RangeTypeDescriptor { - id: "7f8919fd845bb1badae19d40d96ea0a8" - .parse::() - .unwrap() - .into(), - type_pos: TypePos(0), - name: None, - schema_defined: None, - ancestors: vec![], - }), - ], - )?; - - encoding_eq!( - &codec, - b"\x02\0\0\0\x08\0\0\0\0\0\0\0\x07\0\0\0\x08\0\0\0\0\0\0\0'", - std::ops::Range { - start: 7i64, - end: 39 - } - .into() - ); - Ok(()) -} - -#[test] -fn multi_range() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(1)), - &[ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - }), - Descriptor::MultiRange(MultiRangeTypeDescriptor { - id: "08fc943ff87d44b68e76ba8dbeed4d00" - .parse::() - .unwrap() - .into(), - type_pos: TypePos(0), - name: None, - schema_defined: None, - ancestors: vec![], - }), - ], - )?; - - encoding_eq!( - &codec, - b"\0\0\0\x01\0\0\0\x19\x02\0\0\0\x08\0\0\0\0\0\0\0\x07\0\0\0\x08\0\0\0\0\0\0\0'", - Value::Array(vec![std::ops::Range { - start: 7i64, - end: 39 - } - .into()]) - ); - Ok(()) -} - -#[test] -fn postgis_geometry() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "44c901c0-d922-4894-83c8-061bd05e4840" - .parse::()? - .into(), - })], - )?; - - encoding_eq!( - &codec, - /* - * Point - * 01 - byteOrder, Little Endian - * 01000000 - wkbType, WKBPoint - * 0000000000000040 - x, 2.0 - * 000000000000F03F - y, 1.0 - */ - b"\ - \x01\ - \x01\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - ", - Value::PostGisGeometry( - b"\ - \x01\ - \x01\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - "[..] - .into() - ) - ); - Ok(()) -} - -#[test] -fn postgis_geography() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "4d738878-3a5f-4821-ab76-9d8e7d6b32c4" - .parse::()? - .into(), - })], - )?; - encoding_eq!( - &codec, - /* - * Point - * 01 - byteOrder, Little Endian - * 01000000 - wkbType, WKBPoint - * 0000000000000040 - x, 2.0 - * 000000000000F03F - y, 1.0 - */ - b"\ - \x01\ - \x01\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - ", - Value::PostGisGeography( - b"\ - \x01\ - \x01\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - "[..] - .into() - ) - ); - Ok(()) -} - -#[test] -fn postgis_box_2d() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "7fae5536-6311-4f60-8eb9-096a5d972f48" - .parse::()? - .into(), - })], - )?; - encoding_eq!( - &codec, - /* - * Polygon - * 01 - byteOrder, Little Endian - * 03000000 - wkbType, wkbPolygon - * 01000000 - numRings, 1 - * 05000000 - numPoints, 5 - * 000000000000F03F - x, 1.0 - * 000000000000F03F - y, 1.0 - * 0000000000000040 - x, 2.0 - * 000000000000F03F - y, 1.0 - * 0000000000000040 - x, 2.0 - * 0000000000000040 - y, 2.0 - * 000000000000F03F - x, 1.0 - * 0000000000000040 - y, 2.0 - * 000000000000F03F - x, 1.0 - * 000000000000F03F - y, 1.0 - */ - b"\ - \x01\ - \x03\x00\x00\x00\ - \x01\x00\x00\x00\ - \x05\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - ", - Value::PostGisBox2d( - b"\ - \x01\ - \x03\x00\x00\x00\ - \x01\x00\x00\x00\ - \x05\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - "[..] - .into() - ) - ); - Ok(()) -} - -#[test] -fn postgis_box_3d() -> Result<(), Box> { - let codec = build_codec( - Some(TypePos(0)), - &[Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "c1a50ff8-fded-48b0-85c2-4905a8481433" - .parse::()? - .into(), - })], - )?; - encoding_eq!( - &codec, - /* - * Polygon - * 01 - byteOrder, Little Endian - * 03000080 - wkbType, wkbPolygonZ - * 01000000 - numRings, 1 - * 05000000 - numPoints, 5 - * 000000000000F03F - x, 1.0 - * 000000000000F03F - y, 1.0 - * 0000000000000000 - z, 0.0 - * 0000000000000040 - x, 2.0 - * 000000000000F03F - y, 1.0 - * 0000000000000000 - z, 0.0 - * 0000000000000040 - x, 2.0 - * 0000000000000040 - y, 2.0 - * 000000000000F03F - x, 1.0 - * 0000000000000000 - z, 0.0 - * 0000000000000040 - y, 2.0 - * 000000000000F03F - x, 1.0 - * 000000000000F03F - y, 1.0 - * 0000000000000000 - z, 0.0 - */ - b"\ - \x01\ - \x03\x00\x00\x80\ - \x01\x00\x00\x00\ - \x05\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - ", - Value::PostGisBox3d( - b"\ - \x01\ - \x03\x00\x00\x80\ - \x01\x00\x00\x00\ - \x05\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x40\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\xF0\x3F\ - \x00\x00\x00\x00\x00\x00\x00\x00\ - "[..] - .into() - ) - ); - Ok(()) -} diff --git a/edgedb-protocol/tests/datetime_chrono.rs b/edgedb-protocol/tests/datetime_chrono.rs deleted file mode 100644 index 5a518abe..00000000 --- a/edgedb-protocol/tests/datetime_chrono.rs +++ /dev/null @@ -1,400 +0,0 @@ -#[cfg(feature = "chrono")] -mod chrono { - - use std::convert::TryInto; - use std::str::FromStr; - - use bytes::{Buf, BytesMut}; - use edgedb_protocol::codec::{self, Codec}; - use edgedb_protocol::model::{Datetime, LocalDatetime, LocalTime}; - use edgedb_protocol::value::Value; - use test_case::test_case; - - // ======== - // Datetime - // ======== - - // Minimum and Maximum - // ------------------- - #[test_case( - /*input*/ "9999-12-31T23:59:59.999999499Z", - // Note: Can't round up here, so --^ - /*micros*/ 252455615999999999, - /*formatted*/ "9999-12-31T23:59:59.999999Z" - ; "maximum" - )] - #[test_case( - /*input*/ "0001-01-01T00:00:00.000000Z", - /*micros*/ -63082281600000000, - /*formatted*/ "0001-01-01T00:00:00Z" - ; "minimum" - )] - // Rounding in Various Ranges - // -------------------------- - #[test_case( - /*input*/ "1814-03-09T01:02:03.000005500Z", - /*micros*/ -5863791476999994, - /*formatted*/ "1814-03-09T01:02:03.000006Z" - ; "negative unix timestamp, round up" - )] - #[test_case( - /*input*/ "1814-03-09T01:02:03.000005501Z", - /*micros*/ -5863791476999994, - /*formatted*/ "1814-03-09T01:02:03.000006Z" - ; "negative unix timestamp, 5501" - )] - #[test_case( - /*input*/ "1814-03-09T01:02:03.000005499Z", - /*micros*/ -5863791476999995, - /*formatted*/ "1814-03-09T01:02:03.000005Z" - ; "negative unix timestamp, 5499" - )] - #[test_case( - /*input*/ "1856-08-27T01:02:03.000004500Z", - /*micros*/ -4523554676999996, - /*formatted*/ "1856-08-27T01:02:03.000004Z" - ; "negative unix timestamp, round down" - )] - #[test_case( - /*input*/ "1856-08-27T01:02:03.000004501Z", - /*micros*/ -4523554676999995, - /*formatted*/ "1856-08-27T01:02:03.000005Z" - ; "negative unix timestamp, 4501" - )] - #[test_case( - /*input*/ "1856-08-27T01:02:03.000004499Z", - /*micros*/ -4523554676999996, - /*formatted*/ "1856-08-27T01:02:03.000004Z" - ; "negative unix timestamp, 4499" - )] - #[test_case( - /*input*/ "1969-12-31T23:59:59.999999500Z", - /*micros*/ -946684800000000, - /*formatted*/ "1970-01-01T00:00:00Z" - ; "unix timestamp to zero" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000009500Z", - /*micros*/ -78620276999990, - /*formatted*/ "1997-07-05T01:02:03.000010Z" - ; "negative postgres timestamp, round up" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000009500Z", - /*micros*/ -78620276999990, - /*formatted*/ "1997-07-05T01:02:03.000010Z" - ; "negative postgres timestamp, 9501" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000009499Z", - /*micros*/ -78620276999991, - /*formatted*/ "1997-07-05T01:02:03.000009Z" - ; "negative postgres timestamp, 9499" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000000500Z", - /*micros*/ -78620277000000, - /*formatted*/ "1997-07-05T01:02:03Z" - ; "negative postgres timestamp, round down" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000000501Z", - /*micros*/ -78620276999999, - /*formatted*/ "1997-07-05T01:02:03.000001Z" - ; "negative postgres timestamp, 501" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000000499Z", - /*micros*/ -78620277000000, - /*formatted*/ "1997-07-05T01:02:03Z" - ; "negative postgres timestamp, 499" - )] - #[test_case( - /*input*/ "1999-12-31T23:59:59.999999500Z", - /*micros*/ 0, - /*formatted*/ "2000-01-01T00:00:00Z" - ; "postgres timestamp to zero" - )] - #[test_case( - /*input*/ "2014-02-27T00:00:00.000001500Z", - /*micros*/ 446774400000002, - /*formatted*/ "2014-02-27T00:00:00.000002Z" - ; "positive timestamp, round up" - )] - #[test_case( - /*input*/ "2014-02-27T00:00:00.000001501Z", - /*micros*/ 446774400000002, - /*formatted*/ "2014-02-27T00:00:00.000002Z" - ; "positive timestamp, 1501" - )] - #[test_case( - /*input*/ "2014-02-27T00:00:00.000001499Z", - /*micros*/ 446774400000001, - /*formatted*/ "2014-02-27T00:00:00.000001Z" - ; "positive timestamp, 1499" - )] - #[test_case( - /*input*/ "2022-02-24T05:43:03.000002500Z", - /*micros*/ 698996583000002, - /*formatted*/ "2022-02-24T05:43:03.000002Z" - ; "positive timestamp, round down" - )] - #[test_case( - /*input*/ "2022-02-24T05:43:03.000002501Z", - /*micros*/ 698996583000003, - /*formatted*/ "2022-02-24T05:43:03.000003Z" - ; "positive timestamp, 2501" - )] - #[test_case( - /*input*/ "2022-02-24T05:43:03.000002499Z", - /*micros*/ 698996583000002, - /*formatted*/ "2022-02-24T05:43:03.000002Z" - ; "positive timestamp, 2499" - )] - fn datetime(input: &str, micros: i64, formatted: &str) { - let chrono = chrono::DateTime::::from_str(input).unwrap(); - let edgedb: Datetime = chrono.try_into().unwrap(); - assert_eq!(format!("{:?}", edgedb), formatted); - - let mut buf = BytesMut::new(); - let val = Value::Datetime(edgedb.clone()); - codec::Datetime.encode(&mut buf, &val).unwrap(); - let serialized_micros = buf.get_i64(); - - assert_eq!(serialized_micros, micros); - - let rev = chrono::DateTime::::from(edgedb); - assert_eq!(format!("{:?}", rev), formatted); - } - - // ============== - // Local Datetime - // ============== - - // Minimum and Maximum - // ------------------- - #[test_case( - /*input*/ "9999-12-31T23:59:59.999999499", - // Note: Can't round up here, so --^ - /*micros*/ 252455615999999999, - /*formatted*/ "9999-12-31T23:59:59.999999" - ; "maximum" - )] - #[test_case( - /*input*/ "0001-01-01T00:00:00.000000", - /*micros*/ -63082281600000000, - /*formatted*/ "0001-01-01T00:00:00" - ; "minimum" - )] - // Rounding in Various Ranges - // -------------------------- - #[test_case( - /*input*/ "1814-03-09T01:02:03.000005500", - /*micros*/ -5863791476999994, - /*formatted*/ "1814-03-09T01:02:03.000006" - ; "negative unix timestamp, round up" - )] - #[test_case( - /*input*/ "1814-03-09T01:02:03.000005501", - /*micros*/ -5863791476999994, - /*formatted*/ "1814-03-09T01:02:03.000006" - ; "negative unix timestamp, 5501" - )] - #[test_case( - /*input*/ "1814-03-09T01:02:03.000005499", - /*micros*/ -5863791476999995, - /*formatted*/ "1814-03-09T01:02:03.000005" - ; "negative unix timestamp, 5499" - )] - #[test_case( - /*input*/ "1856-08-27T01:02:03.000004500", - /*micros*/ -4523554676999996, - /*formatted*/ "1856-08-27T01:02:03.000004" - ; "negative unix timestamp, round down" - )] - #[test_case( - /*input*/ "1856-08-27T01:02:03.000004501", - /*micros*/ -4523554676999995, - /*formatted*/ "1856-08-27T01:02:03.000005" - ; "negative unix timestamp, 4501" - )] - #[test_case( - /*input*/ "1856-08-27T01:02:03.000004499", - /*micros*/ -4523554676999996, - /*formatted*/ "1856-08-27T01:02:03.000004" - ; "negative unix timestamp, 4499" - )] - #[test_case( - /*input*/ "1969-12-31T23:59:59.999999500", - /*micros*/ -946684800000000, - /*formatted*/ "1970-01-01T00:00:00" - ; "unix timestamp to zero" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000009500", - /*micros*/ -78620276999990, - /*formatted*/ "1997-07-05T01:02:03.000010" - ; "negative postgres timestamp, round up" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000009500", - /*micros*/ -78620276999990, - /*formatted*/ "1997-07-05T01:02:03.000010" - ; "negative postgres timestamp, 9501" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000009499", - /*micros*/ -78620276999991, - /*formatted*/ "1997-07-05T01:02:03.000009" - ; "negative postgres timestamp, 9499" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000000500", - /*micros*/ -78620277000000, - /*formatted*/ "1997-07-05T01:02:03" - ; "negative postgres timestamp, round down" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000000501", - /*micros*/ -78620276999999, - /*formatted*/ "1997-07-05T01:02:03.000001" - ; "negative postgres timestamp, 501" - )] - #[test_case( - /*input*/ "1997-07-05T01:02:03.000000499", - /*micros*/ -78620277000000, - /*formatted*/ "1997-07-05T01:02:03" - ; "negative postgres timestamp, 499" - )] - #[test_case( - /*input*/ "1999-12-31T23:59:59.999999500", - /*micros*/ 0, - /*formatted*/ "2000-01-01T00:00:00" - ; "postgres timestamp to zero" - )] - #[test_case( - /*input*/ "2014-02-27T00:00:00.000001500", - /*micros*/ 446774400000002, - /*formatted*/ "2014-02-27T00:00:00.000002" - ; "positive timestamp, round up" - )] - #[test_case( - /*input*/ "2014-02-27T00:00:00.000001501", - /*micros*/ 446774400000002, - /*formatted*/ "2014-02-27T00:00:00.000002" - ; "positive timestamp, 1501" - )] - #[test_case( - /*input*/ "2014-02-27T00:00:00.000001499", - /*micros*/ 446774400000001, - /*formatted*/ "2014-02-27T00:00:00.000001" - ; "positive timestamp, 1499" - )] - #[test_case( - /*input*/ "2022-02-24T05:43:03.000002500", - /*micros*/ 698996583000002, - /*formatted*/ "2022-02-24T05:43:03.000002" - ; "positive timestamp, round down" - )] - #[test_case( - /*input*/ "2022-02-24T05:43:03.000002501", - /*micros*/ 698996583000003, - /*formatted*/ "2022-02-24T05:43:03.000003" - ; "positive timestamp, 2501" - )] - #[test_case( - /*input*/ "2022-02-24T05:43:03.000002499", - /*micros*/ 698996583000002, - /*formatted*/ "2022-02-24T05:43:03.000002" - ; "positive timestamp, 2499" - )] - fn local_datetime(input: &str, micros: i64, formatted: &str) { - let chrono = chrono::NaiveDateTime::from_str(input).unwrap(); - let edgedb: LocalDatetime = chrono.try_into().unwrap(); - assert_eq!(format!("{:?}", edgedb), formatted); - - let mut buf = BytesMut::new(); - let val = Value::LocalDatetime(edgedb.clone()); - codec::LocalDatetime.encode(&mut buf, &val).unwrap(); - let serialized_micros = buf.get_i64(); - - assert_eq!(serialized_micros, micros); - - let rev = chrono::NaiveDateTime::from(edgedb); - assert_eq!(format!("{:?}", rev), formatted); - } - - // ========== - // Local Time - // ========== - #[test_case( - /*input*/ "23:59:59.999999500", - // Note: Can't round up here, so --^ - /*micros*/ 0, - /*formatted*/ "00:00:00" - ; "wraparound" - )] - #[test_case( - /*input*/ "00:00:00.000000", - /*micros*/ 0, - /*formatted*/ "00:00:00" - ; "minimum" - )] - #[test_case( - /*input*/ "23:59:59.999999", - /*micros*/ 86399999999, - /*formatted*/ "23:59:59.999999" - ; "maximum" - )] - #[test_case( - /*input*/ "01:02:03.000005500", - /*micros*/ 3723000006, - /*formatted*/ "01:02:03.000006" - ; "round up" - )] - #[test_case( - /*input*/ "01:02:03.000005501", - /*micros*/ 3723000006, - /*formatted*/ "01:02:03.000006" - ; "5501" - )] - #[test_case( - /*input*/ "01:02:03.000005499", - /*micros*/ 3723000005, - /*formatted*/ "01:02:03.000005" - ; "5499" - )] - #[test_case( - /*input*/ "01:02:03.000004500", - /*micros*/ 3723000004, - /*formatted*/ "01:02:03.000004" - ; "round down" - )] - #[test_case( - /*input*/ "01:02:03.000004501", - /*micros*/ 3723000005, - /*formatted*/ "01:02:03.000005" - ; "4501" - )] - #[test_case( - /*input*/ "01:02:03.000004499", - /*micros*/ 3723000004, - /*formatted*/ "01:02:03.000004" - ; "4499" - )] - fn local_time(input: &str, micros: i64, formatted: &str) { - let chrono = chrono::NaiveTime::from_str(input).unwrap(); - let edgedb: LocalTime = chrono.into(); - assert_eq!(format!("{:?}", edgedb), formatted); - - let mut buf = BytesMut::new(); - let val = Value::LocalTime(edgedb.clone()); - codec::LocalTime.encode(&mut buf, &val).unwrap(); - let serialized_micros = buf.get_i64(); - - assert_eq!(serialized_micros, micros); - - let rev = chrono::NaiveTime::from(edgedb); - assert_eq!(format!("{:?}", rev), formatted); - } -} diff --git a/edgedb-protocol/tests/datetime_system.rs b/edgedb-protocol/tests/datetime_system.rs deleted file mode 100644 index 7afc941d..00000000 --- a/edgedb-protocol/tests/datetime_system.rs +++ /dev/null @@ -1,264 +0,0 @@ -use std::convert::{TryFrom, TryInto}; -use std::time::{Duration as StdDuration, SystemTime, UNIX_EPOCH}; - -use bytes::{Buf, BytesMut}; -use edgedb_protocol::codec::{self, Codec}; -use edgedb_protocol::model::{Datetime, Duration}; -use edgedb_protocol::value::Value; -use test_case::test_case; - -// ======== -// Datetime -// ======== -// -// Note: pre-1970 dates have hardcoded unix time and they are below - -// Maximum -// ------- -#[test_case( - /*input*/ "9999-12-31T23:59:59.999999499Z", - // Note: Can't round up here, so --^ - /*micros*/ 252455615999999999, - /*formatted*/ "9999-12-31T23:59:59.999999000Z" - ; "maximum" -)] -// Rounding in Various Ranges >= 1970 -// --------------------------------- -#[test_case( - /*input*/ "1997-07-05T01:02:03.000009500Z", - /*micros*/ -78620276999990, - /*formatted*/ "1997-07-05T01:02:03.000010000Z" - ; "negative postgres timestamp, round up" -)] -#[test_case( - /*input*/ "1997-07-05T01:02:03.000009500Z", - /*micros*/ -78620276999990, - /*formatted*/ "1997-07-05T01:02:03.000010000Z" - ; "negative postgres timestamp, 9501" -)] -#[test_case( - /*input*/ "1997-07-05T01:02:03.000009499Z", - /*micros*/ -78620276999991, - /*formatted*/ "1997-07-05T01:02:03.000009000Z" - ; "negative postgres timestamp, 9499" -)] -#[test_case( - /*input*/ "1997-07-05T01:02:03.000000500Z", - /*micros*/ -78620277000000, - /*formatted*/ "1997-07-05T01:02:03Z" - ; "negative postgres timestamp, round down" -)] -#[test_case( - /*input*/ "1997-07-05T01:02:03.000000501Z", - /*micros*/ -78620276999999, - /*formatted*/ "1997-07-05T01:02:03.000001000Z" - ; "negative postgres timestamp, 501" -)] -#[test_case( - /*input*/ "1997-07-05T01:02:03.000000499Z", - /*micros*/ -78620277000000, - /*formatted*/ "1997-07-05T01:02:03Z" - ; "negative postgres timestamp, 499" -)] -#[test_case( - /*input*/ "1999-12-31T23:59:59.999999500Z", - /*micros*/ 0, - /*formatted*/ "2000-01-01T00:00:00Z" - ; "postgres timestamp to zero" -)] -#[test_case( - /*input*/ "2014-02-27T00:00:00.000001500Z", - /*micros*/ 446774400000002, - /*formatted*/ "2014-02-27T00:00:00.000002000Z" - ; "positive timestamp, round up" -)] -#[test_case( - /*input*/ "2014-02-27T00:00:00.000001501Z", - /*micros*/ 446774400000002, - /*formatted*/ "2014-02-27T00:00:00.000002000Z" - ; "positive timestamp, 1501" -)] -#[test_case( - /*input*/ "2014-02-27T00:00:00.000001499Z", - /*micros*/ 446774400000001, - /*formatted*/ "2014-02-27T00:00:00.000001000Z" - ; "positive timestamp, 1499" -)] -#[test_case( - /*input*/ "2022-02-24T05:43:03.000002500Z", - /*micros*/ 698996583000002, - /*formatted*/ "2022-02-24T05:43:03.000002000Z" - ; "positive timestamp, round down" -)] -#[test_case( - /*input*/ "2022-02-24T05:43:03.000002501Z", - /*micros*/ 698996583000003, - /*formatted*/ "2022-02-24T05:43:03.000003000Z" - ; "positive timestamp, 2501" -)] -#[test_case( - /*input*/ "2022-02-24T05:43:03.000002499Z", - /*micros*/ 698996583000002, - /*formatted*/ "2022-02-24T05:43:03.000002000Z" - ; "positive timestamp, 2499" -)] -fn datetime(input: &str, micros: i64, formatted: &str) { - let system = humantime::parse_rfc3339(input).unwrap(); - let edgedb: Datetime = system.try_into().unwrap(); - // different format but we assert microseconds anyways - // assert_eq!(format!("{:?}", edgedb), formatted); - - let mut buf = BytesMut::new(); - let val = Value::Datetime(edgedb.clone()); - codec::Datetime.encode(&mut buf, &val).unwrap(); - let serialized_micros = buf.get_i64(); - - assert_eq!(serialized_micros, micros); - - let rev = SystemTime::try_from(edgedb).unwrap(); - assert_eq!(humantime::format_rfc3339(rev).to_string(), formatted); -} - -#[test_case( - /*input: "0001-01-01T00:00:00.000000Z",*/ - StdDuration::new(62135596800, 0), - /*micros*/ -63082281600000000, - /*formatted: "0001-01-01T00:00:00Z"*/ - /*output*/ StdDuration::new(62135596800, 0) - ; "minimum" -)] -// Rounding in pre Unix Epoch -// -------------------------- -#[test_case( - /*input: "1814-03-09T01:02:03.000005500Z",*/ - StdDuration::new(4917106676, 999994500), - /*micros*/ -5863791476999994, - /*formatted "1814-03-09T01:02:03.000006000Z"*/ - /*output*/ StdDuration::new(4917106676, 999994000) - ; "negative unix timestamp, round up" -)] -#[test_case( - /*input: "1814-03-09T01:02:03.000005501Z",*/ - StdDuration::new(4917106676, 999994499), - /*micros*/ -5863791476999994, - /*formatted: "1814-03-09T01:02:03.000006000Z"*/ - /*output*/ StdDuration::new(4917106676, 999994000) - ; "negative unix timestamp, 5501" -)] -#[test_case( - /*input: "1814-03-09T01:02:03.000005499Z",*/ - StdDuration::new(4917106676, 999994501), - /*micros*/ -5863791476999995, - /*formatted: "1814-03-09T01:02:03.000005000Z"*/ - /*output*/ StdDuration::new(4917106676, 999995000) - ; "negative unix timestamp, 5499" -)] -#[test_case( - /*input: "1856-08-27T01:02:03.000004500Z",*/ - StdDuration::new(3576869876, 999995500), - /*micros*/ -4523554676999996, - /*formatted: "1856-08-27T01:02:03.000004000Z"*/ - /*output*/ StdDuration::new(3576869876, 999996000) - ; "negative unix timestamp, round down" -)] -#[test_case( - /*input: "1856-08-27T01:02:03.000004501Z",*/ - StdDuration::new(3576869876, 999995499), - /*micros*/ -4523554676999995, - /*formatted:"1856-08-27T01:02:03.000005000Z" */ - /*output*/ StdDuration::new(3576869876, 999995000) - ; "negative unix timestamp, 4501" -)] -#[test_case( - /*input: "1856-08-27T01:02:03.000004499Z",*/ - StdDuration::new(3576869876, 999995501), - /*micros*/ -4523554676999996, - /*formatted: "1856-08-27T01:02:03.000004000Z"*/ - /*output*/ StdDuration::new(3576869876, 999996000) - ; "negative unix timestamp, 4499" -)] -#[test_case( - /*input: "1969-12-31T23:59:59.999999500Z",*/ - StdDuration::new(0, 500), - /*micros*/ -946684800000000, - /*formatted: "1970-01-01T00:00:00Z"*/ - /*output*/ StdDuration::new(0, 0) - ; "unix timestamp to zero" -)] -fn datetime_pre_1970(input: StdDuration, micros: i64, output: StdDuration) { - let edgedb: Datetime = (UNIX_EPOCH - input).try_into().unwrap(); - // different format but we assert microseconds anyways - // assert_eq!(format!("{:?}", edgedb), formatted); - - let mut buf = BytesMut::new(); - let val = Value::Datetime(edgedb.clone()); - codec::Datetime.encode(&mut buf, &val).unwrap(); - let serialized_micros = buf.get_i64(); - - assert_eq!(serialized_micros, micros); - - let rev = SystemTime::try_from(edgedb).unwrap(); - assert_eq!(rev, UNIX_EPOCH - output); -} - -#[test_case( - /*input*/ StdDuration::new(0, 0), - /*micros*/ 0, - /*output*/ StdDuration::new(0, 0) - ; "Zero" -)] -#[test_case( - /*input*/ StdDuration::new(1234, 567890123), - /*micros*/ 1234567890, - /*output*/ StdDuration::new(1234, 567890000) - ; "Some value" -)] -#[test_case( - /*input*/ StdDuration::new(1, 2500), - /*micros*/ 1000002, - /*output*/ StdDuration::new(1, 2000) - ; "round down" -)] -#[test_case( - /*input*/ StdDuration::new(23, 2499), - /*micros*/ 23000002, - /*output*/ StdDuration::new(23, 2000) - ; "2499 nanos" -)] -#[test_case( - /*input*/ StdDuration::new(456, 2501), - /*micros*/ 456000003, - /*output*/ StdDuration::new(456, 3000) - ; "2501 nanos" -)] -#[test_case( - /*input*/ StdDuration::new(5789, 3500), - /*micros*/ 5789000004, - /*output*/ StdDuration::new(5789, 4000) - ; "round up" -)] -#[test_case( - /*input*/ StdDuration::new(12345, 3499), - /*micros*/ 12345000003, - /*output*/ StdDuration::new(12345, 3000) - ; "3499 nanos" -)] -#[test_case( - /*input*/ StdDuration::new(789012, 3501), - /*micros*/ 789012000004, - /*output*/ StdDuration::new(789012, 4000) - ; "3501 nanos" -)] -fn duration(input: StdDuration, micros: i64, output: StdDuration) { - let edgedb: Duration = input.try_into().unwrap(); - - let mut buf = BytesMut::new(); - let val = Value::Duration(edgedb.clone()); - codec::Duration.encode(&mut buf, &val).unwrap(); - let serialized_micros = buf.get_i64(); - - assert_eq!(serialized_micros, micros); - - let rev: StdDuration = edgedb.try_into().unwrap(); - assert_eq!(rev, output); -} diff --git a/edgedb-protocol/tests/decode.rs b/edgedb-protocol/tests/decode.rs deleted file mode 100644 index f5b52b27..00000000 --- a/edgedb-protocol/tests/decode.rs +++ /dev/null @@ -1,8 +0,0 @@ -use edgedb_protocol::model::Vector; -use edgedb_protocol::queryable::Queryable; - -#[test] -fn decode_vector() { - let vec = Vector::decode(&Default::default(), b"\0\x03\0\0?\x80\0\0@\0\0\0@@\0\0").unwrap(); - assert_eq!(vec, Vector(vec![1., 2., 3.])); -} diff --git a/edgedb-protocol/tests/error_response.bin b/edgedb-protocol/tests/error_response.bin deleted file mode 100644 index 92b01b4569106048967671c06d7e9e04e45d17f3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 458 zcmb`Dy-veG4963(gCt+*7GjR6Ud_W$E2 z_eD{Do^HxwnVshtSw>KLn^|c98$`$^4gwP@dP=0g0$ax&>^`T=nLv@3U|FLvB2>*7S!{Th7@dY{jywD3+M0m4*u)q)ffLOokIWs diff --git a/edgedb-protocol/tests/parameter_status.bin b/edgedb-protocol/tests/parameter_status.bin deleted file mode 100644 index 81f7d3bcd4d3c9058b2fd493a379652bd91748e5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 46 ycmWG$U|`S%Vzz?x#FUgGAWK@mJijPgza+OnKP5?DuUId@Jvh)u&&)IV!QwV51;^W diff --git a/edgedb-protocol/tests/server_messages.rs b/edgedb-protocol/tests/server_messages.rs deleted file mode 100644 index 5329b504..00000000 --- a/edgedb-protocol/tests/server_messages.rs +++ /dev/null @@ -1,370 +0,0 @@ -use std::collections::HashMap; -use std::error::Error; -use std::fs; - -use bytes::{Bytes, BytesMut}; -use uuid::Uuid; - -use edgedb_protocol::common::{Capabilities, RawTypedesc}; -use edgedb_protocol::encoding::{Input, Output}; -use edgedb_protocol::features::ProtocolVersion; -use edgedb_protocol::server_message::Authentication; -use edgedb_protocol::server_message::CommandDataDescription1; -use edgedb_protocol::server_message::RestoreReady; -use edgedb_protocol::server_message::ServerHandshake; -use edgedb_protocol::server_message::ServerMessage; -use edgedb_protocol::server_message::StateDataDescription; -use edgedb_protocol::server_message::{Cardinality, PrepareComplete}; -use edgedb_protocol::server_message::{CommandComplete0, CommandComplete1}; -use edgedb_protocol::server_message::{CommandDataDescription0, Data}; -use edgedb_protocol::server_message::{ErrorResponse, ErrorSeverity}; -use edgedb_protocol::server_message::{LogMessage, MessageSeverity}; -use edgedb_protocol::server_message::{ParameterStatus, ServerKeyData}; -use edgedb_protocol::server_message::{ReadyForCommand, TransactionState}; - -mod base; - -macro_rules! encoding_eq_ver { - ($major: expr, $minor: expr, $message: expr, $bytes: expr) => { - let proto = ProtocolVersion::new($major, $minor); - let data: &[u8] = $bytes; - let mut bytes = BytesMut::new(); - $message.encode(&mut Output::new(&proto, &mut bytes))?; - println!("Serialized bytes {:?}", bytes); - let bytes = bytes.freeze(); - assert_eq!(&bytes[..], data); - assert_eq!( - ServerMessage::decode(&mut Input::new(proto, Bytes::copy_from_slice(data)))?, - $message, - ); - }; -} - -macro_rules! encoding_eq { - ($message: expr, $bytes: expr) => { - let (major, minor) = ProtocolVersion::current().version_tuple(); - encoding_eq_ver!(major, minor, $message, $bytes); - }; -} - -macro_rules! map { - ($($key:expr => $value:expr),*) => { - { - #[allow(unused_mut)] - let mut h = HashMap::new(); - $( - h.insert($key, $value); - )* - h - } - } -} - -#[test] -fn server_handshake() -> Result<(), Box> { - encoding_eq!( - ServerMessage::ServerHandshake(ServerHandshake { - major_ver: 1, - minor_ver: 0, - extensions: HashMap::new(), - }), - b"v\0\0\0\n\0\x01\0\0\0\0" - ); - Ok(()) -} - -#[test] -fn ready_for_command() -> Result<(), Box> { - encoding_eq!( - ServerMessage::ReadyForCommand(ReadyForCommand { - transaction_state: TransactionState::NotInTransaction, - headers: HashMap::new(), - }), - b"Z\0\0\0\x07\0\0I" - ); - Ok(()) -} - -#[test] -fn error_response() -> Result<(), Box> { - encoding_eq!( - ServerMessage::ErrorResponse(ErrorResponse { - severity: ErrorSeverity::Error, - code: 50397184, - message: String::from( - "missing required connection parameter \ - in ClientHandshake message: \"user\"" - ), - attributes: map! { - 257 => Bytes::from_static("Traceback (most recent call last):\n File \"edb/server/mng_port/edgecon.pyx\", line 1077, in edb.server.mng_port.edgecon.EdgeConnection.main\n await self.auth()\n File \"edb/server/mng_port/edgecon.pyx\", line 178, in auth\n raise errors.BinaryProtocolError(\nedb.errors.BinaryProtocolError: missing required connection parameter in ClientHandshake message: \"user\"\n".as_bytes()) - }, - }), - &fs::read("tests/error_response.bin")?[..] - ); - Ok(()) -} - -#[test] -fn server_key_data() -> Result<(), Box> { - encoding_eq!( - ServerMessage::ServerKeyData(ServerKeyData { data: [0u8; 32] }), - &fs::read("tests/server_key_data.bin")?[..] - ); - Ok(()) -} - -#[test] -fn parameter_status() -> Result<(), Box> { - encoding_eq!( - ServerMessage::ParameterStatus(ParameterStatus { - proto: ProtocolVersion::current(), - name: Bytes::from_static(b"pgaddr"), - value: Bytes::from_static(b"/work/tmp/db/.s.PGSQL.60128"), - }), - &fs::read("tests/parameter_status.bin")?[..] - ); - Ok(()) -} - -#[test] -fn command_complete0() -> Result<(), Box> { - encoding_eq_ver!( - 0, - 13, - ServerMessage::CommandComplete0(CommandComplete0 { - headers: HashMap::new(), - status_data: Bytes::from_static(b"okay"), - }), - b"C\0\0\0\x0e\0\0\0\0\0\x04okay" - ); - Ok(()) -} - -#[test] -fn command_complete1() -> Result<(), Box> { - encoding_eq_ver!( - 1, - 0, - ServerMessage::CommandComplete1(CommandComplete1 { - annotations: HashMap::new(), - capabilities: Capabilities::MODIFICATIONS, - status_data: Bytes::from_static(b"okay"), - state: None, - }), - b"C\0\0\0*\0\0\0\0\0\0\0\0\0\x01\0\0\0\x04okay\ - \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\ - \0\0\0\0" - ); - Ok(()) -} - -#[test] -fn prepare_complete() -> Result<(), Box> { - encoding_eq!( - ServerMessage::PrepareComplete(PrepareComplete { - headers: HashMap::new(), - cardinality: Cardinality::AtMostOne, - input_typedesc_id: Uuid::from_u128(0xFF), - output_typedesc_id: Uuid::from_u128(0x105), - }), - b"1\0\0\0'\0\0o\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05" - ); - encoding_eq!( - ServerMessage::PrepareComplete(PrepareComplete { - headers: HashMap::new(), - cardinality: Cardinality::NoResult, - input_typedesc_id: Uuid::from_u128(0xFF), - output_typedesc_id: Uuid::from_u128(0x0), - }), - b"1\0\0\0'\0\0n\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" - ); - Ok(()) -} - -#[test] -fn command_data_description0() -> Result<(), Box> { - encoding_eq_ver!( - 0, - 13, - ServerMessage::CommandDataDescription0(CommandDataDescription0 { - headers: HashMap::new(), - result_cardinality: Cardinality::AtMostOne, - input: RawTypedesc { - proto: ProtocolVersion::new(0, 13), - id: Uuid::from_u128(0xFF), - data: Bytes::from_static(b"\x04\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0\0"), - }, - output: RawTypedesc { - proto: ProtocolVersion::new(0, 13), - id: Uuid::from_u128(0x105), - data: Bytes::from_static(b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05"), - }, - }), - bconcat!(b"T\0\0\0S\0\0o" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff" - b"\0\0\0\x13" - b"\x04\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05" - b"\0\0\0\x11" - b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05") - ); - encoding_eq_ver!( - 0, - 13, - ServerMessage::CommandDataDescription0(CommandDataDescription0 { - headers: HashMap::new(), - result_cardinality: Cardinality::NoResult, - input: RawTypedesc { - proto: ProtocolVersion::new(0, 13), - id: Uuid::from_u128(0xFF), - data: Bytes::from_static(b"\x04\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0\0"), - }, - output: RawTypedesc { - proto: ProtocolVersion::new(0, 13), - id: Uuid::from_u128(0), - data: Bytes::from_static(b""), - }, - }), - bconcat!(b"T\0\0\0B\0\0n" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff" - b"\0\0\0\x13" - b"\x04\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0") - ); - Ok(()) -} - -#[test] -fn command_data_description1() -> Result<(), Box> { - encoding_eq_ver!( - 1, - 0, - ServerMessage::CommandDataDescription1(CommandDataDescription1 { - annotations: HashMap::new(), - capabilities: Capabilities::MODIFICATIONS, - result_cardinality: Cardinality::AtMostOne, - input: RawTypedesc { - proto: ProtocolVersion::new(1, 0), - id: Uuid::from_u128(0xFF), - data: Bytes::from_static(b"\x04\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0\0"), - }, - output: RawTypedesc { - proto: ProtocolVersion::new(1, 0), - id: Uuid::from_u128(0x105), - data: Bytes::from_static(b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05"), - }, - }), - bconcat!(b"T\0\0\0[\0\0\0\0\0\0\0\0\0\x01o" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff" - b"\0\0\0\x13" - b"\x04\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05" - b"\0\0\0\x11" - b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05") - ); - encoding_eq_ver!( - 1, - 0, - ServerMessage::CommandDataDescription1(CommandDataDescription1 { - annotations: HashMap::new(), - capabilities: Capabilities::MODIFICATIONS, - result_cardinality: Cardinality::NoResult, - input: RawTypedesc { - proto: ProtocolVersion::new(1, 0), - id: Uuid::from_u128(0xFF), - data: Bytes::from_static(b"\x04\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0\0"), - }, - output: RawTypedesc { - proto: ProtocolVersion::new(1, 0), - id: Uuid::from_u128(0), - data: Bytes::from_static(b""), - }, - }), - bconcat!(b"T\0\0\0J\0\0\0\0\0\0\0\0\0\x01n" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff" - b"\0\0\0\x13" - b"\x04\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0") - ); - Ok(()) -} - -#[test] -fn data() -> Result<(), Box> { - encoding_eq!( - ServerMessage::Data(Data { - data: vec![Bytes::from_static(b"\0\0\0\0\0\0\0\x01")], - }), - b"D\0\0\0\x12\0\x01\0\0\0\x08\0\0\0\0\0\0\0\x01" - ); - Ok(()) -} - -#[test] -fn restore_ready() -> Result<(), Box> { - encoding_eq!( - ServerMessage::RestoreReady(RestoreReady { - jobs: 1, - headers: HashMap::new(), - }), - b"+\0\0\0\x08\0\0\0\x01" - ); - Ok(()) -} - -#[test] -fn authentication() -> Result<(), Box> { - encoding_eq!( - ServerMessage::Authentication(Authentication::Ok), - b"\x52\0\0\0\x08\x00\x00\x00\x00" - ); - encoding_eq!( - ServerMessage::Authentication(Authentication::Sasl { - methods: vec![String::from("SCRAM-SHA-256")], - }), - b"R\0\0\0\x1d\0\0\0\n\0\0\0\x01\0\0\0\rSCRAM-SHA-256" - ); - encoding_eq!( - ServerMessage::Authentication(Authentication::SaslContinue { - data: Bytes::from_static(b"sasl_interim_data"), - }), - b"R\0\0\0\x1d\x00\x00\x00\x0b\0\0\0\x11sasl_interim_data" - ); - encoding_eq!( - ServerMessage::Authentication(Authentication::SaslFinal { - data: Bytes::from_static(b"sasl_final_data"), - }), - b"R\0\0\0\x1b\x00\x00\x00\x0c\0\0\0\x0fsasl_final_data" - ); - Ok(()) -} - -#[test] -fn log_message() -> Result<(), Box> { - encoding_eq!( - ServerMessage::LogMessage(LogMessage { - severity: MessageSeverity::Notice, - code: 0xF0_00_00_00, - text: "changing system config".into(), - attributes: map! {}, - }), - b"L\0\0\0%<\xf0\0\0\0\0\0\0\x16changing system config\0\0" - ); - Ok(()) -} - -#[test] -fn state_data_description() -> Result<(), Box> { - encoding_eq!( - ServerMessage::StateDataDescription(StateDataDescription { - typedesc: RawTypedesc { - proto: ProtocolVersion::current(), - id: Uuid::from_u128(0x105), - data: Bytes::from_static(b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05"), - }, - }), - b"s\0\0\0)\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05\0\0\0\ - \x11\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05" - ); - Ok(()) -} diff --git a/edgedb-protocol/tests/type_descriptors.rs b/edgedb-protocol/tests/type_descriptors.rs deleted file mode 100644 index 92a20517..00000000 --- a/edgedb-protocol/tests/type_descriptors.rs +++ /dev/null @@ -1,401 +0,0 @@ -use bytes::{Buf, Bytes}; -use std::error::Error; - -use edgedb_protocol::descriptors::BaseScalarTypeDescriptor; -use edgedb_protocol::descriptors::ObjectTypeDescriptor; -use edgedb_protocol::descriptors::ScalarTypeDescriptor; -use edgedb_protocol::descriptors::TupleTypeDescriptor; -use edgedb_protocol::descriptors::{Descriptor, TypePos}; -use edgedb_protocol::descriptors::{ObjectShapeDescriptor, ShapeElement}; -use edgedb_protocol::encoding::Input; -use edgedb_protocol::errors::DecodeError; -use edgedb_protocol::features::ProtocolVersion; -use uuid::Uuid; - -mod base; - -fn decode(pv: ProtocolVersion, bytes: &[u8]) -> Result, DecodeError> { - let bytes = Bytes::copy_from_slice(bytes); - let mut input = Input::new(pv, bytes); - let mut result = Vec::new(); - while input.remaining() > 0 { - result.push(Descriptor::decode(&mut input)?); - } - assert!(input.remaining() == 0); - Ok(result) -} - -fn decode_2_0(bytes: &[u8]) -> Result, DecodeError> { - decode(ProtocolVersion::new(2, 0), bytes) -} - -fn decode_1_0(bytes: &[u8]) -> Result, DecodeError> { - decode(ProtocolVersion::new(1, 0), bytes) -} - -fn decode_0_10(bytes: &[u8]) -> Result, DecodeError> { - decode(ProtocolVersion::new(0, 10), bytes) -} - -#[test] -fn empty_tuple() -> Result<(), Box> { - // `SELECT ()` - assert_eq!( - decode_1_0(b"\x04\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\0\0")?, - vec![Descriptor::Tuple(TupleTypeDescriptor { - id: "00000000-0000-0000-0000-0000000000FF" - .parse::()? - .into(), - element_types: Vec::new(), - name: None, - schema_defined: None, - ancestors: vec![], - }),] - ); - Ok(()) -} - -#[test] -fn one_tuple() -> Result<(), Box> { - // `SELECT (1,)` - assert_eq!( - decode_1_0(bconcat!( - b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05" - b"\x04\x1cyGes%\x89Sa\x03\xe7\x87vE\xad9\0\x01\0\0"))?, - vec![ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - }), - Descriptor::Tuple(TupleTypeDescriptor { - id: "1c794765-7325-8953-6103-e7877645ad39" - .parse::()? - .into(), - element_types: vec![TypePos(0)], - name: None, - schema_defined: None, - ancestors: vec![], - }), - ] - ); - Ok(()) -} - -#[test] -fn single_int_1_0() -> Result<(), Box> { - assert_eq!( - decode_1_0(b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05")?, - vec![Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - })] - ); - Ok(()) -} - -#[test] -fn single_int_2_0() -> Result<(), Box> { - assert_eq!( - decode_2_0(b"\0\0\0\"\x03\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05\0\0\0\nstd::int64\x01\0\0")?, - vec![Descriptor::Scalar(ScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - name: Some(String::from("std::int64")), - schema_defined: Some(true), - ancestors: vec![], - base_type_pos: None, - })] - ); - Ok(()) -} - -#[test] -fn single_derived_int_2_0() -> Result<(), Box> { - assert_eq!( - decode_2_0(bconcat!( - b"\0\0\0\"\x03\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x05\0\0\0\n" - b"std::int64\x01\0\0\0\0\0)\x03\x91v\xff\x8c\x95\xb6\x11\xef\x9c" - b" [\x0e\x8c=\xaa\xc8\0\0\0\x0fdefault::my_int\x01\0\x01\0\0\0\0\0" - b"-\x03J\xa0\x08{\x95\xb7\x11\xef\xbd\xe2?\xfa\xe3\r\x13\xe9\0\0\0" - b"\x11default::my_int_2\x01\0\x02\0\x01\0\0" - ))?, - vec![ - Descriptor::Scalar(ScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000105" - .parse::()? - .into(), - name: Some(String::from("std::int64")), - schema_defined: Some(true), - ancestors: vec![], - base_type_pos: None, - }), - Descriptor::Scalar(ScalarTypeDescriptor { - id: "9176ff8c-95b6-11ef-9c20-5b0e8c3daac8" - .parse::()? - .into(), - name: Some(String::from("default::my_int")), - schema_defined: Some(true), - ancestors: vec![TypePos(0)], - base_type_pos: Some(TypePos(0)), - }), - Descriptor::Scalar(ScalarTypeDescriptor { - id: "4aa0087b-95b7-11ef-bde2-3ffae30d13e9" - .parse::()? - .into(), - name: Some(String::from("default::my_int_2")), - schema_defined: Some(true), - ancestors: vec![TypePos(1), TypePos(0)], - base_type_pos: Some(TypePos(0)), - }), - ] - ); - Ok(()) -} - -#[test] -fn duration() -> Result<(), Box> { - assert_eq!( - decode_1_0(b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x0e")?, - vec![Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-00000000010e" - .parse::()? - .into(), - })] - ); - Ok(()) -} - -#[test] -fn object_0_10() -> Result<(), Box> { - assert_eq!( - decode_0_10(bconcat!( - b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\0\x02" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x01\x01n" - b"\xbb\xbe\xda\0P\x14\xfe\x84\xbc\x82\x15@\xb1" - b"R\xcd\0\x03\x01\0\0\0\x07__tid__\0\0\x01" - b"\0\0\0\x02id\0\0\0\0\0\0\x05title\0\x01"))?, - vec![ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000100" - .parse::()? - .into(), - }), - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000101" - .parse::()? - .into(), - }), - Descriptor::ObjectShape(ObjectShapeDescriptor { - id: "6ebbbeda-0050-14fe-84bc-821540b152cd" - .parse::()? - .into(), - ephemeral_free_shape: false, - type_pos: None, - elements: vec![ - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: String::from("__tid__"), - type_pos: TypePos(0), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: String::from("id"), - type_pos: TypePos(0), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: false, - cardinality: None, - name: String::from("title"), - type_pos: TypePos(1), - source_type_pos: None, - } - ] - }) - ] - ); - Ok(()) -} - -#[test] -fn object_1_0() -> Result<(), Box> { - use edgedb_protocol::common::Cardinality::*; - assert_eq!( - decode_1_0(bconcat!( - // equivalent of 0.10 - //b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x01\x02" - //b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\0\x01,sT" - //b"\xf9\x8f\xfac\xed\x10\x8d\x9c\xe4\x156\xd3\x92\0\x03" - //b"\x01\0\0\0\t__tname__\0\0\x01\0\0\0\x02id\0\x01\0\0\0\0\x05title\0\0" - b"\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x01\x02" - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\0\x01n'\xdb\xa0" - b"xa$\xc2\x86\xa9\x15\xa6\xf2\xe3\xfa\xf5\0\x03\0\0\0" - b"\x01A\0\0\0\t__tname__\0\0\0\0\0\x01A\0\0\0\x02id" - b"\0\x01\0\0\0\0o\0\0\0\x05title\0\0" - ))?, - vec![ - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000101" - .parse::()? - .into(), - }), - Descriptor::BaseScalar(BaseScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000100" - .parse::()? - .into(), - }), - Descriptor::ObjectShape(ObjectShapeDescriptor { - id: "6e27dba0-7861-24c2-86a9-15a6f2e3faf5" - .parse::()? - .into(), - ephemeral_free_shape: false, - type_pos: None, - elements: vec![ - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: Some(One), - name: String::from("__tname__"), - type_pos: TypePos(0), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: Some(One), - name: String::from("id"), - type_pos: TypePos(1), - source_type_pos: None, - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: false, - cardinality: Some(AtMostOne), - name: String::from("title"), - type_pos: TypePos(0), - source_type_pos: None, - } - ] - }) - ] - ); - Ok(()) -} - -#[test] -fn object_2_0() -> Result<(), Box> { - use edgedb_protocol::common::Cardinality::*; - // SELECT Foo { - // id, - // title, - // [IS Bar].body, - // } - assert_eq!( - decode_2_0(bconcat!( - b"\0\0\0 \x03\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01\x01\0\0\0\x08" - b"std::str\x01\0\0\0\0\0!\x03\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01" - b"\0\0\0\0\tstd::uuid\x01\0\0\0\0\0\"\n\xc3\xcc\xa7R\x95\xb7" - b"\x11\xef\xb4\x87\x1d\x1b\x9f\xa20\x03\0\0\0\x0cdefault::Foo" - b"\x01\0\0\0\"\n\r\xdc\xd7\x1e\x95\xb8\x11\xef\x82M!7\x80\\^4" - b"\0\0\0\x0cdefault::Bar\x01\0\0\0^\x01\x1dMg\xe7{\xdd]9\x90\x97" - b"O\x82\xfa\xd8\xaf7\0\0\x02\0\x04\0\0\0\x01A\0\0\0\t__tname__" - b"\0\0\0\x02\0\0\0\0A\0\0\0\x02id\0\x01\0\x02\0\0\0\0o\0\0\0\x05" - b"title\0\0\0\x02\0\0\0\0o\0\0\0\x04body\0\0\0\x03" - ))?, - vec![ - Descriptor::Scalar(ScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000101" - .parse::()? - .into(), - name: Some(String::from("std::str")), - schema_defined: Some(true), - ancestors: vec![], - base_type_pos: None, - }), - Descriptor::Scalar(ScalarTypeDescriptor { - id: "00000000-0000-0000-0000-000000000100" - .parse::()? - .into(), - name: Some(String::from("std::uuid")), - schema_defined: Some(true), - ancestors: vec![], - base_type_pos: None, - }), - Descriptor::Object(ObjectTypeDescriptor { - id: "c3cca752-95b7-11ef-b487-1d1b9fa23003" - .parse::()? - .into(), - name: Some(String::from("default::Foo")), - schema_defined: Some(true), - }), - Descriptor::Object(ObjectTypeDescriptor { - id: "0ddcd71e-95b8-11ef-824d-2137805c5e34" - .parse::()? - .into(), - name: Some(String::from("default::Bar")), - schema_defined: Some(true), - }), - Descriptor::ObjectShape(ObjectShapeDescriptor { - id: "1d4d67e7-7bdd-5d39-9097-4f82fad8af37" - .parse::()? - .into(), - ephemeral_free_shape: false, - type_pos: Some(TypePos(2)), - elements: vec![ - ShapeElement { - flag_implicit: true, - flag_link_property: false, - flag_link: false, - cardinality: Some(One), - name: String::from("__tname__"), - type_pos: TypePos(0), - source_type_pos: Some(TypePos(2)), - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: false, - cardinality: Some(One), - name: String::from("id"), - type_pos: TypePos(1), - source_type_pos: Some(TypePos(2)), - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: false, - cardinality: Some(AtMostOne), - name: String::from("title"), - type_pos: TypePos(0), - source_type_pos: Some(TypePos(2)), - }, - ShapeElement { - flag_implicit: false, - flag_link_property: false, - flag_link: false, - cardinality: Some(AtMostOne), - name: String::from("body"), - type_pos: TypePos(0), - source_type_pos: Some(TypePos(3)), - }, - ] - }) - ] - ); - Ok(()) -} diff --git a/edgedb-tokio/Cargo.toml b/edgedb-tokio/Cargo.toml index db712f3c..7a54a00f 100644 --- a/edgedb-tokio/Cargo.toml +++ b/edgedb-tokio/Cargo.toml @@ -6,70 +6,25 @@ authors = ["MagicStack Inc. "] edition = "2021" description = """ EdgeDB database client implementation for tokio. + This crate has been renamed to gel-tokio. """ readme = "README.md" rust-version.workspace = true [dependencies] -edgedb-protocol = { path = "../edgedb-protocol", version = "0.6.0", features = [ - "with-serde", -] } -edgedb-errors = { path = "../edgedb-errors", version = "0.4.1" } -edgedb-derive = { path = "../edgedb-derive", version = "0.5.1", optional = true } -tokio = { version = "1.15", features = ["net", "time", "sync", "macros"] } -bytes = "1.5.0" -scram = { version = "0.7", package = "scram-2" } -serde = { version = "1.0", features = ["derive"] } +edgedb-derive = { path = "../edgedb-derive", version = "0.6", optional = true } serde_json = { version = "1.0", optional = true } -sha1 = { version = "0.10.1", features = ["std"] } -base16ct = { version = "0.2.0", features = ["alloc"] } -log = "0.4.8" -rand = "0.8" -url = "2.1.1" -tls-api = { version = "0.12.0" } -tls-api-not-tls = { version = "0.12.1" } -tls-api-rustls = { version = "0.12.1" } -rustls = { version = "0.23.5", default-features = false, features = [ - "ring", -] } # keep in sync with tls-api -rustls-native-certs = "0.8.1" -rustls-pemfile = "2.1.2" -webpki = { package = "rustls-webpki", version = "0.102.2", features = [ - "std", -], default-features = false } -webpki-roots = "0.26.1" -async-trait = "0.1.52" -anyhow = "1.0.53" # needed for tls-api dirs = { version = "5.0.0", optional = true } -arc-swap = "1.5.1" -once_cell = "1.9.0" tokio-stream = { version = "0.1.11", optional = true } -base64 = "0.22.1" -crc16 = "0.4.0" -socket2 = "0.5" - -[target.'cfg(target_family="unix")'.dev-dependencies] -command-fds = "0.3.0" - -[dev-dependencies] -shutdown_hooks = "0.1.0" -env_logger = "0.11" -thiserror = "2" -test-log = "0.2.8" -futures-util = "0.3.21" -miette = { version = "7.2.0", features = ["fancy"] } -edgedb-errors = { path = "../edgedb-errors", features = ["miette"] } -test-utils = { git = "https://github.com/edgedb/test-utils.git" } -tempfile = "3.13.0" [features] -default = ["derive", "env"] -derive = ["edgedb-derive"] -env = ["fs"] -admin_socket = ["dirs"] -unstable = ["serde_json", "tokio-stream"] # features for CLI and Wasm -fs = ["tokio/fs", "dirs", "serde_json"] -miette-errors = ["edgedb-errors/miette"] +default = [] +derive = [] +env = [] +admin_socket = [] +unstable = [] +fs = [] +miette-errors = [] [lints] workspace = true diff --git a/edgedb-tokio/README.md b/edgedb-tokio/README.md index 0029e7ff..14f62774 100644 --- a/edgedb-tokio/README.md +++ b/edgedb-tokio/README.md @@ -1,42 +1,9 @@ EdgeDB Rust Binding for Tokio ============================= -Work in progress asynchronous bindings of EdgeDB for [Tokio](https://tokio.rs/) -main loop. - -# Example Usage - -```rust -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let conn = edgedb_tokio::create_client().await?; - let val = conn.query_required_single::( - "SELECT 7*8", - &(), - ).await?; - println!("7*8 is: {}", val); - Ok(()) -} -``` - -# Transaction Example - -```rust -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let conn = edgedb_tokio::create_client().await?; - let val = conn.transaction(|mut transaction| async move { - transaction.query_required_single::( - "SELECT (UPDATE Counter SET { value := .value + 1}).value LIMIT 1", - &() - ).await - }).await?; - println!("Counter: {val}"); - Ok(()) -} -``` - -More [examples on github](https://github.com/edgedb/edgedb-rust/tree/master/edgedb-tokio/examples) +Asynchronous bindings of EdgeDB for [Tokio](https://tokio.rs/) main loop. + +> This crate has been renamed to [gel-tokio](https://crates.io/crates/gel-tokio). License diff --git a/edgedb-tokio/examples/query_args.rs b/edgedb-tokio/examples/query_args.rs deleted file mode 100644 index e75a9f09..00000000 --- a/edgedb-tokio/examples/query_args.rs +++ /dev/null @@ -1,47 +0,0 @@ -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let conn = edgedb_tokio::create_client().await?; - - let no_args: String = conn - .query_required_single("select {'No args inside here'};", &()) - .await?; - println!("No args:\n{no_args}"); - - let arg = "Hi I'm arg"; - let one_arg: String = conn - .query_required_single("select {'One arg inside here: ' ++ $0};", &(arg,)) - .await?; - // Note the comma to indicate this is a tuple: ^^^^^^^^ - println!("One arg:\n{one_arg}"); - - let vec_of_numbers: Vec = conn - .query("select {$0, $1}", &(10, 20)) - .await?; - println!("Here are your numbers back: {vec_of_numbers:?}"); - - let num_res: i32 = conn - .query_required_single("select {$0 + $1}", &(10, 20)) - .await?; - println!("Two args added together: {num_res}"); - - let name = "Gadget"; - let title = "Inspector"; - let val = conn - .query_required_single::( - "with person := - (name := $0, title := $1) - select person.title ++ ' ' ++ person.name;", - &(name, title), - ) - .await?; - println!("And person's name was... {}!", val); - - // Arguments must be positional ($0, $1, $2, etc.), not named - let named_args_fail: Result = conn - .query_required_single("select {$num1 + $num2}", &(10, 20)) - .await; - assert!(format!("{named_args_fail:?}") - .contains("expected positional arguments, got num1 instead of 0")); - - Ok(()) -} diff --git a/edgedb-tokio/examples/simple.rs b/edgedb-tokio/examples/simple.rs deleted file mode 100644 index cebe74ba..00000000 --- a/edgedb-tokio/examples/simple.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let conn = edgedb_tokio::create_client().await?; - let val = conn - .query_required_single::("SELECT 7*8", &()) - .await?; - println!("7*8 is: {}", val); - Ok(()) -} diff --git a/edgedb-tokio/examples/transaction.rs b/edgedb-tokio/examples/transaction.rs deleted file mode 100644 index 4032f45f..00000000 --- a/edgedb-tokio/examples/transaction.rs +++ /dev/null @@ -1,17 +0,0 @@ -#[tokio::main] -async fn main() -> anyhow::Result<()> { - env_logger::init(); - let conn = edgedb_tokio::create_client().await?; - let val = conn - .transaction(|mut transaction| async move { - transaction - .query_required_single::( - "SELECT (UPDATE Counter SET { value := .value + 1}).value LIMIT 1", - &(), - ) - .await - }) - .await?; - println!("Counter: {val}"); - Ok(()) -} diff --git a/edgedb-tokio/examples/transaction_errors.rs b/edgedb-tokio/examples/transaction_errors.rs deleted file mode 100644 index 5f7571c1..00000000 --- a/edgedb-tokio/examples/transaction_errors.rs +++ /dev/null @@ -1,56 +0,0 @@ -use rand::{thread_rng, Rng}; -use std::error::Error; - -use edgedb_errors::{ErrorKind, UserError}; - -#[derive(thiserror::Error, Debug)] -#[error("should not apply this counter update")] -struct CounterError; - -fn check_val0(val: i64) -> anyhow::Result<()> { - if val % 3 == 0 { - if thread_rng().gen_bool(0.9) { - Err(CounterError)?; - } - } - Ok(()) -} - -fn check_val1(val: i64) -> Result<(), CounterError> { - if val % 3 == 1 { - if thread_rng().gen_bool(0.1) { - Err(CounterError)?; - } - } - Ok(()) -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - env_logger::init(); - let conn = edgedb_tokio::create_client().await?; - let res = conn - .transaction(|mut transaction| async move { - let val = transaction - .query_required_single::( - " - WITH counter := (UPDATE Counter SET { value := .value + 1}), - SELECT counter.value LIMIT 1 - ", - &(), - ) - .await?; - check_val0(val)?; - check_val1(val).map_err(UserError::with_source)?; - Ok(val) - }) - .await; - match res { - Ok(val) => println!("New counter value: {val}"), - Err(e) if e.source().map_or(false, |e| e.is::()) => { - println!("Skipping: {e:#}"); - } - Err(e) => return Err(e)?, - } - Ok(()) -} diff --git a/edgedb-tokio/src/builder.rs b/edgedb-tokio/src/builder.rs deleted file mode 100644 index fec92b30..00000000 --- a/edgedb-tokio/src/builder.rs +++ /dev/null @@ -1,2335 +0,0 @@ -use std::borrow::Cow; -use std::collections::HashMap; -use std::env; -use std::ffi::{OsStr, OsString}; -use std::fmt; -use std::io; -use std::path::{Path, PathBuf}; -use std::str::{self, FromStr}; -use std::sync::Arc; -use std::time::Duration; - -use base64::Engine; -use rustls::client::danger::ServerCertVerifier; -use rustls::crypto; -use serde_json::from_slice; -use sha1::Digest; -use tokio::fs; - -use edgedb_protocol::model; - -use crate::credentials::{Credentials, TlsSecurity}; -use crate::env::{get_env, Env}; -use crate::errors::{ClientError, Error, ErrorKind, ResultExt}; -use crate::errors::{ClientNoCredentialsError, NoCloudConfigFound}; -use crate::errors::{InterfaceError, InvalidArgumentError}; -use crate::{tls, PROJECT_FILES}; - -pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); -pub const DEFAULT_WAIT: Duration = Duration::from_secs(30); -pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60); -pub const DEFAULT_POOL_SIZE: usize = 10; -pub const DEFAULT_HOST: &str = "localhost"; -pub const DEFAULT_PORT: u16 = 5656; -const DOMAIN_LABEL_MAX_LENGTH: usize = 63; -const CLOUD_INSTANCE_NAME_MAX_LENGTH: usize = DOMAIN_LABEL_MAX_LENGTH - 2 + 1; // "--" -> "/" - -type Verifier = Arc; - -mod sealed { - use super::*; - - /// Helper trait to extract errors and redirect them to the Vec. - pub(super) trait ErrorBuilder { - /// Convert a Result, Error> to an Option. - /// If the result is an error, it is pushed to the Vec. - fn maybe(&mut self, res: Result, Error>) -> Option; - - /// Convert a Result to an Option. - /// If the result is an error, it is pushed to the Vec. - fn check(&mut self, res: Result) -> Option; - } - - impl ErrorBuilder for Vec { - fn maybe(&mut self, res: Result, Error>) -> Option { - match res { - Ok(v) => v, - Err(e) => { - self.push(e); - None - } - } - } - - fn check(&mut self, res: Result) -> Option { - match res { - Ok(v) => Some(v), - Err(e) => { - self.push(e); - None - } - } - } - } -} - -use sealed::ErrorBuilder; - -/// Client security mode. -#[derive(Default, Debug, Clone, Copy)] -pub enum ClientSecurity { - /// Disable security checks - InsecureDevMode, - /// Always verify domain an certificate - Strict, - /// Verify domain only if no specific certificate is configured - #[default] - Default, -} - -/// Client security mode. -#[derive(Debug, Clone, Copy)] -pub enum CloudCerts { - Staging, - Local, -} - -impl CloudCerts { - pub fn root(&self) -> &'static str { - match self { - // Staging certs retrieved from - // https://letsencrypt.org/docs/staging-environment/#root-certificates - CloudCerts::Staging => include_str!("letsencrypt_staging.pem"), - // Local nebula development root cert found in - // nebula/infra/terraform/local/ca/root.certificate.pem - CloudCerts::Local => include_str!("nebula_development.pem"), - } - } -} - -/// TCP keepalive configuration. -#[derive(Default, Debug, Clone, Copy)] -pub enum TcpKeepalive { - /// Disable TCP keepalive probes. - Disabled, - /// Explicit duration between TCP keepalive probes. - Explicit(Duration), - /// Default: 60 seconds. - #[default] - Default, -} - -impl TcpKeepalive { - fn as_keepalive(&self) -> Option { - match self { - TcpKeepalive::Disabled => None, - TcpKeepalive::Default => Some(DEFAULT_TCP_KEEPALIVE), - TcpKeepalive::Explicit(duration) => Some(*duration), - } - } -} - -/// A builder used to create connections. -#[derive(Debug, Clone, Default)] -pub struct Builder { - instance: Option, - dsn: Option, - credentials: Option, - credentials_file: Option, - host: Option, - port: Option, - unix_path: Option, - user: Option, - database: Option, - branch: Option, - password: Option, - tls_ca_file: Option, - tls_security: Option, - tls_server_name: Option, - client_security: Option, - pem_certificates: Option, - wait_until_available: Option, - admin: bool, - connect_timeout: Option, - tcp_keepalive: Option, - secret_key: Option, - cloud_profile: Option, - - // Pool configuration - max_concurrency: Option, -} -/// Configuration of the client -/// -/// Use [`Builder`][] to create an instance -#[derive(Clone)] -pub struct Config(pub(crate) Arc); - -impl Config { - /// The duration for which the client will attempt to establish a connection. - pub fn wait_until_available(&self) -> Duration { - self.0.wait - } -} - -#[derive(Debug, Clone)] -pub(crate) struct ConfigInner { - pub address: Address, - pub admin: bool, - pub user: String, - pub password: Option, - pub secret_key: Option, - pub cloud_profile: Option, - pub database: String, - pub branch: String, - pub verifier: Verifier, - pub wait: Duration, - pub connect_timeout: Duration, - pub cloud_certs: Option, - #[allow(dead_code)] // used only only for tests - pub extra_dsn_query_args: HashMap, - #[allow(dead_code)] // used only on unstable feature - pub creds_file_outdated: bool, - - // Whether to set TCP keepalive or not - pub tcp_keepalive: Option, - - // Pool configuration - pub max_concurrency: Option, - - pub tls_server_name: Option, - - instance_name: Option, - tls_security: TlsSecurity, - client_security: ClientSecurity, - pem_certificates: Option, -} - -#[derive(Debug, Clone)] -pub(crate) enum Address { - Tcp((String, u16)), - #[allow(dead_code)] // TODO(tailhook), but for cli only - Unix(PathBuf), -} - -struct DisplayAddr<'a>(Option<&'a Address>); - -struct DsnHelper<'a> { - url: &'a url::Url, - admin: bool, - query: HashMap, Cow<'a, str>>, -} - -/// Parsed EdgeDB instance name. -#[derive(Clone, Debug)] -pub enum InstanceName { - /// Instance configured locally - Local(String), - /// Instance running on the EdgeDB Cloud - Cloud { - /// Organization name - org_slug: String, - /// Instance name within the organization - name: String, - }, -} - -#[derive(Debug, serde::Deserialize)] -pub struct CloudConfig { - pub secret_key: String, -} - -#[derive(Debug, serde::Deserialize)] -struct Claims { - #[serde(rename = "iss", skip_serializing_if = "Option::is_none")] - issuer: Option, -} - -#[cfg(unix)] -fn path_bytes(path: &Path) -> &'_ [u8] { - use std::os::unix::ffi::OsStrExt; - path.as_os_str().as_bytes() -} - -#[cfg(windows)] -fn path_bytes<'x>(path: &'x Path) -> &'x [u8] { - path.to_str() - .expect("windows paths are always valid UTF-16") - .as_bytes() -} - -fn hash(path: &Path) -> String { - format!( - "{:x}", - sha1::Sha1::new_with_prefix(path_bytes(path)).finalize() - ) -} - -fn stash_name(path: &Path) -> OsString { - let hash = hash(path); - let base = path.file_name().unwrap_or(OsStr::new("")); - let mut base = base.to_os_string(); - base.push("-"); - base.push(&hash); - base -} - -fn config_dir() -> Result { - let dir = if cfg!(windows) { - dirs::data_local_dir() - .ok_or_else(|| ClientError::with_message("cannot determine local data directory"))? - .join("EdgeDB") - .join("config") - } else { - dirs::config_dir() - .ok_or_else(|| ClientError::with_message("cannot determine config directory"))? - .join("edgedb") - }; - Ok(dir) -} - -#[allow(dead_code)] -#[cfg(target_os = "linux")] -fn default_runtime_base() -> Result { - extern "C" { - fn geteuid() -> u32; - } - Ok(Path::new("/run/user").join(unsafe { geteuid() }.to_string())) -} - -#[allow(dead_code)] -#[cfg(not(target_os = "linux"))] -fn default_runtime_base() -> Result { - Err(ClientError::with_message( - "no default runtime dir for the platform", - )) -} - -/// Compute the path to the project's stash file based on the canonical path. -pub fn get_stash_path(project_dir: &Path) -> Result { - let canonical = project_dir.canonicalize().map_err(|e| { - ClientError::with_source(e).context("project directory could not be canonicalized") - })?; - Ok(config_dir()?.join("projects").join(stash_name(&canonical))) -} - -fn is_valid_local_instance_name(name: &str) -> bool { - // For local instance names: - // 1. Allow only letters, numbers, underscores and single dashes - // 2. Must not start or end with a dash - // regex: ^[a-zA-Z_0-9]+(-[a-zA-Z_0-9]+)*$ - let mut chars = name.chars(); - match chars.next() { - Some(c) if c.is_ascii_alphanumeric() || c == '_' => {} - _ => return false, - } - let mut was_dash = false; - for c in chars { - if c == '-' { - if was_dash { - return false; - } else { - was_dash = true; - } - } else { - if !c.is_ascii_alphanumeric() && c != '_' { - return false; - } - was_dash = false; - } - } - !was_dash -} - -fn is_valid_cloud_instance_name(name: &str) -> bool { - // For cloud instance name part: - // 1. Allow only letters, numbers and single dashes - // 2. Must not start or end with a dash - // regex: ^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$ - let mut chars = name.chars(); - match chars.next() { - Some(c) if c.is_ascii_alphanumeric() => {} - _ => return false, - } - let mut was_dash = false; - for c in chars { - if c == '-' { - if was_dash { - return false; - } else { - was_dash = true; - } - } else { - if !c.is_ascii_alphanumeric() { - return false; - } - was_dash = false; - } - } - !was_dash -} - -fn is_valid_cloud_org_name(name: &str) -> bool { - // For cloud organization slug part: - // 1. Allow only letters, numbers, underscores and single dashes - // 2. Must not end with a dash - // regex: ^-?[a-zA-Z0-9_]+(-[a-zA-Z0-9]+)*$ - let mut chars = name.chars(); - match chars.next() { - Some(c) if c.is_ascii_alphanumeric() || c == '-' || c == '_' => {} - _ => return false, - } - let mut was_dash = false; - for c in chars { - if c == '-' { - if was_dash { - return false; - } else { - was_dash = true; - } - } else { - if !(c.is_ascii_alphanumeric() || c == '_') { - return false; - } - was_dash = false; - } - } - !was_dash -} - -impl fmt::Display for InstanceName { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - InstanceName::Local(name) => name.fmt(f), - InstanceName::Cloud { org_slug, name } => write!(f, "{}/{}", org_slug, name), - } - } -} - -impl FromStr for InstanceName { - type Err = Error; - fn from_str(name: &str) -> Result { - if let Some((org_slug, name)) = name.split_once('/') { - if !is_valid_cloud_instance_name(name) { - return Err(ClientError::with_message(format!( - "invalid cloud instance name \"{}\", must follow \ - regex: ^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$", - name, - ))); - } - if !is_valid_cloud_org_name(org_slug) { - return Err(ClientError::with_message(format!( - "invalid cloud org name \"{}\", must follow \ - regex: ^-?[a-zA-Z0-9_]+(-[a-zA-Z0-9]+)*$", - org_slug, - ))); - } - if name.len() > CLOUD_INSTANCE_NAME_MAX_LENGTH { - return Err(ClientError::with_message(format!( - "invalid cloud instance name \"{}\": \ - length cannot exceed {} characters", - name, CLOUD_INSTANCE_NAME_MAX_LENGTH, - ))); - } - Ok(InstanceName::Cloud { - org_slug: org_slug.into(), - name: name.into(), - }) - } else { - if !is_valid_local_instance_name(name) { - return Err(ClientError::with_message(format!( - "invalid instance name \"{}\", must be either following \ - regex: ^[a-zA-Z_0-9]+(-[a-zA-Z_0-9]+)*$ or \ - a cloud instance name ORG/INST.", - name, - ))); - } - Ok(InstanceName::Local(name.into())) - } - } -} - -fn cloud_config_file(profile: &str) -> anyhow::Result { - Ok(config_dir()? - .join("cloud-credentials") - .join(format!("{}.json", profile))) -} - -impl fmt::Display for DisplayAddr<'_> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match &self.0 { - Some(Address::Tcp((host, port))) => { - write!(f, "{}:{}", host, port) - } - Some(Address::Unix(path)) => write!(f, "unix:{}", path.display()), - None => write!(f, ""), - } - } -} - -impl<'a> DsnHelper<'a> { - fn from_url(url: &'a url::Url) -> Result { - use std::collections::hash_map::Entry::*; - - let admin = url.scheme() == "edgedbadmin"; - let mut query = HashMap::new(); - for (k, v) in url.query_pairs() { - match query.entry(k) { - Vacant(e) => { - e.insert(v); - } - Occupied(e) => { - return Err(ClientError::with_message(format!( - "{:?} is defined multiple times in the DSN query", - e.key() - )) - .context("invalid DSN")); - } - } - } - Ok(Self { url, admin, query }) - } - - fn ignore_value(&mut self, key: &str) { - self.query.remove(key); - self.query.remove(&format!("{}_env", key)[..]); - self.query.remove(&format!("{}_file", key)[..]); - } - - async fn retrieve_value( - &mut self, - key: &'static str, - v_from_url: Option, - conv: impl FnOnce(String) -> Result, - ) -> Result, Error> { - self._retrieve_value(key, v_from_url, conv) - .await - .context("invalid DSN") - } - - async fn _retrieve_value( - &mut self, - key: &'static str, - v_from_url: Option, - conv: impl FnOnce(String) -> Result, - ) -> Result, Error> { - let v_query = self.query.remove(key); - let k_env = format!("{key}_env"); - let v_env = self.query.remove(k_env.as_str()); - let k_file = format!("{key}_file"); - let v_file = self.query.remove(k_file.as_str()); - - let defined_param_names = vec![ - v_from_url.as_ref().map(|_| format!("{key} of URL")), - v_query.as_ref().map(|_| format!("query {key}")), - v_env.as_ref().map(|_| format!("query {k_env}")), - v_file.as_ref().map(|_| format!("query {k_file}")), - ] - .into_iter() - .flatten() - .collect::>(); - if defined_param_names.len() > 1 { - return Err(InterfaceError::with_message(format!( - "{key} defined multiple times: {}", - defined_param_names.join(", "), - ))); - } - - if v_from_url.is_some() { - Ok(v_from_url) - } else if let Some(val) = v_query { - conv(val.to_string()) - .map(|rv| Some(rv)) - .with_context(|| format!("failed to parse value of query {key}")) - } else if let Some(env_name) = v_env { - let val = get_env(&env_name)?.ok_or(ClientError::with_message(format!( - "{k_env}: {env_name} is not set" - )))?; - conv(val) - .map(|rv| Some(rv)) - .with_context(|| format!("failed to parse value of {k_env}: {env_name}")) - } else if let Some(file_path) = v_file { - let val = fs::read_to_string(Path::new(file_path.as_ref())) - .await - .map_err(|e| { - ClientError::with_source(e) - .context(format!("error reading {k_file}: {file_path}")) - })?; - conv(val) - .map(|rv| Some(rv)) - .with_context(|| format!("failed to parse content of {k_file}: {file_path}")) - } else { - Ok(None) - } - } - - async fn retrieve_host(&mut self) -> Result, Error> { - if let Some(url::Host::Ipv6(host)) = self.url.host() { - // async-std uses raw IPv6 address without "[]" - Ok(Some(host.to_string())) - } else { - let url_host = if let Some(host) = self.url.host_str() { - validate_host(host)?; - Some(host.to_owned()) - } else { - None - }; - self.retrieve_value("host", url_host, validate_host).await - } - } - - async fn retrieve_tls_server_name(&mut self) -> Result, Error> { - self.retrieve_value("tls_server_name", None, Ok).await - } - - async fn retrieve_port(&mut self) -> Result, Error> { - self.retrieve_value("port", self.url.port(), |s| { - s.parse() - .map_err(|e| InterfaceError::with_source(e).context("invalid port")) - }) - .await - } - - async fn retrieve_user(&mut self) -> Result, Error> { - let username = self.url.username(); - let v = if username.is_empty() { - None - } else { - Some(username.to_owned()) - }; - self.retrieve_value("user", v, validate_user).await - } - - async fn retrieve_password(&mut self) -> Result, Error> { - let v = self.url.password().map(|s| s.to_owned()); - self.retrieve_value("password", v, Ok).await - } - - async fn retrieve_database(&mut self) -> Result, Error> { - let v = self.url.path().strip_prefix('/').and_then(|s| { - if s.is_empty() { - None - } else { - Some(s.to_owned()) - } - }); - self.retrieve_value("database", v, |s| { - let s = s.strip_prefix('/').unwrap_or(&s); - validate_database(&s)?; - Ok(s.to_owned()) - }) - .await - } - - async fn retrieve_branch(&mut self) -> Result, Error> { - let v = self.url.path().strip_prefix('/').and_then(|s| { - if s.is_empty() { - None - } else { - Some(s.to_owned()) - } - }); - self.retrieve_value("branch", v, |s| { - let s = s.strip_prefix('/').unwrap_or(&s); - validate_branch(&s)?; - Ok(s.to_owned()) - }) - .await - } - - async fn retrieve_secret_key(&mut self) -> Result, Error> { - self.retrieve_value("secret_key", None, Ok).await - } - - async fn retrieve_tls_ca_file(&mut self) -> Result, Error> { - self.retrieve_value("tls_ca_file", None, Ok).await - } - - async fn retrieve_tls_security(&mut self) -> Result, Error> { - self.retrieve_value("tls_security", None, |x| x.parse()) - .await - } - - async fn retrieve_wait_until_available(&mut self) -> Result, Error> { - self.retrieve_value("wait_until_available", None, |s| { - s.parse::() - .map_err(ClientError::with_source) - .and_then(|d| match d.is_negative() { - false => Ok(d.abs_duration()), - true => Err(ClientError::with_message( - "negative durations are unsupported", - )), - }) - }) - .await - } - - fn remaining_queries(&self) -> HashMap { - self.query - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect() - } -} - -impl Builder { - /// Create a builder with empty options - pub fn new() -> Builder { - Default::default() - } - - /// Set instance name - #[cfg(feature = "env")] - pub fn instance(&mut self, name: &str) -> Result<&mut Self, Error> { - self.instance = Some(name.parse()?); - Ok(self) - } - - /// Set connection parameters as DSN - #[cfg(feature = "env")] - pub fn dsn(&mut self, dsn: &str) -> Result<&mut Self, Error> { - if !dsn.starts_with("edgedb://") && !dsn.starts_with("edgedbadmin://") && !dsn.starts_with("gel://") { - return Err(InvalidArgumentError::with_message(format!( - "String {:?} is not a valid DSN", - dsn - ))); - }; - let url = url::Url::parse(dsn).map_err(|e| { - InvalidArgumentError::with_source(e).context(format!("cannot parse DSN {:?}", dsn)) - })?; - self.dsn = Some(url); - Ok(self) - } - - /// Set connection parameters as credentials structure - pub fn credentials(&mut self, credentials: &Credentials) -> Result<&mut Self, Error> { - if let Some(cert_data) = &credentials.tls_ca { - validate_certs(cert_data).context("invalid certificates in `tls_ca`")?; - } - self.credentials = Some(credentials.clone()); - Ok(self) - } - - /// Set connection parameters from file - /// - /// Note: file is not read immediately but is read when configuration is - /// being built. - #[cfg(feature = "fs")] - pub fn credentials_file(&mut self, path: impl AsRef) -> &mut Self { - self.credentials_file = Some(path.as_ref().to_path_buf()); - self - } - - /// Set host to connect to - pub fn host(&mut self, host: &str) -> Result<&mut Self, Error> { - validate_host(host)?; - self.host = Some(host.to_string()); - Ok(self) - } - - /// Override server name indication (SNI) in TLS handshake - pub fn tls_server_name(&mut self, tls_server_name: &str) -> Result<&mut Self, Error> { - validate_host(tls_server_name)?; - self.tls_server_name = Some(tls_server_name.to_string()); - Ok(self) - } - - /// Set port to connect to - pub fn port(&mut self, port: u16) -> Result<&mut Self, Error> { - validate_port(port)?; - self.port = Some(port); - Ok(self) - } - - /// Set path to unix socket - #[cfg(feature = "admin_socket")] - pub fn unix_path(&mut self, path: impl AsRef) -> &mut Self { - self.unix_path = Some(path.as_ref().to_path_buf()); - self - } - - #[cfg(feature = "admin_socket")] - pub fn admin(&mut self, admin: bool) -> &mut Self { - self.admin = admin; - self - } - - /// Set the user name for authentication. - pub fn user(&mut self, user: &str) -> Result<&mut Self, Error> { - validate_user(user)?; - self.user = Some(user.to_string()); - Ok(self) - } - - /// Set the password for SCRAM authentication. - pub fn password(&mut self, password: &str) -> &mut Self { - self.password = Some(password.to_string()); - self - } - /// Set the database name. - pub fn database(&mut self, database: &str) -> Result<&mut Self, Error> { - validate_database(database)?; - self.database = Some(database.into()); - Ok(self) - } - - /// Set the branch name. - pub fn branch(&mut self, branch: &str) -> Result<&mut Self, Error> { - validate_branch(branch)?; - self.branch = Some(branch.into()); - Ok(self) - } - - /// Set certificate authority for TLS from file - /// - /// Note: file is not read immediately but is read when configuration is - /// being built. - #[cfg(feature = "fs")] - pub fn tls_ca_file(&mut self, path: &Path) -> &mut Self { - self.tls_ca_file = Some(path.to_path_buf()); - self - } - - /// Updates the client TLS security mode. - /// - /// By default, the certificate chain is always verified; but hostname - /// verification is disabled if configured to use only a - /// specific certificate, and enabled if root certificates are used. - pub fn tls_security(&mut self, value: TlsSecurity) -> &mut Self { - self.tls_security = Some(value); - self - } - - /// Modifies the client security mode. - /// - /// InsecureDevMode changes tls_security only from Default to Insecure - /// Strict ensures tls_security is also Strict - pub fn client_security(&mut self, value: ClientSecurity) -> &mut Self { - self.client_security = Some(value); - self - } - - /// Set the allowed certificate as a PEM file. - pub fn pem_certificates(&mut self, cert_data: &str) -> Result<&mut Self, Error> { - validate_certs(cert_data).context("invalid PEM certificate")?; - self.pem_certificates = Some(cert_data.into()); - Ok(self) - } - - /// Set the secret key for JWT authentication. - pub fn secret_key(&mut self, secret_key: &str) -> &mut Self { - self.secret_key = Some(secret_key.into()); - self - } - - /// Set the time to wait for the database server to become available. - /// - /// This works by ignoring certain errors known to happen while the - /// database is starting up or restarting (e.g. "connection refused" or - /// early "connection reset"). - /// - /// Note: the amount of time establishing a connection can take is the sum - /// of `wait_until_available` plus `connect_timeout` - pub fn wait_until_available(&mut self, time: Duration) -> &mut Self { - self.wait_until_available = Some(time); - self - } - - /// A timeout for a single connect attempt. - /// - /// The default is 10 seconds. A subsecond timeout should be fine for most - /// networks. However, in some cases this can be much slower. That's - /// because this timeout includes authentication, during which: - /// * The password is checked (slow by design). - /// * A compiler process is launched (slow now, may be optimized later). - /// - /// So in a concurrent case on slower VMs (such as CI with parallel - /// tests), 10 seconds is more reasonable default. - /// - /// The `wait_until_available` setting should be larger than this value to - /// allow multiple attempts. - /// - /// Note: the amount of time establishing a connection can take is the sum - /// of `wait_until_available` plus `connect_timeout` - pub fn connect_timeout(&mut self, timeout: Duration) -> &mut Self { - self.connect_timeout = Some(timeout); - self - } - - /// Sets the TCP keepalive interval and time for the database connection to - /// ensure that the remote end of the connection is still alive, and to - /// inform any network intermediaries that this connection is not idle. By - /// default, a keepalive probe will be sent once every 60 seconds once the - /// connection has been idle for 60 seconds. - /// - /// Note: If the connection is not made over a TCP socket, this value will - /// be unused. If the current platform does not support explicit TCP - /// keep-alive intervals on the socket, keepalives will be enabled and the - /// operating-system default for the intervals will be used. - pub fn tcp_keepalive(&mut self, tcp_keepalive: TcpKeepalive) -> &mut Self { - self.tcp_keepalive = Some(tcp_keepalive); - self - } - - /// Set the maximum number of underlying database connections. - pub fn max_concurrency(&mut self, value: usize) -> &mut Self { - self.max_concurrency = Some(value); - self - } - - /// Build connection and pool configuration in constrained mode - /// - /// Normal [`Builder::build_env()`], reads environment variables and files - /// if appropriate to build configuration variables. This method never reads - /// files or environment variables. Therefore it never blocks, so is not - /// asyncrhonous. - /// - /// The limitations are: - /// - /// 1. [`Builder::credentials_file()`] is not supported - /// 2. [`Builder::dsn()`] is not supported yet (although, will be - /// implemented later restricing `*_file` and `*_env` query args - #[cfg(any(feature = "unstable", test))] - pub fn constrained_build(&self) -> Result { - let address = if let Some(unix_path) = &self.unix_path { - let port = self.port.unwrap_or(DEFAULT_PORT); - Address::Unix(resolve_unix(unix_path, port, self.admin)) - } else if let Some(credentials) = &self.credentials { - let host = self - .host - .clone() - .or_else(|| credentials.host.clone()) - .unwrap_or(DEFAULT_HOST.into()); - let port = self.port.unwrap_or(credentials.port); - Address::Tcp((host, port)) - } else { - Address::Tcp(( - self.host.clone().unwrap_or_else(|| DEFAULT_HOST.into()), - self.port.unwrap_or(DEFAULT_PORT), - )) - }; - if self.instance.is_some() - || self.dsn.is_some() - || self.credentials_file.is_some() - || self.tls_ca_file.is_some() - || self.secret_key.is_some() - || self.cloud_profile.is_some() - { - return Err(InterfaceError::with_message( - "unsupported constraint builder param", - )); - } - let creds = self.credentials.as_ref(); - let mut cfg = ConfigInner { - address, - tls_server_name: self.tls_server_name.clone(), - admin: self.admin, - user: self - .user - .clone() - .or_else(|| creds.map(|c| c.user.clone())) - .unwrap_or_else(|| "edgedb".into()), - password: self - .password - .clone() - .or_else(|| creds.and_then(|c| c.password.clone())), - secret_key: self.secret_key.clone(), - cloud_profile: self.cloud_profile.clone(), - cloud_certs: None, - database: self - .database - .clone() - .or_else(|| creds.and_then(|c| c.database.clone())) - .unwrap_or_else(|| "edgedb".into()), - branch: self - .branch - .clone() - .or_else(|| creds.and_then(|c| c.branch.clone())) - .unwrap_or_else(|| "__default__".into()), - instance_name: None, - wait: self.wait_until_available.unwrap_or(DEFAULT_WAIT), - connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT), - extra_dsn_query_args: HashMap::new(), - creds_file_outdated: false, - pem_certificates: self - .pem_certificates - .clone() - .or_else(|| creds.and_then(|c| c.tls_ca.clone())), - tcp_keepalive: self.tcp_keepalive.unwrap_or_default().as_keepalive(), - // Pool configuration - max_concurrency: self.max_concurrency, - - // Temporary placeholders - verifier: Arc::new(tls::NullVerifier), - client_security: self.client_security.unwrap_or_default(), - tls_security: self - .tls_security - .or_else(|| creds.map(|c| c.tls_security)) - .unwrap_or_default(), - }; - - cfg.verifier = cfg.make_verifier(cfg.compute_tls_security()?); - - Ok(Config(Arc::new(cfg))) - } - - /// Build connection and pool configuration object - pub async fn build_env(&self) -> Result { - let (complete, config, mut errors) = self._build_no_fail().await; - if !complete { - return Err(ClientNoCredentialsError::with_message( - "EdgeDB connection options are not initialized. \ - Run `edgedb project init` or use environment variables \ - to configure connection.", - )); - } - if !errors.is_empty() { - return Err(errors.remove(0)); - } - Ok(config) - } - - async fn compound_owned(&self, cfg: &mut ConfigInner, errors: &mut Vec) { - let mut conflict = None; - if let Some(instance) = &self.instance { - conflict = Some("instance"); - errors.check(read_instance(cfg, instance).await); - } - if let Some(dsn) = &self.dsn { - if let Some(conflict) = conflict { - errors.push(InvalidArgumentError::with_message(format!( - "dsn argument conflicts with {}", - conflict - ))); - } - conflict = Some("dsn"); - self.read_dsn(cfg, dsn, errors).await; - } - if let Some(credentials_file) = &self.credentials_file { - if let Some(conflict) = conflict { - errors.push(InvalidArgumentError::with_message(format!( - "credentials_file argument conflicts with {}", - conflict - ))); - } - conflict = Some("credentials_file"); - errors.check(read_credentials(cfg, credentials_file).await); - } - if let Some(credentials) = &self.credentials { - if let Some(conflict) = conflict { - errors.push(InvalidArgumentError::with_message(format!( - "credentials argument conflicts with {}", - conflict - ))); - } - conflict = Some("credentials"); - errors.check(set_credentials(cfg, credentials)); - } - if let Some(host) = &self.host { - if let Some(conflict) = conflict { - errors.push(InvalidArgumentError::with_message(format!( - "host argument conflicts with {}", - conflict - ))); - } - conflict = Some("host"); - cfg.address = Address::Tcp((host.into(), self.port.unwrap_or(DEFAULT_PORT))); - } else if let Some(port) = &self.port { - if let Some(conflict) = conflict { - errors.push(InvalidArgumentError::with_message(format!( - "port argument conflicts with {}", - conflict - ))); - } - if let Address::Tcp((_, ref mut portref)) = &mut cfg.address { - *portref = *port - } - } - if let Some(unix_path) = &self.unix_path { - if let Some(conflict) = conflict { - errors.push(InvalidArgumentError::with_message(format!( - "unix_path argument conflicts with {}", - conflict - ))); - } - #[allow(unused_assignments)] - { - conflict = Some("unix_path"); - } - let port = match cfg.address { - Address::Tcp((_, port)) => port, - Address::Unix(_) => DEFAULT_PORT, - }; - let full_path = resolve_unix(unix_path, port, self.admin); - cfg.address = Address::Unix(full_path); - } - if let Some((d, b)) = &self.database.as_ref().zip(self.branch.as_ref()) { - if d != b { - errors.push(InvalidArgumentError::with_message(format!( - "database {d} conflicts with branch {b}" - ))) - } - } - } - - async fn granular_owned(&self, cfg: &mut ConfigInner, errors: &mut Vec) { - if let Some(database) = &self.database { - cfg.database.clone_from(database); - } - - if let Some(branch) = &self.branch { - cfg.branch.clone_from(branch); - } - - if let Some(user) = &self.user { - cfg.user.clone_from(user); - } - - if let Some(password) = &self.password { - cfg.password = Some(password.clone()); - } - - if let Some(tls_server_name) = &self.tls_server_name { - cfg.tls_server_name = Some(tls_server_name.clone()); - } - - if let Some(tls_ca_file) = &self.tls_ca_file { - if let Some(pem) = errors.check(read_certificates(tls_ca_file).await) { - cfg.pem_certificates = Some(pem) - } - } - - if let Some(pem) = &self.pem_certificates { - cfg.pem_certificates = Some(pem.clone()); - } - - if let Some(security) = self.tls_security { - cfg.tls_security = security; - } - - if let Some(wait) = self.wait_until_available { - cfg.wait = wait; - } - } - - async fn compound_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) -> bool { - let instance = Env::instance(); - let dsn = Env::dsn(); - let credentials_file = Env::credentials_file(); - let host = Env::host(); - let port = Env::port(); - - fn has(opt: &Result, Error>) -> bool { - opt.as_ref().map(|s| s.as_ref()).ok().flatten().is_some() - } - - let groups = [ - (has(&instance), "GEL_INSTANCE"), - (has(&dsn), "GEL_DSN"), - (has(&credentials_file), "GEL_CREDENTIALS_FILE"), - (has(&host) || has(&port), "GEL_HOST or GEL_PORT"), - ]; - - let has_envs = groups - .into_iter() - .filter_map(|(has, name)| if has { Some(name) } else { None }) - .collect::>(); - - if has_envs.len() > 1 { - errors.push(InvalidArgumentError::with_message(format!( - "environment variable {} conflicts with {}", - has_envs[0], - has_envs[1..].join(", "), - ))); - } - - if let Some(instance) = errors.maybe(instance) { - errors.check(read_instance(cfg, &instance).await); - } - if let Some(dsn) = errors.maybe(dsn) { - self.read_dsn(cfg, &dsn, errors).await - } - if let Some(fpath) = errors.maybe(credentials_file) { - errors.check(read_credentials(cfg, fpath).await); - } - if let Some(host) = errors.maybe(host) { - cfg.address = Address::Tcp((host, DEFAULT_PORT)); - } - if let Some(port) = errors.maybe(port) { - if let Address::Tcp((_, ref mut portref)) = &mut cfg.address { - *portref = port.into(); - } - } - - // This code needs a total rework... - - // Because an incomplete configuration trumps errors, we return "complete" if - // there are errors, so those errors can be reported. - !has_envs.is_empty() || !errors.is_empty() - } - - async fn secret_key_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) { - cfg.secret_key = self - .secret_key - .clone() - .or_else(|| errors.maybe(Env::secret_key())); - } - - async fn granular_env(&self, cfg: &mut ConfigInner, errors: &mut Vec) { - let database_branch = self.database.as_ref().or(self.branch.as_ref()) - .cloned() - .or_else(|| { - let database = errors.maybe(Env::database()); - let branch = errors.maybe(Env::branch()); - - if database.is_some() && branch.is_some() { - errors.push(InvalidArgumentError::with_message( - "Invalid environment: variables `EDGEDB_DATABASE` and `EDGEDB_BRANCH` are mutually exclusive", - )); - return None; - } - - database.or(branch) - }); - if let Some(name) = database_branch { - cfg.database.clone_from(&name); - cfg.branch = name; - } - - let user = self.user.clone().or_else(|| errors.maybe(Env::user())); - if let Some(user) = user { - cfg.user = user; - } - - let tls_server_name = self - .tls_server_name - .clone() - .or_else(|| errors.maybe(Env::tls_server_name())); - if let Some(tls_server_name) = tls_server_name { - cfg.tls_server_name = Some(tls_server_name); - } - - let password = self - .password - .clone() - .or_else(|| errors.maybe(Env::password())); - if let Some(password) = password { - cfg.password = Some(password); - } - - let tls_ca_file = self - .tls_ca_file - .clone() - .or_else(|| errors.maybe(Env::tls_ca_file())); - if let Some(tls_ca_file) = tls_ca_file { - if let Some(pem) = errors.check(read_certificates(tls_ca_file).await) { - cfg.pem_certificates = Some(pem) - } - } - - let tls_ca = errors.maybe(Env::tls_ca()); - if let Some(pem) = tls_ca { - if let Some(()) = errors.check(validate_certs(&pem)) { - cfg.pem_certificates = Some(pem) - } - } - - let security = errors.maybe(Env::client_tls_security()); - if let Some(security) = security { - cfg.tls_security = security; - } - - let wait = self - .wait_until_available - .or_else(|| errors.maybe(Env::wait_until_available())); - if let Some(wait) = wait { - cfg.wait = wait; - } - } - - async fn read_dsn(&self, cfg: &mut ConfigInner, url: &url::Url, errors: &mut Vec) { - let mut dsn = match DsnHelper::from_url(url) { - Ok(dsn) => dsn, - Err(e) => { - errors.push(e); - return; - } - }; - let host = errors - .maybe(dsn.retrieve_host().await) - .unwrap_or_else(|| DEFAULT_HOST.into()); - let port = errors - .maybe(dsn.retrieve_port().await) - .unwrap_or(DEFAULT_PORT); - if let Some(value) = errors.maybe(dsn.retrieve_tls_server_name().await) { - cfg.tls_server_name = Some(value) - } - cfg.address = Address::Tcp((host, port)); - cfg.admin = dsn.admin; - if let Some(value) = errors.maybe(dsn.retrieve_user().await) { - cfg.user = value - } - if self.password.is_none() { - if let Some(value) = errors.maybe(dsn.retrieve_password().await) { - cfg.password = Some(value) - } - } else { - dsn.ignore_value("password"); - } - - let has_query_branch = dsn.query.contains_key("branch") - || dsn.query.contains_key("branch_env") - || dsn.query.contains_key("branch_file"); - let has_query_database = dsn.query.contains_key("database") - || dsn.query.contains_key("database_env") - || dsn.query.contains_key("database_file"); - if has_query_branch && has_query_database { - errors.push(InvalidArgumentError::with_message( - "Invalid DSN: `database` and `branch` are mutually exclusive", - )); - } - if self.branch.is_none() && self.database.is_none() { - let database_or_branch = if has_query_database { - dsn.retrieve_database().await - } else { - dsn.retrieve_branch().await - }; - - if let Some(name) = errors.maybe(database_or_branch) { - { - cfg.branch.clone_from(&name); - cfg.database = name; - } - } - } else { - dsn.ignore_value("branch"); - dsn.ignore_value("database"); - } - - if let Some(value) = errors.maybe(dsn.retrieve_secret_key().await) { - cfg.secret_key = Some(value) - } - if self.tls_ca_file.is_none() { - if let Some(path) = errors.maybe(dsn.retrieve_tls_ca_file().await) { - if let Some(pem) = errors.check(read_certificates(&path).await) { - cfg.pem_certificates = Some(pem) - } - } - } else { - dsn.ignore_value("tls_ca_file"); - } - if let Some(value) = errors.maybe(dsn.retrieve_tls_security().await) { - cfg.tls_security = value - } - if let Some(value) = errors.maybe(dsn.retrieve_wait_until_available().await) { - cfg.wait = value - } - - cfg.extra_dsn_query_args = dsn.remaining_queries(); - } - - async fn read_project(&self, cfg: &mut ConfigInner, errors: &mut Vec) -> bool { - let pair = errors.maybe(self._get_stash_path().await); - if let Some((project, stash)) = pair { - errors.check(self._read_project(cfg, &project, &stash).await); - true - } else { - false - } - } - - async fn _get_stash_path(&self) -> Result, Error> { - let Some(dir) = get_project_path(None, true).await? else { - return Ok(None); - }; - let dir = dir - .parent() - .ok_or_else(|| ClientError::with_message("Project file has no parent"))?; - let stash_path = get_stash_path(dir)?; - if fs::metadata(&stash_path).await.is_ok() { - return Ok(Some((dir.to_owned(), stash_path))); - } - Ok(None) - } - - async fn _read_project( - &self, - cfg: &mut ConfigInner, - project_dir: &Path, - stash_path: &Path, - ) -> Result<(), Error> { - let path = stash_path.join("instance-name"); - let instance = fs::read_to_string(&path).await.map_err(|e| { - ClientError::with_source(e).context(format!( - "error reading project settings {:?}: {:?}", - project_dir, path - )) - })?; - let instance = instance.trim().parse().map_err(|e| { - ClientError::with_source(e).context(format!( - "cannot parse project's instance name: {:?}", - instance - )) - })?; - if matches!(instance, InstanceName::Cloud { .. }) { - if cfg.secret_key.is_none() && cfg.cloud_profile.is_none() { - let path = stash_path.join("cloud-profile"); - let profile = fs::read_to_string(&path) - .await - .map_err(|e| { - ClientError::with_source(e).context(format!( - "error reading project settings {:?}: {:?}", - project_dir, path - )) - })? - .trim() - .into(); - cfg.cloud_profile = Some(profile); - } - } - read_instance(cfg, &instance).await?; - let path = stash_path.join("database"); - match fs::read_to_string(&path).await { - Ok(text) => { - validate_database(text.trim()) - .with_context(|| { - format!( - "error reading project settings {:?}: {:?}", - project_dir, path - ) - })? - .clone_into(&mut cfg.database); - cfg.branch.clone_from(&cfg.database); - } - Err(e) if e.kind() == io::ErrorKind::NotFound => {} - Err(e) => { - return Err(ClientError::with_source(e).context(format!( - "error reading project settings {:?}: {:?}", - project_dir, path - ))) - } - } - Ok(()) - } - - /// Build connection and pool configuration object - /// - /// This is similar to `build_env` but never fails and fills in whatever - /// fields possible in `Config`. - /// - /// First boolean item in the tuple is `true` if configuration is complete - /// and can be used for connections. - #[cfg(any(feature = "unstable", test))] - pub async fn build_no_fail(&self) -> (bool, Config, Vec) { - self._build_no_fail().await - } - - async fn _build_no_fail(&self) -> (bool, Config, Vec) { - let mut errors = Vec::new(); - - let mut cfg = ConfigInner { - address: Address::Tcp((DEFAULT_HOST.into(), DEFAULT_PORT)), - tls_server_name: self.tls_server_name.clone(), - admin: self.admin, - user: "edgedb".into(), - password: None, - secret_key: None, - cloud_profile: None, - cloud_certs: None, - database: "edgedb".into(), - branch: "__default__".into(), - instance_name: None, - wait: self.wait_until_available.unwrap_or(DEFAULT_WAIT), - connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT), - extra_dsn_query_args: HashMap::new(), - creds_file_outdated: false, - pem_certificates: self.pem_certificates.clone(), - client_security: self.client_security.unwrap_or_default(), - tls_security: self.tls_security.unwrap_or_default(), - tcp_keepalive: self.tcp_keepalive.unwrap_or_default().as_keepalive(), - // Pool configuration - max_concurrency: self.max_concurrency, - - // Temporary placeholders - verifier: Arc::new(tls::NullVerifier), - }; - - cfg.cloud_profile = self - .cloud_profile - .clone() - .or_else(|| errors.maybe(Env::cloud_profile())); - - let complete = if self.host.is_some() - || self.port.is_some() - || self.unix_path.is_some() - || self.dsn.is_some() - || self.instance.is_some() - || self.credentials.is_some() - || self.credentials_file.is_some() - { - cfg.secret_key.clone_from(&self.secret_key); - self.compound_owned(&mut cfg, &mut errors).await; - self.granular_owned(&mut cfg, &mut errors).await; - true - } else { - self.secret_key_env(&mut cfg, &mut errors).await; - let complete = if self.compound_env(&mut cfg, &mut errors).await { - true - } else { - self.read_project(&mut cfg, &mut errors).await - }; - self.granular_env(&mut cfg, &mut errors).await; - complete - }; - - let security = errors.maybe(Env::client_security()); - - if let Some(security) = security { - cfg.client_security = security; - } - - let cloud_certs = errors.maybe(Env::_cloud_certs()); - if let Some(cloud_certs) = cloud_certs { - cfg.cloud_certs = Some(cloud_certs); - } - - // we don't overwrite this param in cfg because we want - // `with_pem_certificates` to bump security to Strict - let tls_security = errors - .check(cfg.compute_tls_security()) - .unwrap_or(TlsSecurity::Strict); - cfg.verifier = cfg.make_verifier(tls_security); - - (complete, Config(Arc::new(cfg)), errors) - } -} - -fn resolve_unix(path: impl AsRef, port: u16, admin: bool) -> PathBuf { - let has_socket_name = path - .as_ref() - .file_name() - .and_then(|x| x.to_str()) - .map(|x| x.contains(".s.EDGEDB")) - .unwrap_or(false); - let path = if has_socket_name { - // it's the full path - path.as_ref().to_path_buf() - } else { - let socket_name = if admin { - format!(".s.EDGEDB.admin.{}", port) - } else { - format!(".s.EDGEDB.{}", port) - }; - path.as_ref().join(socket_name) - }; - path -} - -async fn read_instance(cfg: &mut ConfigInner, name: &InstanceName) -> Result<(), Error> { - cfg.instance_name = Some(name.clone()); - match name { - InstanceName::Local(name) => { - read_credentials( - cfg, - config_dir()? - .join("credentials") - .join(format!("{}.json", name)), - ) - .await?; - } - InstanceName::Cloud { org_slug, name } => { - let secret_key = if let Some(secret_key) = &cfg.secret_key { - secret_key.clone() - } else { - let profile = cfg.cloud_profile.as_deref().unwrap_or("default"); - let path = cloud_config_file(profile)?; - let data = match fs::read(path).await { - Ok(data) => data, - Err(e) if e.kind() == io::ErrorKind::NotFound => { - let hint_cmd = if profile == "default" { - "edgedb cloud login".into() - } else { - format!("edgedb cloud login --cloud-profile {:?}", profile) - }; - return Err(NoCloudConfigFound::with_message( - "connecting cloud instance requires a secret key", - ) - .with_headers(HashMap::from([( - 0x_00_01, // FIELD_HINT - bytes::Bytes::from(format!( - "try `{}`, or provide a secret key to connect with", - hint_cmd - )), - )]))); - } - Err(e) => return Err(ClientError::with_source(e))?, - }; - let config: CloudConfig = from_slice(&data).map_err(ClientError::with_source)?; - config.secret_key - }; - let claims_b64 = secret_key - .split('.') - .nth(1) - .ok_or(ClientError::with_message("Illegal JWT token"))?; - let claims = base64::engine::general_purpose::URL_SAFE_NO_PAD - .decode(claims_b64) - .map_err(ClientError::with_source)?; - let claims: Claims = from_slice(&claims).map_err(ClientError::with_source)?; - let dns_zone = claims - .issuer - .ok_or(ClientError::with_message("Invalid secret key"))?; - let org_slug = org_slug.to_lowercase(); - let name = name.to_lowercase(); - let msg = format!("{}/{}", org_slug, name); - let checksum = crc16::State::::calculate(msg.as_bytes()); - let dns_bucket = format!("c-{:02}", checksum % 100); - cfg.address = Address::Tcp(( - format!("{}--{}.{}.i.{}", name, org_slug, dns_bucket, dns_zone), - DEFAULT_PORT, - )); - cfg.secret_key = Some(secret_key); - } - } - Ok(()) -} - -async fn read_credentials(cfg: &mut ConfigInner, path: impl AsRef) -> Result<(), Error> { - let path = path.as_ref(); - async { - let data = fs::read(path).await.map_err(ClientError::with_source)?; - let creds = serde_json::from_slice(&data).map_err(ClientError::with_source)?; - set_credentials(cfg, &creds)?; - Ok(()) - } - .await - .map_err(|e: Error| e.context(format!("cannot read credentials file {}", path.display())))?; - Ok(()) -} - -async fn read_certificates(path: impl AsRef) -> Result { - let data = fs::read_to_string(path.as_ref()) - .await - .map_err(|e| ClientError::with_source(e).context("error reading TLS CA file"))?; - validate_certs(&data).context("invalid certificates")?; - Ok(data) -} - -fn set_credentials(cfg: &mut ConfigInner, creds: &Credentials) -> Result<(), Error> { - if let Some(cert_data) = &creds.tls_ca { - validate_certs(cert_data).context("invalid certificates in `tls_ca`")?; - cfg.pem_certificates = Some(cert_data.into()); - } - cfg.address = Address::Tcp(( - creds.host.clone().unwrap_or_else(|| DEFAULT_HOST.into()), - creds.port, - )); - cfg.user.clone_from(&creds.user); - cfg.password.clone_from(&creds.password); - - if let Some((b, d)) = creds.branch.as_ref().zip(creds.database.as_ref()) { - if b != d { - return Err(ClientError::with_message( - "branch and database are mutually exclusive", - )); - } - } - let mut db_branch = creds.branch.as_ref().or(creds.database.as_ref()); - if creds.branch.is_none() && creds.database.as_ref().map_or(false, |d| d == "edgedb") { - db_branch = None; - } - cfg.database = db_branch.cloned().unwrap_or_else(|| "edgedb".into()); - cfg.branch = db_branch.cloned().unwrap_or_else(|| "__default__".into()); - cfg.tls_server_name = creds.tls_server_name.clone(); - cfg.tls_security = creds.tls_security; - cfg.creds_file_outdated = creds.file_outdated; - Ok(()) -} - -fn validate_certs(data: &str) -> Result<(), Error> { - let root_store = tls::read_root_cert_pem(data).map_err(ClientError::with_source_ref)?; - if root_store.is_empty() { - return Err(ClientError::with_message( - "PEM data contains no certificate", - )); - } - Ok(()) -} - -fn validate_host>(host: T) -> Result { - if host.as_ref().is_empty() { - return Err(InvalidArgumentError::with_message( - "invalid host: empty string", - )); - } else if host.as_ref().contains(',') { - return Err(InvalidArgumentError::with_message( - "invalid host: multiple hosts", - )); - } - Ok(host) -} - -fn validate_port(port: u16) -> Result { - if port == 0 { - return Err(InvalidArgumentError::with_message( - "invalid port: port cannot be zero", - )); - } - Ok(port) -} - -fn validate_branch>(branch: T) -> Result { - if branch.as_ref().is_empty() { - return Err(InvalidArgumentError::with_message( - "invalid branch: empty string", - )); - } - Ok(branch) -} - -fn validate_database>(database: T) -> Result { - if database.as_ref().is_empty() { - return Err(InvalidArgumentError::with_message( - "invalid database: empty string", - )); - } - Ok(database) -} - -fn validate_user>(user: T) -> Result { - if user.as_ref().is_empty() { - return Err(InvalidArgumentError::with_message( - "invalid user: empty string", - )); - } - Ok(user) -} - -impl Config { - /// A displayable form for an address this builder will connect to - pub fn display_addr(&self) -> impl fmt::Display + '_ { - DisplayAddr(Some(&self.0.address)) - } - - /// Is admin connection desired - #[cfg(feature = "admin_socket")] - pub fn admin(&self) -> bool { - self.0.admin - } - - /// User name - pub fn user(&self) -> &str { - &self.0.user - } - - /// Database name - pub fn database(&self) -> &str { - &self.0.database - } - - /// Database branch name - pub fn branch(&self) -> &str { - &self.0.branch - } - - /// Extract credentials from the [Builder] so they can be saved as JSON. - pub fn as_credentials(&self) -> Result { - let (host, port) = match &self.0.address { - Address::Tcp(pair) => pair, - Address::Unix(_) => { - return Err(ClientError::with_message( - "Unix socket address cannot \ - be saved as credentials file", - )); - } - }; - - Ok(Credentials { - host: Some(host.clone()), - port: *port, - user: self.0.user.clone(), - password: self.0.password.clone(), - branch: if self.0.branch == "__default__" { - None - } else { - Some(self.0.branch.clone()) - }, - - // this is not strictly needed (it gets overwritten when reading), - // but we want to keep backward compatibility. If you downgrade CLI, - // we want it to be able to interact with the new format of credentials. - database: Some(if self.0.branch == "__default__" { - "edgedb".into() - } else { - self.0.branch.clone() - }), - tls_ca: self.0.pem_certificates.clone(), - tls_security: self.0.tls_security, - tls_server_name: self.0.tls_server_name.clone(), - file_outdated: false, - }) - } - - /// Generate debug JSON string - #[cfg(feature = "unstable")] - pub fn to_json(&self) -> String { - serde_json::json!({ - "address": match &self.0.address { - Address::Tcp((host, port)) => serde_json::json!([host, port]), - Address::Unix(path) => serde_json::json!(path.to_str().unwrap()), - }, - "database": self.0.database, - "branch": self.0.branch, - "user": self.0.user, - "password": self.0.password, - "secretKey": self.0.secret_key, - "tlsCAData": self.0.pem_certificates, - "tlsSecurity": self.0.compute_tls_security().unwrap(), - "tlsServerName": self.0.tls_server_name, - "serverSettings": self.0.extra_dsn_query_args, - "waitUntilAvailable": self.0.wait.as_micros() as i64, - }) - .to_string() - } - - /// Server host name (if doesn't use unix socket) - pub fn host(&self) -> Option<&str> { - match self.0.address { - Address::Tcp((ref host, _)) => Some(host), - _ => None, - } - } - - /// Server port (if doesn't use unix socket) - pub fn port(&self) -> Option { - match self.0.address { - Address::Tcp((_, port)) => Some(port), - _ => None, - } - } - - /// Instance name if set and if it's local - pub fn local_instance_name(&self) -> Option<&str> { - match self.0.instance_name { - Some(InstanceName::Local(ref name)) => Some(name), - _ => None, - } - } - - /// Name of the instance if set - pub fn instance_name(&self) -> Option<&InstanceName> { - self.0.instance_name.as_ref() - } - - /// Secret key if set - pub fn secret_key(&self) -> Option<&str> { - self.0.secret_key.as_deref() - } - - /// Return HTTP(s) url to server - /// - /// If not connected via unix socket - pub fn http_url(&self, tls: bool) -> Option { - match &self.0.address { - Address::Tcp((host, port)) => { - let s = if tls { "s" } else { "" }; - Some(format!("http{}://{}:{}", s, host, port)) - } - Address::Unix(_) => None, - } - } - - fn _get_unix_path(&self) -> Result, Error> { - match &self.0.address { - Address::Unix(path) => Ok(Some(path.clone())), - Address::Tcp(_) => Ok(None), - } - } - - /// Return the same config with changed password - pub fn with_password(mut self, password: &str) -> Config { - Arc::make_mut(&mut self.0).password = Some(password.to_owned()); - self - } - - /// Return the same config with changed database - pub fn with_database(mut self, database: &str) -> Result { - if database.is_empty() { - return Err(InvalidArgumentError::with_message( - "invalid database: empty string", - )); - } - database.clone_into(&mut Arc::make_mut(&mut self.0).database); - Ok(self) - } - - /// Return the same config with changed database branch - pub fn with_branch(mut self, branch: &str) -> Result { - if branch.is_empty() { - return Err(InvalidArgumentError::with_message( - "invalid branch: empty string", - )); - } - branch.clone_into(&mut Arc::make_mut(&mut self.0).branch); - Ok(self) - } - - /// Return the same config with changed wait until available timeout - #[cfg(any(feature = "unstable", test))] - pub fn with_wait_until_available(mut self, wait: Duration) -> Config { - Arc::make_mut(&mut self.0).wait = wait; - self - } - - /// Return the same config with changed certificates - #[cfg(any(feature = "unstable", test))] - pub fn with_pem_certificates(mut self, pem: &str) -> Result { - validate_certs(pem).context("invalid PEM certificate")?; - let cfg = Arc::make_mut(&mut self.0); - cfg.pem_certificates = Some(pem.to_owned()); - cfg.verifier = cfg.make_verifier(cfg.compute_tls_security()?); - Ok(self) - } - - #[cfg(feature = "admin_socket")] - pub fn with_unix_path(mut self, path: &Path) -> Config { - Arc::make_mut(&mut self.0).address = Address::Unix(path.into()); - self - } - - /// Returns true if credentials file is in outdated format - #[cfg(any(feature = "unstable", test))] - pub fn is_creds_file_outdated(&self) -> bool { - self.0.creds_file_outdated - } - - /// Return the certificate store of the config - #[cfg(any(feature = "unstable", test))] - pub fn root_cert_store(&self) -> Result { - Ok(self.0.root_cert_store()) - } - - /// Return the same config with changed certificate verifier - /// - /// Command-line tool uses this for interactive verifier - #[cfg(any(feature = "unstable", test))] - pub fn with_cert_verifier(mut self, verifier: Verifier) -> Config { - Arc::make_mut(&mut self.0).verifier = verifier; - self - } -} - -impl ConfigInner { - fn compute_tls_security(&self) -> Result { - use TlsSecurity::*; - - match (self.client_security, self.tls_security) { - (ClientSecurity::Strict, Insecure | NoHostVerification) => { - Err(ClientError::with_message(format!( - "client_security=strict and tls_security={} don't comply", - self.tls_security, - ))) - } - (ClientSecurity::Strict, _) => Ok(Strict), - (ClientSecurity::InsecureDevMode, Default) => Ok(Insecure), - (_, Default) if self.pem_certificates.is_none() => Ok(Strict), - (_, Default) => Ok(NoHostVerification), - (_, ts) => Ok(ts), - } - } - fn root_cert_store(&self) -> rustls::RootCertStore { - if self.pem_certificates.is_some() { - tls::read_root_cert_pem(self.pem_certificates.as_deref().unwrap_or("")) - .expect("all certificates have been verified previously") - } else { - let mut root_store = rustls::RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.into(), - }; - if let Some(certs) = self.cloud_certs { - root_store.extend( - tls::read_root_cert_pem(certs.root()) - .expect("embedded certs are correct") - .roots, - ); - } - - root_store - } - } - fn make_verifier(&self, tls_security: TlsSecurity) -> Verifier { - use TlsSecurity::*; - - let root_store = Arc::new(self.root_cert_store()); - - match tls_security { - Insecure => Arc::new(tls::NullVerifier) as Verifier, - NoHostVerification => Arc::new(tls::NoHostnameVerifier::new(root_store)) as Verifier, - Strict => { - let cryto_provider = crypto::CryptoProvider::get_default() - .cloned() - .unwrap_or_else(|| Arc::new(crypto::ring::default_provider())); - - rustls::client::WebPkiServerVerifier::builder_with_provider( - root_store, - cryto_provider, - ) - .build() - .expect("WebPkiServerVerifier to build correctly") as Verifier - } - Default => unreachable!(), - } - } -} - -impl fmt::Debug for Config { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Config") - .field("address", &self.0.address) - .field("max_concurrency", &self.0.max_concurrency) - // TODO(tailhook) more fields - .finish() - } -} - -impl FromStr for ClientSecurity { - type Err = Error; - fn from_str(s: &str) -> Result { - use ClientSecurity::*; - - match s { - "default" => Ok(Default), - "strict" => Ok(Strict), - "insecure_dev_mode" => Ok(InsecureDevMode), - mode => Err(ClientError::with_message(format!( - "Invalid client security: {:?}. \ - Options: default, strict, insecure_dev_mode.", - mode - ))), - } - } -} - -impl FromStr for CloudCerts { - type Err = Error; - fn from_str(s: &str) -> Result { - use CloudCerts::*; - - match s { - "staging" => Ok(Staging), - "local" => Ok(Local), - option => Err(ClientError::with_message(format!( - "Invalid cloud certificates: {:?}. \ - Options: staging, local.", - option - ))), - } - } -} - -/// Searches for a project file either from the current directory or from a -/// specified directory, optionally searching parent directories. -pub async fn get_project_path( - override_dir: Option<&Path>, - search_parents: bool, -) -> Result, Error> { - let dir = - match override_dir { - Some(v) => Cow::Borrowed(v), - None => Cow::Owned(env::current_dir().map_err(|e| { - ClientError::with_source(e).context("failed to get current directory") - })?), - }; - - search_dir(&dir, search_parents).await -} - -async fn search_dir(base: &Path, search_parents: bool) -> Result, Error> { - let mut path = base; - loop { - let mut found = vec![]; - for name in PROJECT_FILES { - let file = path.join(name); - match fs::metadata(&file).await { - Ok(_) => found.push(file), - Err(e) if e.kind() == io::ErrorKind::NotFound => {} - Err(e) => return Err(ClientError::with_source(e)), - } - } - - // Future note: we allow multiple configuration files to be found in one - // folder but you must ensure that they contain the same contents - // (either via copy or symlink). - if found.len() > 1 { - let (a, b) = found.split_at(1); - let a = &a[0]; - let s = fs::read_to_string(a) - .await - .map_err(|e| ClientError::with_source(e).context("failed to read file"))?; - for file in b { - if fs::read_to_string(file) - .await - .map_err(|e| ClientError::with_source(e).context("failed to read file"))? - != s - { - return Err(ClientError::with_message(format!( - "{:?} and {:?} found in {base:?} but the contents are different", - a.file_name(), - file.file_name() - ))); - } - } - return Ok(Some(found.into_iter().next().unwrap())); - } else if let Some(path) = found.pop() { - // Found just one - return Ok(Some(path)); - } - - if !search_parents { - break; - } - if let Some(parent) = path.parent() { - path = parent; - } else { - break; - } - } - Ok(None) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_project_file_priority() { - let temp = tempfile::tempdir().unwrap(); - let base = temp.path(); - - let gel_path = base.join("gel.toml"); - let edgedb_path = base.join("edgedb.toml"); - - // Test gel.toml only - fs::write(&gel_path, "test1").await.unwrap(); - let found = get_project_path(Some(base), false).await.unwrap().unwrap(); - assert_eq!(found, gel_path); - - // Test edgedb.toml only - fs::remove_file(&gel_path).await.unwrap(); - fs::write(&edgedb_path, "test2").await.unwrap(); - let found = get_project_path(Some(base), false).await.unwrap().unwrap(); - assert_eq!(found, edgedb_path); - - // Test both files with same content - fs::write(&gel_path, "test3").await.unwrap(); - fs::write(&edgedb_path, "test3").await.unwrap(); - let found = get_project_path(Some(base), false).await.unwrap().unwrap(); - assert_eq!(found, gel_path); - - // Test both files with different content - fs::write(&gel_path, "test4").await.unwrap(); - fs::write(&edgedb_path, "test5").await.unwrap(); - let err = get_project_path(Some(base), false).await.unwrap_err(); - assert!(err.to_string().contains("but the contents are different")); - } - - #[tokio::test] - async fn test_read_credentials() { - let cfg = Builder::new() - .credentials_file("tests/credentials1.json") - .build_env() - .await - .unwrap(); - assert!(matches!(&cfg.0.address, Address::Tcp((_, 10702)))); - assert_eq!(&cfg.0.user, "test3n"); - assert_eq!(&cfg.0.database, "test3n"); - assert_eq!(cfg.0.password, Some("lZTBy1RVCfOpBAOwSCwIyBIR".into())); - } - - #[tokio::test] - async fn display() { - let dsn_schemes = ["edgedb", "edgedbadmin", "gel"]; - for dsn_scheme in dsn_schemes { - let cfg = Builder::new() - .dsn(&format!("{dsn_scheme}://localhost:1756")) - .unwrap() - .build_env() - .await - .unwrap(); - assert!(matches!( - &cfg.0.address, - Address::Tcp((host, 1756)) if host == "localhost" - )); - /* TODO(tailhook) - bld.unix_path("/test/my.sock"); - assert_eq!(bld.build().unwrap()._get_unix_path().unwrap(), - Some("/test/my.sock/.s.EDGEDB.5656".into())); - */ - #[cfg(feature = "admin_socket")] - { - let cfg = Builder::new() - .unix_path("/test/.s.EDGEDB.8888") - .build_env() - .await - .unwrap(); - assert_eq!( - cfg._get_unix_path().unwrap(), - Some("/test/.s.EDGEDB.8888".into()) - ); - let cfg = Builder::new() - .port(8888) - .unwrap() - .unix_path("/test") - .build_env() - .await - .unwrap(); - assert_eq!( - cfg._get_unix_path().unwrap(), - Some("/test/.s.EDGEDB.8888".into()) - ); - } - } - } - - #[tokio::test] - async fn from_dsn() { - let dsn_schemes = ["edgedb", "edgedbadmin", "gel"]; - for dsn_scheme in dsn_schemes { - let cfg = Builder::new() - .dsn(&format!("{dsn_scheme}://user1:EiPhohl7@edb-0134.elb.us-east-2.amazonaws.com/db2")) - .unwrap() - .build_env() - .await - .unwrap(); - assert!(matches!( - &cfg.0.address, - Address::Tcp((host, 5656)) - if host == "edb-0134.elb.us-east-2.amazonaws.com", - )); - assert_eq!(&cfg.0.user, "user1"); - assert_eq!(&cfg.0.database, "db2"); - assert_eq!(&cfg.0.branch, "db2"); - assert_eq!(cfg.0.password, Some("EiPhohl7".into())); - - let cfg = Builder::new() - .dsn(&format!("{dsn_scheme}://user2@edb-0134.elb.us-east-2.amazonaws.com:1756/db2")) - .unwrap() - .build_env() - .await - .unwrap(); - assert!(matches!( - &cfg.0.address, - Address::Tcp((host, 1756)) - if host == "edb-0134.elb.us-east-2.amazonaws.com", - )); - assert_eq!(&cfg.0.user, "user2"); - assert_eq!(&cfg.0.database, "db2"); - assert_eq!(&cfg.0.branch, "db2"); - assert_eq!(cfg.0.password, None); - - // Tests overriding - let cfg = Builder::new() - .dsn(&format!("{dsn_scheme}://edb-0134.elb.us-east-2.amazonaws.com:1756")) - .unwrap() - .build_env() - .await - .unwrap(); - assert!(matches!( - &cfg.0.address, - Address::Tcp((host, 1756)) - if host == "edb-0134.elb.us-east-2.amazonaws.com", - )); - assert_eq!(&cfg.0.user, "edgedb"); - assert_eq!(&cfg.0.database, "edgedb"); - assert_eq!(&cfg.0.branch, "__default__"); - assert_eq!(cfg.0.password, None); - - let cfg = Builder::new() - .dsn(&format!("{dsn_scheme}://user3:123123@[::1]:5555/abcdef")) - .unwrap() - .build_env() - .await - .unwrap(); - assert!(matches!( - &cfg.0.address, - Address::Tcp((host, 5555)) if host == "::1", - )); - assert_eq!(&cfg.0.user, "user3"); - assert_eq!(&cfg.0.database, "abcdef"); - assert_eq!(&cfg.0.branch, "abcdef"); - assert_eq!(cfg.0.password, Some("123123".into())); - } - } - - #[tokio::test] - #[should_panic] // servo/rust-url#424 - async fn from_dsn_ipv6_scoped_address() { - let dsn_schemes = ["edgedb", "edgedbadmin", "gel"]; - for dsn_scheme in dsn_schemes { - let cfg = Builder::new() - .dsn(&format!("{dsn_scheme}://user3@[fe80::1ff:fe23:4567:890a%25eth0]:3000/ab")) - .unwrap() - .build_env() - .await - .unwrap(); - assert!(matches!( - &cfg.0.address, - Address::Tcp((host, 3000)) if host == "fe80::1ff:fe23:4567:890a%eth0", - )); - assert_eq!(&cfg.0.user, "user3"); - assert_eq!(&cfg.0.database, "ab"); - assert_eq!(cfg.0.password, None); - } - } - - #[test] - fn test_instance_name() { - for inst_name in [ - "abc", - "_localdev", - "123", - "___", - "12345678901234567890123456789012345678901234567890123456789012345678901234567890", - "abc-123", - "a-b-c_d-e-f", - "_-_-_-_", - "abc/def", - "123/456", - "abc-123/def-456", - "123-abc/456-def", - "a-b-c/1-2-3", - "-leading-dash/abc", - "_leading-underscore/abc", - "under_score/abc", - "-vicfg-hceTeOuz6iXr3vkXPf0Wsudd/test123", - ] { - match InstanceName::from_str(inst_name) { - Ok(InstanceName::Local(name)) => assert_eq!(name, inst_name), - Ok(InstanceName::Cloud { org_slug, name }) => { - let (o, i) = inst_name - .split_once('/') - .expect("test case must have one slash"); - assert_eq!(org_slug, o); - assert_eq!(name, i); - } - Err(e) => panic!("{:#}", e), - } - } - for name in [ - "", - "-leading-dash", - "trailing-dash-", - "double--dash", - "trailing-dash-/abc", - "double--dash/abc", - "abc/-leading-dash", - "abc/trailing-dash-", - "abc/double--dash", - "abc/_localdev", - "123/45678901234567890123456789012345678901234567890123456789012345678901234567890", - ] { - assert!( - InstanceName::from_str(name).is_err(), - "unexpected success: {}", - name - ); - } - } -} diff --git a/edgedb-tokio/src/client.rs b/edgedb-tokio/src/client.rs deleted file mode 100644 index 957944e6..00000000 --- a/edgedb-tokio/src/client.rs +++ /dev/null @@ -1,613 +0,0 @@ -use std::future::Future; -use std::sync::Arc; - -use edgedb_protocol::common::{Capabilities, Cardinality, IoFormat}; -use edgedb_protocol::model::Json; -use edgedb_protocol::query_arg::QueryArgs; -use edgedb_protocol::QueryResult; -use tokio::time::sleep; - -use crate::builder::Config; -use crate::errors::InvalidArgumentError; -use crate::errors::NoDataError; -use crate::errors::{Error, ErrorKind, SHOULD_RETRY}; -use crate::options::{RetryOptions, TransactionOptions}; -use crate::raw::{Options, PoolState, Response}; -use crate::raw::{Pool, QueryCapabilities}; -use crate::state::{AliasesDelta, ConfigDelta, GlobalsDelta}; -use crate::state::{AliasesModifier, ConfigModifier, Fn, GlobalsModifier}; -use crate::transaction::{transaction, Transaction}; -use crate::ResultVerbose; - -/// The EdgeDB Client. -/// -/// Internally it contains a connection pool. -/// -/// To create a client, use [`create_client`](crate::create_client) function (it -/// gets database connection configuration from environment). You can also use -/// [`Builder`](crate::Builder) to [`build`](`crate::Builder::new`) custom -/// [`Config`] and [create a client](Client::new) using that config. -/// -/// The `with_` methods ([`with_retry_options`](crate::Client::with_retry_options), [`with_transaction_options`](crate::Client::with_transaction_options), etc.) -/// let you create a shallow copy of the client with adjusted options. -#[derive(Debug, Clone)] -pub struct Client { - options: Arc, - pool: Pool, -} - -impl Client { - /// Create a new connection pool. - /// - /// Note this does not create a connection immediately. - /// Use [`ensure_connected()`][Client::ensure_connected] to establish a - /// connection and verify that the connection is usable. - pub fn new(config: &Config) -> Client { - Client { - options: Default::default(), - pool: Pool::new(config), - } - } - - /// Ensure that there is at least one working connection to the pool. - /// - /// This can be used at application startup to ensure that you have a - /// working connection. - pub async fn ensure_connected(&self) -> Result<(), Error> { - self.pool.acquire().await?; - Ok(()) - } - - /// Query with retry. - async fn query_helper( - &self, - query: impl AsRef, - arguments: &A, - io_format: IoFormat, - cardinality: Cardinality, - ) -> Result>, Error> - where - A: QueryArgs, - R: QueryResult, - { - let mut iteration = 0; - loop { - let mut conn = self.pool.acquire().await?; - - let conn = conn.inner(); - let state = &self.options.state; - let caps = Capabilities::MODIFICATIONS | Capabilities::DDL; - match conn - .query( - query.as_ref(), - arguments, - state, - &self.options.annotations, - caps, - io_format, - cardinality, - ) - .await - { - Ok(resp) => return Ok(resp), - Err(e) => { - let allow_retry = match e.get::() { - // Error from a weird source, or just a bug - // Let's keep on the safe side - None => false, - Some(QueryCapabilities::Unparsed) => true, - Some(QueryCapabilities::Parsed(c)) => c.is_empty(), - }; - if allow_retry && e.has_tag(SHOULD_RETRY) { - let rule = self.options.retry.get_rule(&e); - iteration += 1; - if iteration < rule.attempts { - let duration = (rule.backoff)(iteration); - log::info!("Error: {:#}. Retrying in {:?}...", e, duration); - sleep(duration).await; - continue; - } - } - return Err(e); - } - } - } - } - - /// Execute a query and return a collection of results and warnings produced by the server. - /// - /// You will usually have to specify the return type for the query: - /// - /// ```rust,ignore - /// let greeting: (Vec, _) = conn.query_with_warnings("select 'hello'", &()).await?; - /// ``` - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn query_verbose( - &self, - query: impl AsRef + Send, - arguments: &A, - ) -> Result>, Error> - where - A: QueryArgs, - R: QueryResult, - { - Client::query_helper(self, query, arguments, IoFormat::Binary, Cardinality::Many) - .await - .map(|Response { data, warnings, .. }| ResultVerbose { data, warnings }) - } - - /// Execute a query and return a collection of results. - /// - /// You will usually have to specify the return type for the query: - /// - /// ```rust,ignore - /// let greeting = pool.query::("SELECT 'hello'", &()); - /// // or - /// let greeting: Vec = pool.query("SELECT 'hello'", &()); - /// - /// let two_numbers: Vec = conn.query("select {$0, $1}", &(10, 20)).await?; - /// ``` - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn query( - &self, - query: impl AsRef + Send, - arguments: &A, - ) -> Result, Error> - where - A: QueryArgs, - R: QueryResult, - { - Client::query_helper(self, query, arguments, IoFormat::Binary, Cardinality::Many) - .await - .map(|r| r.data) - } - - /// Execute a query and return a single result - /// - /// You will usually have to specify the return type for the query: - /// - /// ```rust,ignore - /// let greeting = pool.query_single::("SELECT 'hello'", &()); - /// // or - /// let greeting: Option = pool.query_single( - /// "SELECT 'hello'", - /// &() - /// ); - /// ``` - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn query_single( - &self, - query: impl AsRef + Send, - arguments: &A, - ) -> Result, Error> - where - A: QueryArgs, - R: QueryResult + Send, - { - Client::query_helper( - self, - query, - arguments, - IoFormat::Binary, - Cardinality::AtMostOne, - ) - .await - .map(|x| x.data.into_iter().next()) - } - - /// Execute a query and return a single result - /// - /// The query must return exactly one element. If the query returns more - /// than one element, a - /// [`ResultCardinalityMismatchError`][crate::errors::ResultCardinalityMismatchError] - /// is raised. If the query returns an empty set, a - /// [`NoDataError`][crate::errors::NoDataError] is raised. - /// - /// You will usually have to specify the return type for the query: - /// - /// ```rust,ignore - /// let greeting = pool.query_required_single::( - /// "SELECT 'hello'", - /// &(), - /// ); - /// // or - /// let greeting: String = pool.query_required_single( - /// "SELECT 'hello'", - /// &(), - /// ); - /// ``` - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn query_required_single( - &self, - query: impl AsRef + Send, - arguments: &A, - ) -> Result - where - A: QueryArgs, - R: QueryResult + Send, - { - Client::query_helper( - self, - query, - arguments, - IoFormat::Binary, - Cardinality::AtMostOne, - ) - .await - .and_then(|x| { - x.data - .into_iter() - .next() - .ok_or_else(|| NoDataError::with_message("query row returned zero results")) - }) - } - - /// Execute a query and return the result as JSON. - pub async fn query_json( - &self, - query: impl AsRef, - arguments: &impl QueryArgs, - ) -> Result { - let res = self - .query_helper::(query, arguments, IoFormat::Json, Cardinality::Many) - .await?; - - let json = res - .data - .into_iter() - .next() - .ok_or_else(|| NoDataError::with_message("query row returned zero results"))?; - - // we trust database to produce valid json - Ok(Json::new_unchecked(json)) - } - - /// Execute a query and return a single result as JSON. - /// - /// The query must return exactly one element. If the query returns more - /// than one element, a - /// [`ResultCardinalityMismatchError`][crate::errors::ResultCardinalityMismatchError] - /// is raised. - /// - /// ```rust,ignore - /// let query = "select ( - /// insert Account { - /// username := $0 - /// }) { - /// username, - /// id - /// };"; - /// let json_res: Option = client - /// .query_single_json(query, &("SomeUserName",)) - /// .await?; - /// ``` - pub async fn query_single_json( - &self, - query: impl AsRef, - arguments: &impl QueryArgs, - ) -> Result, Error> { - let res = self - .query_helper::(query, arguments, IoFormat::Json, Cardinality::AtMostOne) - .await?; - - // we trust database to produce valid json - Ok(res.data.into_iter().next().map(Json::new_unchecked)) - } - - /// Execute a query and return a single result as JSON. - /// - /// The query must return exactly one element. If the query returns more - /// than one element, a - /// [`ResultCardinalityMismatchError`][crate::errors::ResultCardinalityMismatchError] - /// is raised. If the query returns an empty set, a - /// [`NoDataError`][crate::errors::NoDataError] is raised. - pub async fn query_required_single_json( - &self, - query: impl AsRef, - arguments: &impl QueryArgs, - ) -> Result { - self.query_single_json(query, arguments) - .await? - .ok_or_else(|| NoDataError::with_message("query row returned zero results")) - } - - /// Execute a query and don't expect result - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn execute(&self, query: impl AsRef, arguments: &A) -> Result<(), Error> - where - A: QueryArgs, - { - let mut iteration = 0; - loop { - let mut conn = self.pool.acquire().await?; - - let conn = conn.inner(); - let state = &self.options.state; - let caps = Capabilities::MODIFICATIONS | Capabilities::DDL; - match conn - .execute( - query.as_ref(), - arguments, - state, - &self.options.annotations, - caps, - ) - .await - { - Ok(_) => return Ok(()), - Err(e) => { - let allow_retry = match e.get::() { - // Error from a weird source, or just a bug - // Let's keep on the safe side - None => false, - Some(QueryCapabilities::Unparsed) => true, - Some(QueryCapabilities::Parsed(c)) => c.is_empty(), - }; - if allow_retry && e.has_tag(SHOULD_RETRY) { - let rule = self.options.retry.get_rule(&e); - iteration += 1; - if iteration < rule.attempts { - let duration = (rule.backoff)(iteration); - log::info!("Error: {:#}. Retrying in {:?}...", e, duration); - sleep(duration).await; - continue; - } - } - return Err(e); - } - } - } - } - - /// Execute a transaction - /// - /// Transaction body must be encompassed in the closure. The closure **may - /// be executed multiple times**. This includes not only database queries - /// but also executing the whole function, so the transaction code must be - /// prepared to be idempotent. - /// - /// # Returning custom errors - /// - /// See [this example](https://github.com/edgedb/edgedb-rust/blob/master/edgedb-tokio/examples/transaction_errors.rs) - /// and [the documentation of the `edgedb_errors` crate](https://docs.rs/edgedb-errors/latest/edgedb_errors/) - /// for how to return custom error types. - /// - /// # Panics - /// - /// Function panics when transaction object passed to the closure is not - /// dropped after closure exists. General rule: do not store transaction - /// anywhere and do not send to another coroutine. Pass to all further - /// function calls by reference. - /// - /// # Example - /// - /// ```rust,no_run - /// # async fn transaction() -> Result<(), edgedb_tokio::Error> { - /// let conn = edgedb_tokio::create_client().await?; - /// let val = conn.transaction(|mut tx| async move { - /// tx.query_required_single::(" - /// WITH C := UPDATE Counter SET { value := .value + 1} - /// SELECT C.value LIMIT 1 - /// ", &() - /// ).await - /// }).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn transaction(&self, body: B) -> Result - where - B: FnMut(Transaction) -> F, - F: Future>, - { - transaction(&self.pool, self.options.clone(), body).await - } - - /// Returns client with adjusted options for future transactions. - /// - /// This method returns a "shallow copy" of the current client - /// with modified transaction options. - /// - /// Both ``self`` and returned client can be used after, but when using - /// them transaction options applied will be different. - /// - /// Transaction options are used by the ``transaction`` method. - pub fn with_transaction_options(&self, options: TransactionOptions) -> Self { - Client { - options: Arc::new(Options { - transaction: options, - retry: self.options.retry.clone(), - state: self.options.state.clone(), - annotations: self.options.annotations.clone(), - }), - pool: self.pool.clone(), - } - } - /// Returns client with adjusted options for future retrying - /// transactions. - /// - /// This method returns a "shallow copy" of the current client - /// with modified transaction options. - /// - /// Both ``self`` and returned client can be used after, but when using - /// them transaction options applied will be different. - pub fn with_retry_options(&self, options: RetryOptions) -> Self { - Client { - options: Arc::new(Options { - transaction: self.options.transaction.clone(), - retry: options, - state: self.options.state.clone(), - annotations: self.options.annotations.clone(), - }), - pool: self.pool.clone(), - } - } - - fn with_state(&self, f: impl FnOnce(&PoolState) -> PoolState) -> Self { - Client { - options: Arc::new(Options { - transaction: self.options.transaction.clone(), - retry: self.options.retry.clone(), - state: Arc::new(f(&self.options.state)), - annotations: self.options.annotations.clone(), - }), - pool: self.pool.clone(), - } - } - - /// Returns the client with the specified global variables set - /// - /// Most commonly used with `#[derive(GlobalsDelta)]`. - /// - /// Note: this method is incremental, i.e. it adds (or removes) globals - /// instead of setting a definite set of variables. Use - /// `.with_globals(Unset(["name1", "name2"]))` to unset some variables. - /// - /// This method returns a "shallow copy" of the current client - /// with modified global variables - /// - /// Both ``self`` and returned client can be used after, but when using - /// them transaction options applied will be different. - pub fn with_globals(&self, globals: impl GlobalsDelta) -> Self { - self.with_state(|s| s.with_globals(globals)) - } - - /// Returns the client with the specified global variables set - /// - /// This method returns a "shallow copy" of the current client - /// with modified global variables - /// - /// Both ``self`` and returned client can be used after, but when using - /// them transaction options applied will be different. - /// - /// This is equivalent to `.with_globals(Fn(f))` but more ergonomic as it - /// allows type inference for lambda. - pub fn with_globals_fn(&self, f: impl FnOnce(&mut GlobalsModifier)) -> Self { - self.with_state(|s| s.with_globals(Fn(f))) - } - - /// Returns the client with the specified aliases set - /// - /// This method returns a "shallow copy" of the current client - /// with modified aliases. - /// - /// Both ``self`` and returned client can be used after, but when using - /// them transaction options applied will be different. - pub fn with_aliases(&self, aliases: impl AliasesDelta) -> Self { - self.with_state(|s| s.with_aliases(aliases)) - } - - /// Returns the client with the specified aliases set - /// - /// This method returns a "shallow copy" of the current client - /// with modified aliases. - /// - /// Both ``self`` and returned client can be used after, but when using - /// them transaction options applied will be different. - /// - /// This is equivalent to `.with_aliases(Fn(f))` but more ergonomic as it - /// allows type inference for lambda. - pub fn with_aliases_fn(&self, f: impl FnOnce(&mut AliasesModifier)) -> Self { - self.with_state(|s| s.with_aliases(Fn(f))) - } - - /// Returns the client with the default module set or unset - /// - /// This method returns a "shallow copy" of the current client - /// with modified default module. - /// - /// Both ``self`` and returned client can be used after, but when using - /// them transaction options applied will be different. - pub fn with_default_module(&self, module: Option>) -> Self { - self.with_state(|s| s.with_default_module(module.map(|m| m.into()))) - } - - /// Returns the client with the specified config - /// - /// Note: this method is incremental, i.e. it adds (or removes) individual - /// settings instead of setting a definite configuration. Use - /// `.with_config(Unset(["name1", "name2"]))` to unset some settings. - /// - /// This method returns a "shallow copy" of the current client - /// with modified global variables - /// - /// Both ``self`` and returned client can be used after, but when using - /// them transaction options applied will be different. - pub fn with_config(&self, cfg: impl ConfigDelta) -> Self { - self.with_state(|s| s.with_config(cfg)) - } - - /// Returns the client with the specified config - /// - /// Most commonly used with `#[derive(ConfigDelta)]`. - /// - /// This method returns a "shallow copy" of the current client - /// with modified global variables - /// - /// Both ``self`` and returned client can be used after, but when using - /// them transaction options applied will be different. - /// - /// This is equivalent to `.with_config(Fn(f))` but more ergonomic as it - /// allows type inference for lambda. - pub fn with_config_fn(&self, f: impl FnOnce(&mut ConfigModifier)) -> Self { - self.with_state(|s| s.with_config(Fn(f))) - } - - /// Returns the client with the specified query tag. - /// - /// This method returns a "shallow copy" of the current client - /// with modified query tag. - /// - /// Both ``self`` and returned client can be used after, but when using - /// them query tag applied will be different. - pub fn with_tag(&self, tag: Option<&str>) -> Result { - const KEY: &str = "tag"; - - let annotations = if self.options.annotations.get(KEY).map(|s| s.as_str()) != tag { - let mut annotations = (*self.options.annotations).clone(); - if let Some(tag) = tag { - if tag.starts_with("edgedb/") { - return Err(InvalidArgumentError::with_message("reserved tag: edgedb/*")); - } - if tag.starts_with("gel/") { - return Err(InvalidArgumentError::with_message("reserved tag: gel/*")); - } - if tag.len() > 128 { - return Err(InvalidArgumentError::with_message( - "tag too long (> 128 bytes)", - )); - } - annotations.insert(KEY.to_string(), tag.to_string()); - } else { - annotations.remove(KEY); - } - Arc::new(annotations) - } else { - self.options.annotations.clone() - }; - - Ok(Client { - options: Arc::new(Options { - transaction: self.options.transaction.clone(), - retry: self.options.retry.clone(), - state: self.options.state.clone(), - annotations, - }), - pool: self.pool.clone(), - }) - } -} diff --git a/edgedb-tokio/src/credentials.rs b/edgedb-tokio/src/credentials.rs deleted file mode 100644 index 3e0db8a8..00000000 --- a/edgedb-tokio/src/credentials.rs +++ /dev/null @@ -1,212 +0,0 @@ -//! Credentials file handling routines -use std::fmt; -use std::str::FromStr; - -use serde::{ser, Deserialize, Serialize}; - -use crate::errors::{Error, ErrorKind}; - -/// TLS Client Security Mode -#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum TlsSecurity { - /// Allow any certificate for TLS connection - Insecure, - /// Verify certificate against trusted chain but allow any host name - /// - /// This is useful for localhost (you can't make trusted chain certificate - /// for localhost). And when certificate of specific server is stored in - /// credentials file so it's secure regardless of which host name was used - /// to expose the server to the network. - NoHostVerification, - /// Normal TLS certificate check (checks trusted chain and hostname) - Strict, - /// If there is a specific certificate in credentials, do not check - /// the host name, otherwise use `Strict` mode - #[default] - Default, -} - -/// A structure that represents the contents of the credentials file. -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct Credentials { - pub host: Option, - pub port: u16, - pub user: String, - pub password: Option, - pub database: Option, - pub branch: Option, - pub tls_ca: Option, - pub tls_security: TlsSecurity, - pub tls_server_name: Option, - pub(crate) file_outdated: bool, -} - -#[derive(Serialize, Deserialize)] -struct CredentialsCompat { - #[serde(default, skip_serializing_if = "Option::is_none")] - host: Option, - #[serde(default = "default_port")] - port: u16, - user: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - password: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - database: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - branch: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - tls_cert_data: Option, // deprecated - #[serde(default, skip_serializing_if = "Option::is_none")] - tls_ca: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - tls_server_name: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - tls_verify_hostname: Option, // deprecated - tls_security: Option, -} - -fn default_port() -> u16 { - 5656 -} - -impl FromStr for TlsSecurity { - type Err = Error; - fn from_str(val: &str) -> Result { - match val { - "default" => Ok(TlsSecurity::Default), - "insecure" => Ok(TlsSecurity::Insecure), - "no_host_verification" => Ok(TlsSecurity::NoHostVerification), - "strict" => Ok(TlsSecurity::Strict), - val => Err(crate::errors::ClientError::with_message(format!( - "Invalid value {:?}. \ - Options: default, insecure, no_host_verification, strict.", - val, - ))), - } - } -} - -impl fmt::Display for TlsSecurity { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.as_str().fmt(f) - } -} - -impl TlsSecurity { - fn as_str(&self) -> &'static str { - use TlsSecurity::*; - - match self { - Default => "default", - Insecure => "insecure", - NoHostVerification => "no_host_verification", - Strict => "strict", - } - } -} - -impl Default for Credentials { - fn default() -> Credentials { - Credentials { - host: None, - port: 5656, - user: "edgedb".into(), - password: None, - database: None, - branch: None, - tls_ca: None, - tls_server_name: None, - tls_security: TlsSecurity::Default, - file_outdated: false, - } - } -} - -impl Serialize for Credentials { - fn serialize(&self, serializer: S) -> Result - where - S: ser::Serializer, - { - let creds = CredentialsCompat { - host: self.host.clone(), - port: self.port, - user: self.user.clone(), - password: self.password.clone(), - database: self.database.clone(), - branch: self.branch.clone(), - tls_ca: self.tls_ca.clone(), - tls_server_name: self.tls_server_name.clone(), - tls_cert_data: self.tls_ca.clone(), - tls_security: Some(self.tls_security), - tls_verify_hostname: match self.tls_security { - TlsSecurity::Default => None, - TlsSecurity::Strict => Some(true), - TlsSecurity::NoHostVerification => Some(false), - TlsSecurity::Insecure => Some(false), - }, - }; - - CredentialsCompat::serialize(&creds, serializer) - } -} - -#[cfg(feature = "fs")] -impl<'de> Deserialize<'de> for Credentials { - fn deserialize(deserializer: D) -> Result - where - D: serde::de::Deserializer<'de>, - { - let creds = CredentialsCompat::deserialize(deserializer)?; - let expected_verify = match creds.tls_security { - Some(TlsSecurity::Strict) => Some(true), - Some(TlsSecurity::NoHostVerification) => Some(false), - Some(TlsSecurity::Insecure) => Some(false), - _ => None, - }; - if creds.tls_verify_hostname.is_some() - && creds.tls_security.is_some() - && expected_verify - .zip(creds.tls_verify_hostname) - .map(|(creds, expected)| creds != expected) - .unwrap_or(false) - { - Err(serde::de::Error::custom(format!( - "detected conflicting settings. \ - \ntls_security =\n{}\nbut tls_verify_hostname =\n{}", - serde_json::to_string(&creds.tls_security).map_err(serde::de::Error::custom)?, - serde_json::to_string(&creds.tls_verify_hostname) - .map_err(serde::de::Error::custom)?, - ))) - } else if creds.tls_ca.is_some() - && creds.tls_cert_data.is_some() - && creds.tls_ca != creds.tls_cert_data - { - Err(serde::de::Error::custom(format!( - "detected conflicting settings. \ - \ntls_ca =\n{:#?}\nbut tls_cert_data =\n{:#?}", - creds.tls_ca, creds.tls_cert_data, - ))) - } else { - Ok(Credentials { - host: creds.host, - port: creds.port, - user: creds.user, - password: creds.password, - database: creds.database, - branch: creds.branch, - tls_ca: creds.tls_ca.or(creds.tls_cert_data.clone()), - tls_server_name: creds.tls_server_name, - tls_security: creds - .tls_security - .unwrap_or(match creds.tls_verify_hostname { - None => TlsSecurity::Default, - Some(true) => TlsSecurity::Strict, - Some(false) => TlsSecurity::NoHostVerification, - }), - file_outdated: creds.tls_verify_hostname.is_some() && creds.tls_security.is_none(), - }) - } - } -} diff --git a/edgedb-tokio/src/env.rs b/edgedb-tokio/src/env.rs deleted file mode 100644 index a9923a71..00000000 --- a/edgedb-tokio/src/env.rs +++ /dev/null @@ -1,226 +0,0 @@ -use std::fmt::Debug; -use std::io; -use std::num::NonZeroU16; -use std::time::Duration; -use std::{env, path::PathBuf, str::FromStr}; - -use edgedb_protocol::model; -use url::Url; - -use crate::errors::{ClientError, Error, ErrorKind}; -use crate::{builder::CloudCerts, ClientSecurity, InstanceName, TlsSecurity}; - -#[cfg_attr(feature = "unstable", macro_export)] -macro_rules! define_env { - ($( - #[doc=$doc:expr] - #[env($($env_name:expr),+)] - $(#[preprocess=$preprocess:expr])? - $(#[parse=$parse:expr])? - $(#[validate=$validate:expr])? - $name:ident: $type:ty - ),* $(,)?) => { - #[derive(Debug, Clone)] - pub struct Env { - } - - impl Env { - $( - #[doc = $doc] - pub fn $name() -> ::std::result::Result<::std::option::Option<$type>, $crate::Error> { - const ENV_NAMES: &[&str] = &[$(stringify!($env_name)),+]; - let Some((name, s)) = $crate::env::get_envs(ENV_NAMES)? else { - return Ok(None); - }; - $(let Some(s) = $preprocess(s) else { - return Ok(None); - };)? - - // This construct lets us choose between $parse and std::str::FromStr - // without requiring all types to implement FromStr. - #[allow(unused_labels)] - let value: $type = 'block: { - $( - break 'block $parse(&name, &s)?; - - // Disable the fallback parser - #[cfg(all(debug_assertions, not(debug_assertions)))] - )? - $crate::env::parse(&name, &s)? - }; - - $($validate(name, &value)?;)? - Ok(Some(value)) - } - )* - } - }; -} - -define_env!( - /// The host to connect to. - #[env(GEL_HOST, EDGEDB_HOST)] - #[validate=validate_host] - host: String, - - /// The port to connect to. - #[env(GEL_PORT, EDGEDB_PORT)] - #[preprocess=ignore_docker_tcp_port] - port: NonZeroU16, - - /// The database name to connect to. - #[env(GEL_DATABASE, EDGEDB_DATABASE)] - #[validate=non_empty_string] - database: String, - - /// The branch name to connect to. - #[env(GEL_BRANCH, EDGEDB_BRANCH)] - #[validate=non_empty_string] - branch: String, - - /// The username to connect as. - #[env(GEL_USER, EDGEDB_USER)] - #[validate=non_empty_string] - user: String, - - /// The password to use for authentication. - #[env(GEL_PASSWORD, EDGEDB_PASSWORD)] - password: String, - - /// TLS server name to verify. - #[env(GEL_TLS_SERVER_NAME, EDGEDB_TLS_SERVER_NAME)] - tls_server_name: String, - - /// Path to credentials file. - #[env(GEL_CREDENTIALS_FILE, EDGEDB_CREDENTIALS_FILE)] - credentials_file: String, - - /// Instance name to connect to. - #[env(GEL_INSTANCE, EDGEDB_INSTANCE)] - instance: InstanceName, - - /// Connection DSN string. - #[env(GEL_DSN, EDGEDB_DSN)] - dsn: Url, - - /// Secret key for authentication. - #[env(GEL_SECRET_KEY, EDGEDB_SECRET_KEY)] - secret_key: String, - - /// Client security mode. - #[env(GEL_CLIENT_SECURITY, EDGEDB_CLIENT_SECURITY)] - client_security: ClientSecurity, - - /// TLS security mode. - #[env(GEL_CLIENT_TLS_SECURITY, EDGEDB_CLIENT_TLS_SECURITY)] - client_tls_security: TlsSecurity, - - /// Path to TLS CA certificate file. - #[env(GEL_TLS_CA, EDGEDB_TLS_CA)] - tls_ca: String, - - /// Path to TLS CA certificate file. - #[env(GEL_TLS_CA_FILE, EDGEDB_TLS_CA_FILE)] - tls_ca_file: PathBuf, - - /// Cloud profile name. - #[env(GEL_CLOUD_PROFILE, EDGEDB_CLOUD_PROFILE)] - cloud_profile: String, - - /// Cloud certificates mode. - #[env(_GEL_CLOUD_CERTS, _EDGEDB_CLOUD_CERTS)] - _cloud_certs: CloudCerts, - - /// How long to wait for server to become available. - #[env(GEL_WAIT_UNTIL_AVAILABLE, EDGEDB_WAIT_UNTIL_AVAILABLE)] - #[parse=parse_duration] - wait_until_available: Duration, -); - -fn ignore_docker_tcp_port(s: String) -> Option { - static PORT_WARN: std::sync::Once = std::sync::Once::new(); - - if s.starts_with("tcp://") { - PORT_WARN.call_once(|| { - eprintln!("GEL_PORT/EDGEDB_PORT is ignored when using Docker TCP port"); - }); - None - } else { - Some(s) - } -} - -fn non_empty_string(var: &str, s: &str) -> Result<(), Error> { - if s.is_empty() { - Err(create_var_error(var, "empty string")) - } else { - Ok(()) - } -} - -fn validate_host(var: &str, s: &str) -> Result<(), Error> { - if s.is_empty() { - return Err(create_var_error(var, "invalid host: empty string")); - } else if s.contains(',') { - return Err(create_var_error(var, "invalid host: multiple hosts")); - } - Ok(()) -} - -#[inline(never)] -#[doc(hidden)] -pub fn parse(var: &str, s: &str) -> Result -where - ::Err: Debug, -{ - s.parse().map_err(|e| create_var_error(var, e)) -} - -#[inline(never)] -pub(crate) fn get_env(name: &str) -> Result, Error> { - let var = env::var(name); - match var { - Ok(v) if v.is_empty() => Ok(None), - Ok(v) => Ok(Some(v)), - Err(env::VarError::NotPresent) => Ok(None), - Err(e) => Err(create_var_error(name, e)), - } -} - -#[inline(never)] -#[doc(hidden)] -pub fn get_envs(names: &'static [&'static str]) -> Result, Error> { - let mut value = None; - let mut found_vars = Vec::new(); - - for name in names { - if let Some(val) = get_env(name)? { - found_vars.push(format!("{}={}", name, val)); - if value.is_none() { - value = Some((*name, val)); - } - } - } - - if found_vars.len() > 1 { - log::warn!( - "Multiple environment variables set: {}", - found_vars.join(", ") - ); - } - - Ok(value) -} - -fn parse_duration(var: &str, s: &str) -> Result { - let duration = model::Duration::from_str(s).map_err(|e| create_var_error(var, e))?; - - duration.try_into().map_err(|e| create_var_error(var, e)) -} - -fn create_var_error(var: &str, e: impl Debug) -> Error { - ClientError::with_source(io::Error::new( - io::ErrorKind::InvalidInput, - format!("{var} is invalid: {e:?}"), - )) -} diff --git a/edgedb-tokio/src/errors.rs b/edgedb-tokio/src/errors.rs deleted file mode 100644 index cd9d6aea..00000000 --- a/edgedb-tokio/src/errors.rs +++ /dev/null @@ -1,2 +0,0 @@ -//! Errors that can be returned by a client -pub use edgedb_errors::{kinds::*, Error, ErrorKind, ResultExt}; diff --git a/edgedb-tokio/src/letsencrypt_staging.pem b/edgedb-tokio/src/letsencrypt_staging.pem deleted file mode 100644 index 44dce211..00000000 --- a/edgedb-tokio/src/letsencrypt_staging.pem +++ /dev/null @@ -1,47 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIFmDCCA4CgAwIBAgIQU9C87nMpOIFKYpfvOHFHFDANBgkqhkiG9w0BAQsFADBm -MQswCQYDVQQGEwJVUzEzMDEGA1UEChMqKFNUQUdJTkcpIEludGVybmV0IFNlY3Vy -aXR5IFJlc2VhcmNoIEdyb3VwMSIwIAYDVQQDExkoU1RBR0lORykgUHJldGVuZCBQ -ZWFyIFgxMB4XDTE1MDYwNDExMDQzOFoXDTM1MDYwNDExMDQzOFowZjELMAkGA1UE -BhMCVVMxMzAxBgNVBAoTKihTVEFHSU5HKSBJbnRlcm5ldCBTZWN1cml0eSBSZXNl -YXJjaCBHcm91cDEiMCAGA1UEAxMZKFNUQUdJTkcpIFByZXRlbmQgUGVhciBYMTCC -AiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBALbagEdDTa1QgGBWSYkyMhsc -ZXENOBaVRTMX1hceJENgsL0Ma49D3MilI4KS38mtkmdF6cPWnL++fgehT0FbRHZg -jOEr8UAN4jH6omjrbTD++VZneTsMVaGamQmDdFl5g1gYaigkkmx8OiCO68a4QXg4 -wSyn6iDipKP8utsE+x1E28SA75HOYqpdrk4HGxuULvlr03wZGTIf/oRt2/c+dYmD -oaJhge+GOrLAEQByO7+8+vzOwpNAPEx6LW+crEEZ7eBXih6VP19sTGy3yfqK5tPt -TdXXCOQMKAp+gCj/VByhmIr+0iNDC540gtvV303WpcbwnkkLYC0Ft2cYUyHtkstO -fRcRO+K2cZozoSwVPyB8/J9RpcRK3jgnX9lujfwA/pAbP0J2UPQFxmWFRQnFjaq6 -rkqbNEBgLy+kFL1NEsRbvFbKrRi5bYy2lNms2NJPZvdNQbT/2dBZKmJqxHkxCuOQ -FjhJQNeO+Njm1Z1iATS/3rts2yZlqXKsxQUzN6vNbD8KnXRMEeOXUYvbV4lqfCf8 -mS14WEbSiMy87GB5S9ucSV1XUrlTG5UGcMSZOBcEUpisRPEmQWUOTWIoDQ5FOia/ -GI+Ki523r2ruEmbmG37EBSBXdxIdndqrjy+QVAmCebyDx9eVEGOIpn26bW5LKeru -mJxa/CFBaKi4bRvmdJRLAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMB -Af8EBTADAQH/MB0GA1UdDgQWBBS182Xy/rAKkh/7PH3zRKCsYyXDFDANBgkqhkiG -9w0BAQsFAAOCAgEAncDZNytDbrrVe68UT6py1lfF2h6Tm2p8ro42i87WWyP2LK8Y -nLHC0hvNfWeWmjZQYBQfGC5c7aQRezak+tHLdmrNKHkn5kn+9E9LCjCaEsyIIn2j -qdHlAkepu/C3KnNtVx5tW07e5bvIjJScwkCDbP3akWQixPpRFAsnP+ULx7k0aO1x -qAeaAhQ2rgo1F58hcflgqKTXnpPM02intVfiVVkX5GXpJjK5EoQtLceyGOrkxlM/ -sTPq4UrnypmsqSagWV3HcUlYtDinc+nukFk6eR4XkzXBbwKajl0YjztfrCIHOn5Q -CJL6TERVDbM/aAPly8kJ1sWGLuvvWYzMYgLzDul//rUF10gEMWaXVZV51KpS9DY/ -5CunuvCXmEQJHo7kGcViT7sETn6Jz9KOhvYcXkJ7po6d93A/jy4GKPIPnsKKNEmR -xUuXY4xRdh45tMJnLTUDdC9FIU0flTeO9/vNpVA8OPU1i14vCz+MU8KX1bV3GXm/ -fxlB7VBBjX9v5oUep0o/j68R/iDlCOM4VVfRa8gX6T2FU7fNdatvGro7uQzIvWof -gN9WUwCbEMBy/YhBSrXycKA8crgGg3x1mIsopn88JKwmMBa68oS7EHM9w7C4y71M -7DiA+/9Qdp9RBWJpTS9i/mDnJg1xvo8Xz49mrrgfmcAXTCJqXi24NatI3Oc= ------END CERTIFICATE----- ------BEGIN CERTIFICATE----- -MIICTjCCAdSgAwIBAgIRAIPgc3k5LlLVLtUUvs4K/QcwCgYIKoZIzj0EAwMwaDEL -MAkGA1UEBhMCVVMxMzAxBgNVBAoTKihTVEFHSU5HKSBJbnRlcm5ldCBTZWN1cml0 -eSBSZXNlYXJjaCBHcm91cDEkMCIGA1UEAxMbKFNUQUdJTkcpIEJvZ3VzIEJyb2Nj -b2xpIFgyMB4XDTIwMDkwNDAwMDAwMFoXDTQwMDkxNzE2MDAwMFowaDELMAkGA1UE -BhMCVVMxMzAxBgNVBAoTKihTVEFHSU5HKSBJbnRlcm5ldCBTZWN1cml0eSBSZXNl -YXJjaCBHcm91cDEkMCIGA1UEAxMbKFNUQUdJTkcpIEJvZ3VzIEJyb2Njb2xpIFgy -MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEOvS+w1kCzAxYOJbA06Aw0HFP2tLBLKPo -FQqR9AMskl1nC2975eQqycR+ACvYelA8rfwFXObMHYXJ23XLB+dAjPJVOJ2OcsjT -VqO4dcDWu+rQ2VILdnJRYypnV1MMThVxo0IwQDAOBgNVHQ8BAf8EBAMCAQYwDwYD -VR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU3tGjWWQOwZo2o0busBB2766XlWYwCgYI -KoZIzj0EAwMDaAAwZQIwRcp4ZKBsq9XkUuN8wfX+GEbY1N5nmCRc8e80kUkuAefo -uc2j3cICeXo1cOybQ1iWAjEA3Ooawl8eQyR4wrjCofUE8h44p0j7Yl/kBlJZT8+9 -vbtH7QiVzeKCOTQPINyRql6P ------END CERTIFICATE----- diff --git a/edgedb-tokio/src/lib.rs b/edgedb-tokio/src/lib.rs index 3f75f8f8..a00bd8a5 100644 --- a/edgedb-tokio/src/lib.rs +++ b/edgedb-tokio/src/lib.rs @@ -1,181 +1,7 @@ /*! EdgeDB client for Tokio -👉 New! Check out the new [EdgeDB client tutorial](`tutorial`). 👈 - -The main way to use EdgeDB bindings is to use the [`Client`]. It encompasses -connection pool to the database that is transparent for user. Individual -queries can be made via methods on the client. Correlated queries are done -via [transactions](Client::transaction). - -To create a client, use the [`create_client`] function (it gets a database -connection configuration from environment). You can also use a [`Builder`] -to [`build`](`Builder::new`) custom [`Config`] and [create a -client](Client::new) using that config. - -# Example - -```rust,no_run -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let conn = edgedb_tokio::create_client().await?; - let val = conn.query_required_single::( - "SELECT 7*8", - &(), - ).await?; - println!("7*8 is: {}", val); - Ok(()) -} -``` -More [examples on github](https://github.com/edgedb/edgedb-rust/tree/master/edgedb-tokio/examples) - -# Nice Error Reporting - -We use [miette] crate for including snippets in your error reporting code. - -To make it work, first you need enable `fancy` feature in your top-level -crate's `Cargo.toml`: -```toml -[dependencies] -miette = { version="5.3.0", features=["fancy"] } -edgedb-tokio = { version="*", features=["miette-errors"] } -``` - -Then if you use `miette` all the way through your application, it just -works: -```rust,no_run -#[tokio::main] -async fn main() -> miette::Result<()> { - let conn = edgedb_tokio::create_client().await?; - conn.query::("SELECT 1+2)", &()).await?; - Ok(()) -} -``` - -However, if you use some boxed error container (e.g. [anyhow]), you -might need to downcast error for printing: -```rust,no_run -async fn do_something() -> anyhow::Result<()> { - let conn = edgedb_tokio::create_client().await?; - conn.query::("SELECT 1+2)", &()).await?; - Ok(()) -} - -#[tokio::main] -async fn main() { - match do_something().await { - Ok(res) => res, - Err(e) => { - e.downcast::() - .map(|e| eprintln!("{:?}", miette::Report::new(e))) - .unwrap_or_else(|e| eprintln!("{:#}", e)); - std::process::exit(1); - } - } -} -``` - -In some cases, where parts of your code use `miette::Result` or -`miette::Report` before converting to the boxed (anyhow) container, you -might want a little bit more complex downcasting: - -```rust,no_run -# async fn do_something() -> anyhow::Result<()> { unimplemented!() } -#[tokio::main] -async fn main() { - match do_something().await { - Ok(res) => res, - Err(e) => { - e.downcast::() - .map(|e| eprintln!("{:?}", miette::Report::new(e))) - .or_else(|e| e.downcast::() - .map(|e| eprintln!("{:?}", e))) - .unwrap_or_else(|e| eprintln!("{:#}", e)); - std::process::exit(1); - } - } -} -``` - -Note that last two examples do hide error contexts from anyhow and do not -pretty print if `source()` of the error is `edgedb_errors::Error` but not -the top-level one. We leave those more complex cases as an excersize to the -reader. - -[miette]: https://crates.io/crates/miette -[anyhow]: https://crates.io/crates/anyhow +> This crate has been renamed to [gel-tokio](https://crates.io/crates/gel-tokio). */ -#![cfg_attr( - not(feature = "unstable"), - warn(missing_docs, missing_debug_implementations) -)] - -macro_rules! unstable_pub_mods { - ($(mod $mod_name:ident;)*) => { - $( - #[cfg(feature = "unstable")] - pub mod $mod_name; - #[cfg(not(feature = "unstable"))] - mod $mod_name; - )* - } -} - -// If the unstable feature is enabled, the modules will be public. -// If the unstable feature is not enabled, the modules will be private. -unstable_pub_mods! { - mod builder; - mod credentials; - mod raw; - mod server_params; - mod tls; - mod env; -} - -mod client; -mod errors; -mod options; -mod query_executor; -mod sealed; -pub mod state; -mod transaction; -pub mod tutorial; - -pub use edgedb_derive::{ConfigDelta, GlobalsDelta, Queryable}; - -pub use builder::{Builder, ClientSecurity, Config, InstanceName, TcpKeepalive}; -pub use client::Client; -pub use credentials::TlsSecurity; -pub use errors::Error; -pub use options::{RetryCondition, RetryOptions, TransactionOptions}; -pub use query_executor::{QueryExecutor, ResultVerbose}; -pub use state::{ConfigDelta, GlobalsDelta}; -pub use transaction::Transaction; - -/// The ordered list of project filenames supported. -pub const PROJECT_FILES: &[&str] = &["gel.toml", "edgedb.toml"]; - -/// The default project filename. -pub const DEFAULT_PROJECT_FILE: &str = PROJECT_FILES[0]; - -#[cfg(feature = "unstable")] -pub use builder::{get_project_path, get_stash_path}; - -/// Create a connection to the database with default parameters -/// -/// It's expected that connection parameters are set up using environment -/// (either environment variables or project configuration in a file named by -/// [`PROJECT_FILES`]) so no configuration is specified here. -/// -/// This method tries to esablish single connection immediately to ensure that -/// configuration is valid and will error out otherwise. -/// -/// For more fine-grained setup see [`Client`] and [`Builder`] documentation and -/// the source of this function. -#[cfg(feature = "env")] -pub async fn create_client() -> Result { - let pool = Client::new(&Builder::new().build_env().await?); - pool.ensure_connected().await?; - Ok(pool) -} +compile_error!("edgedb-tokio has been renamed to gel-tokio"); \ No newline at end of file diff --git a/edgedb-tokio/src/nebula_development.pem b/edgedb-tokio/src/nebula_development.pem deleted file mode 100644 index 3c07762e..00000000 --- a/edgedb-tokio/src/nebula_development.pem +++ /dev/null @@ -1,13 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICBjCCAaugAwIBAgIUGLnu92rPr79+DsDQBtolXEZENwMwCgYIKoZIzj0EAwIw -UDELMAkGA1UEBhMCVVMxGjAYBgNVBAoMEUVkZ2VEQiAoaW50ZXJuYWwpMSUwIwYD -VQQDDBxOZWJ1bGEgSW5mcmEgUm9vdCBDQSAobG9jYWwpMB4XDTIzMDExNDIzMDkw -M1oXDTMyMTAxMzIzMDkwM1owUDELMAkGA1UEBhMCVVMxGjAYBgNVBAoMEUVkZ2VE -QiAoaW50ZXJuYWwpMSUwIwYDVQQDDBxOZWJ1bGEgSW5mcmEgUm9vdCBDQSAobG9j -YWwpMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEHJk/v57y1dG1xekQjeYwqlW7 -45fvlWIIid/EfcyBNCyvWhLUyQUz3urmK81rJlFYCexq/kgazTeBFJyWbrvLLKNj -MGEwHQYDVR0OBBYEFN5PvqC9p5e4HC99o3z0pJrRuIpeMB8GA1UdIwQYMBaAFN5P -vqC9p5e4HC99o3z0pJrRuIpeMA8GA1UdEwEB/wQFMAMBAf8wDgYDVR0PAQH/BAQD -AgEGMAoGCCqGSM49BAMCA0kAMEYCIQDedUpRy5YtQAHROrh/ZsWPlvek3vguuRrE -y4u6fdOVhgIhAJ4pJLfdoWQsHPUOcnVG5fBgdSnoCJhGQyuGyp+NDu1q ------END CERTIFICATE----- diff --git a/edgedb-tokio/src/options.rs b/edgedb-tokio/src/options.rs deleted file mode 100644 index 317c7ce5..00000000 --- a/edgedb-tokio/src/options.rs +++ /dev/null @@ -1,201 +0,0 @@ -use std::collections::HashMap; -use std::fmt; -use std::sync::Arc; -use std::time::Duration; - -use once_cell::sync::Lazy; -use rand::{thread_rng, Rng}; - -use crate::errors::{Error, IdleSessionTimeoutError}; - -/// Single immediate retry on idle is fine -/// -/// This doesn't have to be configured. -static IDLE_TIMEOUT_RULE: Lazy = Lazy::new(|| RetryRule { - attempts: 2, - backoff: Arc::new(|_| Duration::new(0, 0)), -}); - -/// Specific condition for retrying queries -/// -/// This is used for fine-grained control for retrying queries and transactions -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[non_exhaustive] -pub enum RetryCondition { - /// Optimistic transaction error - TransactionConflict, - /// Network failure between client and server - NetworkError, -} - -/// Options for [`transaction()`](crate::Client::transaction) -/// -/// Must be set on a [`Client`](crate::Client) via -/// [`with_transaction_options`](crate::Client::with_transaction_options). -#[derive(Debug, Clone, Default)] -pub struct TransactionOptions { - read_only: bool, - deferrable: bool, -} - -/// This structure contains options for retrying transactions and queries -/// -/// Must be set on a [`Client`](crate::Client) via -/// [`with_retry_options`](crate::Client::with_retry_options). -#[derive(Debug, Clone)] -pub struct RetryOptions(Arc); - -#[derive(Debug, Clone)] -struct RetryOptionsInner { - default: RetryRule, - overrides: HashMap, -} - -#[derive(Clone)] -pub(crate) struct RetryRule { - pub(crate) attempts: u32, - pub(crate) backoff: Arc Duration + Send + Sync>, -} - -impl TransactionOptions { - /// Set whether transaction is read-only - pub fn read_only(mut self, read_only: bool) -> Self { - self.read_only = read_only; - self - } - /// Set whether transaction is deferrable - pub fn deferrable(mut self, deferrable: bool) -> Self { - self.deferrable = deferrable; - self - } -} - -impl Default for RetryRule { - fn default() -> RetryRule { - RetryRule { - attempts: 3, - backoff: Arc::new(|n| { - Duration::from_millis(2u64.pow(n) * 100 + thread_rng().gen_range(0..100)) - }), - } - } -} - -impl Default for RetryOptions { - fn default() -> RetryOptions { - RetryOptions(Arc::new(RetryOptionsInner { - default: RetryRule::default(), - overrides: HashMap::new(), - })) - } -} - -impl RetryOptions { - /// Create a new [`RetryOptions`] object with the default rule - pub fn new( - self, - attempts: u32, - backoff: impl Fn(u32) -> Duration + Send + Sync + 'static, - ) -> Self { - RetryOptions(Arc::new(RetryOptionsInner { - default: RetryRule { - attempts, - backoff: Arc::new(backoff), - }, - overrides: HashMap::new(), - })) - } - /// Add a retrying rule for a specific condition - pub fn with_rule( - mut self, - condition: RetryCondition, - attempts: u32, - backoff: impl Fn(u32) -> Duration + Send + Sync + 'static, - ) -> Self { - let inner = Arc::make_mut(&mut self.0); - inner.overrides.insert( - condition, - RetryRule { - attempts, - backoff: Arc::new(backoff), - }, - ); - self - } - pub(crate) fn get_rule(&self, err: &Error) -> &RetryRule { - use edgedb_errors::{ClientError, TransactionConflictError}; - use RetryCondition::*; - - if err.is::() { - &IDLE_TIMEOUT_RULE - } else if err.is::() { - self.0 - .overrides - .get(&TransactionConflict) - .unwrap_or(&self.0.default) - } else if err.is::() { - self.0 - .overrides - .get(&NetworkError) - .unwrap_or(&self.0.default) - } else { - &self.0.default - } - } -} - -struct DebugBackoff(F, u32); - -impl fmt::Debug for DebugBackoff -where - F: Fn(u32) -> Duration, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self.1 > 3 { - for i in 0..3 { - write!(f, "{:?}, ", (self.0)(i))?; - } - write!(f, "...")?; - } else { - write!(f, "{:?}", (self.0)(0))?; - for i in 1..self.1 { - write!(f, ", {:?}", (self.0)(i))?; - } - } - Ok(()) - } -} - -#[test] -fn debug_backoff() { - assert_eq!( - format!( - "{:?}", - DebugBackoff(|i| Duration::from_secs(10 + (i as u64) * 10), 3) - ), - "10s, 20s, 30s" - ); - assert_eq!( - format!( - "{:?}", - DebugBackoff(|i| Duration::from_secs(10 + (i as u64) * 10), 10) - ), - "10s, 20s, 30s, ..." - ); - assert_eq!( - format!( - "{:?}", - DebugBackoff(|i| Duration::from_secs(10 + (i as u64) * 10), 2) - ), - "10s, 20s" - ); -} - -impl fmt::Debug for RetryRule { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("RetryRule") - .field("attempts", &self.attempts) - .field("backoff", &DebugBackoff(&*self.backoff, self.attempts)) - .finish() - } -} diff --git a/edgedb-tokio/src/query_executor.rs b/edgedb-tokio/src/query_executor.rs deleted file mode 100644 index 0aa01901..00000000 --- a/edgedb-tokio/src/query_executor.rs +++ /dev/null @@ -1,253 +0,0 @@ -use edgedb_protocol::query_arg::QueryArgs; -use edgedb_protocol::QueryResult; -use edgedb_protocol::{annotations::Warning, model::Json}; -use std::future::Future; - -use crate::{Client, Error, Transaction}; - -/// Query result with additional metadata. -#[non_exhaustive] -#[derive(Debug)] -pub struct ResultVerbose { - /// Query results - pub data: R, - - /// Query warnings - pub warnings: Vec, -} - -/// Abstracts over different query executors -/// In particular &Client and &mut Transaction -pub trait QueryExecutor: Sized { - /// see [Client::query] - fn query( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future, Error>> + Send - where - A: QueryArgs, - R: QueryResult + Send; - - /// see [Client::query_with_warnings] - fn query_verbose( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future>, Error>> + Send - where - A: QueryArgs, - R: QueryResult + Send; - - /// see [Client::query_single] - fn query_single( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future, Error>> + Send - where - A: QueryArgs, - R: QueryResult + Send; - - /// see [Client::query_required_single] - fn query_required_single( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl std::future::Future> + Send - where - A: QueryArgs, - R: QueryResult + Send; - - /// see [Client::query_json] - fn query_json( - self, - query: &str, - arguments: &impl QueryArgs, - ) -> impl Future> + Send; - - /// see [Client::query_single_json] - fn query_single_json( - self, - query: &str, - arguments: &impl QueryArgs, - ) -> impl Future, Error>> + Send; - - /// see [Client::query_required_single_json] - fn query_required_single_json( - self, - query: &str, - arguments: &impl QueryArgs, - ) -> impl Future>; - - /// see [Client::execute] - fn execute( - self, - query: &str, - arguments: &A, - ) -> impl Future> + Send - where - A: QueryArgs; -} - -impl QueryExecutor for &Client { - fn query( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future, Error>> - where - A: QueryArgs, - R: QueryResult, - { - Client::query(self, query, arguments) - } - - fn query_verbose( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future>, Error>> + Send - where - A: QueryArgs, - R: QueryResult + Send, - { - Client::query_verbose(self, query, arguments) - } - - fn query_single( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future, Error>> - where - A: QueryArgs, - R: QueryResult + Send, - { - Client::query_single(self, query, arguments) - } - - fn query_required_single( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future> - where - A: QueryArgs, - R: QueryResult + Send, - { - Client::query_required_single(self, query, arguments) - } - - fn query_json( - self, - query: &str, - arguments: &impl QueryArgs, - ) -> impl Future> { - Client::query_json(self, query, arguments) - } - - fn query_single_json( - self, - query: &str, - arguments: &impl QueryArgs, - ) -> impl Future, Error>> { - Client::query_single_json(self, query, arguments) - } - - fn query_required_single_json( - self, - query: &str, - arguments: &impl QueryArgs, - ) -> impl Future> { - Client::query_required_single_json(self, query, arguments) - } - - fn execute(self, query: &str, arguments: &A) -> impl Future> - where - A: QueryArgs, - { - Client::execute(self, query, arguments) - } -} - -impl QueryExecutor for &mut Transaction { - fn query( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future, Error>> - where - A: QueryArgs, - R: QueryResult, - { - Transaction::query(self, query, arguments) - } - - fn query_verbose( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future>, Error>> + Send - where - A: QueryArgs, - R: QueryResult + Send, - { - Transaction::query_verbose(self, query, arguments) - } - - fn query_single( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future, Error>> - where - A: QueryArgs, - R: QueryResult + Send, - { - Transaction::query_single(self, query, arguments) - } - - fn query_required_single( - self, - query: impl AsRef + Send, - arguments: &A, - ) -> impl Future> - where - A: QueryArgs, - R: QueryResult + Send, - { - Transaction::query_required_single(self, query, arguments) - } - - fn query_json( - self, - query: &str, - arguments: &impl QueryArgs, - ) -> impl Future> { - Transaction::query_json(self, query, arguments) - } - - fn query_single_json( - self, - query: &str, - arguments: &impl QueryArgs, - ) -> impl Future, Error>> { - Transaction::query_single_json(self, query, arguments) - } - - fn query_required_single_json( - self, - query: &str, - arguments: &impl QueryArgs, - ) -> impl Future> { - Transaction::query_required_single_json(self, query, arguments) - } - - fn execute(self, query: &str, arguments: &A) -> impl Future> - where - A: QueryArgs, - { - Transaction::execute(self, query, arguments) - } -} diff --git a/edgedb-tokio/src/raw/connection.rs b/edgedb-tokio/src/raw/connection.rs deleted file mode 100644 index 1eddfdd2..00000000 --- a/edgedb-tokio/src/raw/connection.rs +++ /dev/null @@ -1,854 +0,0 @@ -use std::borrow::Cow; -use std::cmp::min; -use std::collections::HashMap; -use std::error::Error as _; -use std::future::{self, Future}; -use std::io; -use std::str; -use std::time::Duration; - -use bytes::{Bytes, BytesMut}; -use rand::{thread_rng, Rng}; -use rustls::pki_types::DnsName; -use scram::ScramClient; -use socket2::TcpKeepalive; -use tls_api::TlsConnectorBuilder; -use tls_api::{TlsConnector, TlsConnectorBox, TlsStream, TlsStreamDyn}; -use tls_api_not_tls::TlsConnector as PlainConnector; -use tokio::io::ReadBuf; -use tokio::io::{AsyncRead, AsyncReadExt}; -use tokio::io::{AsyncWrite, AsyncWriteExt}; -use tokio::net::TcpStream; -use tokio::time::{sleep, timeout_at, Instant}; - -use edgedb_protocol::client_message::{ClientHandshake, ClientMessage}; -use edgedb_protocol::encoding::{Input, Output}; -use edgedb_protocol::features::ProtocolVersion; -use edgedb_protocol::server_message::{ - Authentication, ErrorResponse, ServerHandshake, ServerMessage, -}; -use edgedb_protocol::server_message::{ - MessageSeverity, ParameterStatus, RawTypedesc, TransactionState, -}; -use edgedb_protocol::value::Value; - -use crate::builder::{Address, Config}; -use crate::errors::{ - AuthenticationError, ClientConnectionEosError, ClientConnectionError, - ClientConnectionFailedError, ClientConnectionFailedTemporarilyError, ClientEncodingError, - ClientError, Error, ErrorKind, IdleSessionTimeoutError, PasswordRequired, - ProtocolEncodingError, ProtocolError, ProtocolTlsError, -}; -use crate::raw::queries::Guard; -use crate::raw::{Connection, PingInterval}; -use crate::server_params::{ServerParam, ServerParams, SystemConfig}; -use crate::tls; - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub(crate) enum Mode { - Normal { idle_since: Instant }, - Dirty, - AwaitingPing, -} - -impl Connection { - pub fn is_consistent(&self) -> bool { - matches!(self.mode, Mode::Normal { .. }) - } - pub async fn is_connection_reset(&mut self) -> bool { - tokio::select! { biased; - msg = wait_message(&mut self.stream, &mut self.in_buf, &self.proto) - => { - match msg { - Ok(ServerMessage::ErrorResponse(e)) => { - let e: Error = e.into(); - if e.is::() { - log::debug!("Connection reset due to inactivity."); - } else { - log::warn!("Unexpected error: {:#}", e); - } - true - } - Ok(m) => { - log::warn!("Unsolicited message: {:?}", m); - true - } - Err(e) => { - log::debug!("I/O error: {:#}", e); - true - } - } - } - _ = future::ready(()) => { - if self.in_buf.is_empty() { - false - } else { - log::warn!("Unsolicited partial data {:?}", - &self.in_buf[..min(self.in_buf.len(), 16)]); - true - } - } - } - } - pub async fn connect(config: &Config) -> Result { - connect(config).await.map_err(|e| { - if e.is::() { - e.refine_kind::() - } else { - e - } - }) - } - pub async fn send_messages<'x>( - &mut self, - msgs: impl IntoIterator, - ) -> Result<(), Error> { - send_messages(&mut self.stream, &mut self.out_buf, &self.proto, msgs).await - } - pub async fn message(&mut self) -> Result { - wait_message(&mut self.stream, &mut self.in_buf, &self.proto).await - } - pub fn get_server_param(&self) -> Option<&T::Value> { - self.server_params.get::() - } - #[cfg(feature = "unstable")] - pub async fn ping_while(&mut self, other: F) -> T - where - F: Future, - { - if self.ping_interval == PingInterval::Unknown { - self.ping_interval = self.calc_ping_interval(); - } - if let PingInterval::Interval(interval) = self.ping_interval { - let result = tokio::select! { biased; - _ = self.background_pings(interval) => unreachable!(), - res = other => res, - }; - if self.mode == Mode::AwaitingPing { - self.synchronize_ping().await.ok(); - } - result - } else { - other.await - } - } - async fn do_pings(&mut self, interval: Duration) -> Result<(), Error> { - if self.mode == Mode::AwaitingPing { - self.synchronize_ping().await?; - } - - while let Mode::Normal { - idle_since: last_pong, - } = self.mode - { - match timeout_at(last_pong + interval, self.passive_wait()).await { - Err(_) => {} - Ok(Err(e)) => { - self.mode = Mode::Dirty; - return Err(ClientConnectionError::with_source(e))?; - } - Ok(Ok(_)) => unreachable!(), - } - - self.mode = Mode::Dirty; - self.send_messages(&[ClientMessage::Sync]).await?; - self.mode = Mode::AwaitingPing; - self.synchronize_ping().await?; - } - Ok(()) - } - async fn background_pings(&mut self, interval: Duration) -> T { - self.do_pings(interval) - .await - .map_err(|e| log::info!("Connection error during background pings: {}", e)) - .ok(); - debug_assert_eq!(self.mode, Mode::Dirty); - future::pending::<()>().await; - unreachable!(); - } - async fn synchronize_ping<'a>(&mut self) -> Result<(), Error> { - debug_assert_eq!(self.mode, Mode::AwaitingPing); - - // Guard mechanism was invented for real queries, so we have to - // make a little bit of workaround just for Pings - let spurious_guard = Guard; - match self.expect_ready(spurious_guard).await { - Ok(()) => Ok(()), - Err(e) => { - self.mode = Mode::Dirty; - Err(e) - } - } - } - pub async fn passive_wait(&mut self) -> io::Result<()> { - loop { - let msg = self - .message() - .await - .map_err(|_| io::ErrorKind::InvalidData)?; - match msg { - // TODO(tailhook) update parameters? - ServerMessage::ParameterStatus(_) => {} - _ => return Err(io::ErrorKind::InvalidData)?, - } - } - } - fn calc_ping_interval(&self) -> PingInterval { - if let Some(config) = self.server_params.get::() { - if let Some(timeout) = config.session_idle_timeout { - if timeout.is_zero() { - log::info!( - "Server disabled session_idle_timeout; \ - pings are disabled." - ); - PingInterval::Disabled - } else { - let interval = Duration::from_secs( - (timeout.saturating_sub(Duration::from_secs(1)).as_secs_f64() * 0.9).ceil() - as u64, - ); - if interval.is_zero() { - log::warn!( - "session_idle_timeout={:?} is too short; \ - pings are disabled.", - timeout, - ); - PingInterval::Disabled - } else { - log::info!( - "Setting ping interval to {:?} as \ - session_idle_timeout={:?}", - interval, - timeout, - ); - PingInterval::Interval(interval) - } - } - } else { - PingInterval::Unknown - } - } else { - PingInterval::Unknown - } - } - pub async fn terminate(mut self) -> Result<(), Error> { - let _ = self.begin_request()?; // not need to cleanup after that - self.send_messages(&[ClientMessage::Terminate]).await?; - match self.message().await { - Err(e) if e.is::() => Ok(()), - Err(e) => Err(e), - Ok(msg) => Err(ProtocolError::with_message(format!( - "unsolicited message {:?}", - msg - ))), - } - } - pub fn transaction_state(&self) -> TransactionState { - self.transaction_state - } - pub fn state_descriptor(&self) -> &RawTypedesc { - &self.state_desc - } - pub fn protocol(&self) -> &ProtocolVersion { - &self.proto - } -} - -async fn connect(cfg: &Config) -> Result { - let tls = tls::connector(cfg.0.verifier.clone()) - .map_err(|e| ClientError::with_source_ref(e).context("cannot create TLS connector"))?; - match &cfg.0.address { - Address::Unix(path) => { - log::info!("Connecting via Unix `{}`", path.display()); - } - Address::Tcp((host, port)) => { - log::info!("Connecting via TCP {host}:{port}"); - } - } - - let start = Instant::now(); - let wait = cfg.0.wait; - let warned = &mut false; - let conn = loop { - match connect_timeout(cfg, connect2(cfg, &tls, warned)).await { - Err(e) if is_temporary(&e) => { - log::debug!("Temporary connection error: {:#}", e); - if wait > start.elapsed() { - sleep(connect_sleep()).await; - continue; - } else if wait > Duration::new(0, 0) { - return Err(e.context(format!("cannot establish connection for {wait:?}"))); - } else { - return Err(e); - } - } - Err(e) => { - log::debug!("Connection error: {:#}", e); - return Err(e)?; - } - Ok(conn) => break conn, - } - }; - Ok(conn) -} - -async fn connect2( - cfg: &Config, - tls: &TlsConnectorBox, - warned: &mut bool, -) -> Result { - let stream = match connect3(cfg, tls).await { - Err(e) if e.is::() => { - if !*warned { - log::warn!( - "TLS connection failed. \ - Trying plaintext..." - ); - *warned = true; - } - connect3( - cfg, - &PlainConnector::builder() - .map_err(ClientError::with_source_ref)? - .build() - .map_err(ClientError::with_source_ref)? - .into_dyn(), - ) - .await? - } - Err(e) => return Err(e), - Ok(r) => match r.get_alpn_protocol() { - Ok(Some(protocol)) if protocol == b"edgedb-binary" => r, - _ => match &cfg.0.address { - Address::Tcp(_) => Err(ClientConnectionFailedError::with_message( - "Server does not support the EdgeDB binary protocol.", - ))?, - Address::Unix(_) => r, // don't check ALPN on UNIX stream - }, - }, - }; - connect4(cfg, stream).await -} - -async fn connect3(cfg: &Config, tls: &TlsConnectorBox) -> Result { - match &cfg.0.address { - Address::Tcp(addr @ (host, _)) => { - let conn = TcpStream::connect(addr) - .await - .map_err(ClientConnectionError::with_source)?; - - // Set keep-alive on the socket, but don't fail if this isn't successful - if let Some(keepalive) = cfg.0.tcp_keepalive { - let sock = socket2::SockRef::from(&conn); - #[cfg(target_os = "openbsd")] - let res = sock.set_keepalive(true); - #[cfg(not(target_os = "openbsd"))] - let res = sock.set_tcp_keepalive( - &TcpKeepalive::new() - .with_interval(keepalive) - .with_time(keepalive), - ); - if let Err(e) = res { - log::warn!("Failed to set TCP keepalive: {e:?}"); - } - } - - let host = match &cfg.0.tls_server_name { - Some(server_name) => Cow::from(server_name), - None => { - if DnsName::try_from(host.clone()).is_err() { - // FIXME: https://github.com/rustls/rustls/issues/184 - // If self.host is neither an IP address nor a valid DNS - // name, the hacks below won't make it valid anyways. - let host = format!("{}.host-for-ip.edgedb.net", host); - // for ipv6addr - let host = host.replace([':', '%'], "-"); - if host.starts_with('-') { - Cow::from(format!("i{}", host)) - } else { - Cow::from(host) - } - } else { - Cow::from(host) - } - } - }; - Ok(tls.connect(&host[..], conn).await.map_err(tls_fail)?) - } - Address::Unix(path) => { - #[cfg(windows)] - { - return Err(ClientError::with_message( - "Unix socket are not supported on windows", - )); - } - #[cfg(unix)] - { - use tokio::net::UnixStream; - let conn = UnixStream::connect(&path) - .await - .map_err(ClientConnectionError::with_source)?; - Ok(PlainConnector::builder() - .map_err(ClientError::with_source_ref)? - .build() - .map_err(ClientError::with_source_ref)? - .into_dyn() - .connect("localhost", conn) - .await - .map_err(tls_fail)?) - } - } - } -} - -async fn connect4(cfg: &Config, mut stream: TlsStream) -> Result { - let mut proto = ProtocolVersion::current(); - let mut out_buf = BytesMut::with_capacity(8192); - let mut in_buf = BytesMut::with_capacity(8192); - - let mut params = HashMap::new(); - params.insert(String::from("user"), cfg.0.user.clone()); - params.insert(String::from("database"), cfg.0.database.clone()); - params.insert(String::from("branch"), cfg.0.branch.clone()); - if let Some(secret_key) = cfg.0.secret_key.clone() { - params.insert(String::from("secret_key"), secret_key); - } - let (major_ver, minor_ver) = proto.version_tuple(); - send_messages( - &mut stream, - &mut out_buf, - &proto, - &[ClientMessage::ClientHandshake(ClientHandshake { - major_ver, - minor_ver, - params, - extensions: HashMap::new(), - })], - ) - .await?; - - let mut msg = wait_message(&mut stream, &mut in_buf, &proto).await?; - if let ServerMessage::ServerHandshake(ServerHandshake { - major_ver, - minor_ver, - extensions: _, - }) = msg - { - proto = ProtocolVersion::new(major_ver, minor_ver); - // TODO(tailhook) record extensions - msg = wait_message(&mut stream, &mut in_buf, &proto).await?; - } - match msg { - ServerMessage::Authentication(Authentication::Ok) => {} - ServerMessage::Authentication(Authentication::Sasl { methods }) => { - if methods.iter().any(|x| x == "SCRAM-SHA-256") { - if let Some(password) = &cfg.0.password { - scram( - &mut stream, - &mut in_buf, - &mut out_buf, - &proto, - &cfg.0.user, - password, - ) - .await?; - } else { - return Err(PasswordRequired::with_message( - "Password required for the specified user/host", - )); - } - } else { - return Err(AuthenticationError::with_message(format!( - "No supported authentication \ - methods: {:?}", - methods - ))); - } - } - ServerMessage::ErrorResponse(err) => { - return Err(err.into()); - } - msg => { - return Err(ProtocolError::with_message(format!( - "Error authenticating, unexpected message {:?}", - msg - ))); - } - } - - let mut server_params = ServerParams::new(); - let mut state_desc = RawTypedesc::uninitialized(); - loop { - let msg = wait_message(&mut stream, &mut in_buf, &proto).await?; - match msg { - ServerMessage::ReadyForCommand(ready) => { - assert_eq!(ready.transaction_state, TransactionState::NotInTransaction); - break; - } - ServerMessage::ServerKeyData(_) => { - // TODO(tailhook) store it somehow? - } - ServerMessage::ParameterStatus(par) => match &par.name[..] { - #[cfg(feature = "unstable")] - b"pgaddr" => { - use crate::server_params::PostgresAddress; - - let pgaddr: PostgresAddress = match serde_json::from_slice(&par.value[..]) { - Ok(a) => a, - Err(e) => { - log::warn!("Can't decode param {:?}: {}", par.name, e); - continue; - } - }; - server_params.set::(pgaddr); - } - #[cfg(feature = "unstable")] - b"pgdsn" => { - use crate::server_params::PostgresDsn; - - let pgdsn = match str::from_utf8(&par.value) { - Ok(a) => a.to_owned(), - Err(e) => { - log::warn!("Can't decode param {:?}: {}", par.name, e); - continue; - } - }; - - server_params.set::(PostgresDsn(pgdsn)); - } - b"system_config" => { - handle_system_config(par, &mut server_params)?; - } - _ => {} - }, - ServerMessage::StateDataDescription(d) => { - state_desc = d.typedesc; - } - ServerMessage::ErrorResponse(ErrorResponse { - severity, - code, - message, - attributes, - }) => { - log::warn!("Error received from server: {message}. Severity: {severity:?}. Code: {code:#x}"); - log::debug!("Error details: {attributes:?}"); - } - _ => { - log::warn!("unsolicited message {msg:#?}"); - } - } - } - Ok(Connection { - proto, - server_params, - mode: Mode::Normal { - idle_since: Instant::now(), - }, - transaction_state: TransactionState::NotInTransaction, - state_desc, - in_buf, - out_buf, - stream, - ping_interval: PingInterval::Unknown, - }) -} - -async fn scram( - stream: &mut TlsStream, - in_buf: &mut BytesMut, - out_buf: &mut BytesMut, - proto: &ProtocolVersion, - user: &str, - password: &str, -) -> Result<(), Error> { - use edgedb_protocol::client_message::SaslInitialResponse; - use edgedb_protocol::client_message::SaslResponse; - - let scram = ScramClient::new(user, password, None); - - let (scram, first) = scram.client_first(); - send_messages( - stream, - out_buf, - proto, - &[ClientMessage::AuthenticationSaslInitialResponse( - SaslInitialResponse { - method: "SCRAM-SHA-256".into(), - data: Bytes::copy_from_slice(first.as_bytes()), - }, - )], - ) - .await?; - let msg = wait_message(stream, in_buf, proto).await?; - let data = match msg { - ServerMessage::Authentication(Authentication::SaslContinue { data }) => data, - ServerMessage::ErrorResponse(err) => { - return Err(err.into()); - } - msg => { - return Err(ProtocolError::with_message(format!( - "Bad auth response: {:?}", - msg - ))); - } - }; - let data = str::from_utf8(&data[..]).map_err(|e| { - ProtocolError::with_source(e).context("invalid utf-8 in SCRAM-SHA-256 auth") - })?; - let scram = scram - .handle_server_first(data) - .map_err(AuthenticationError::with_source)?; - let (scram, data) = scram.client_final(); - send_messages( - stream, - out_buf, - proto, - &[ClientMessage::AuthenticationSaslResponse(SaslResponse { - data: Bytes::copy_from_slice(data.as_bytes()), - })], - ) - .await?; - let msg = wait_message(stream, in_buf, proto).await?; - let data = match msg { - ServerMessage::Authentication(Authentication::SaslFinal { data }) => data, - ServerMessage::ErrorResponse(err) => { - return Err(err.into()); - } - msg => { - return Err(ProtocolError::with_message(format!( - "auth response: {:?}", - msg - ))); - } - }; - let data = str::from_utf8(&data[..]) - .map_err(|_| ProtocolError::with_message("invalid utf-8 in SCRAM-SHA-256 auth"))?; - scram - .handle_server_final(data) - .map_err(|e| AuthenticationError::with_message(format!("Authentication error: {}", e)))?; - loop { - let msg = wait_message(stream, in_buf, proto).await?; - match msg { - ServerMessage::Authentication(Authentication::Ok) => break, - ServerMessage::ErrorResponse(ErrorResponse { - severity, - code, - message, - attributes, - }) => { - log::warn!("Error received from server: {message}. Severity: {severity:?}. Code: {code:#x}"); - log::debug!("Error details: {attributes:?}"); - } - msg => { - log::warn!("unsolicited message {msg:?}"); - } - }; - } - Ok(()) -} - -fn handle_system_config( - param_status: ParameterStatus, - server_params: &mut ServerParams, -) -> Result<(), Error> { - let (typedesc, data) = param_status - .parse_system_config() - .map_err(ProtocolEncodingError::with_source)?; - let codec = typedesc - .build_codec() - .map_err(ProtocolEncodingError::with_source)?; - let system_config = codec - .decode(data.as_ref()) - .map_err(ProtocolEncodingError::with_source)?; - let mut config = SystemConfig { - session_idle_timeout: None, - }; - if let Value::Object { shape, fields } = system_config { - for (el, field) in shape.elements.iter().zip(fields) { - match el.name.as_str() { - "id" => {} - "session_idle_timeout" => { - config.session_idle_timeout = match field { - Some(Value::Duration(timeout)) => Some(timeout.abs_duration()), - _ => { - log::warn!("Wrong protocol: {}={:?}", el.name, field); - None - } - }; - } - name => { - log::debug!("Unhandled system config: {}={:?}", name, field); - } - } - } - server_params.set::(config); - } else { - log::warn!("Received empty system config message."); - } - Ok(()) -} - -pub(crate) async fn send_messages<'x>( - stream: &mut (impl AsyncWrite + Unpin), - buf: &mut BytesMut, - proto: &ProtocolVersion, - messages: impl IntoIterator, -) -> Result<(), Error> { - buf.truncate(0); - for msg in messages { - log::debug!(target: "edgedb::outgoing::frame", - "Frame Contents: {:#?}", msg); - msg.encode(&mut Output::new(proto, buf)) - .map_err(ClientEncodingError::with_source)?; - } - stream - .write_all_buf(buf) - .await - .map_err(ClientConnectionError::with_source)?; - Ok(()) -} - -fn conn_err(err: io::Error) -> Error { - ClientConnectionError::with_source(err) -} - -pub async fn wait_message<'x>( - stream: &mut (impl AsyncRead + Unpin), - buf: &mut BytesMut, - proto: &ProtocolVersion, -) -> Result { - loop { - match _wait_message(stream, buf, proto).await? { - ServerMessage::LogMessage(msg) => { - match msg.severity { - MessageSeverity::Debug => { - log::debug!("[{}] {}", msg.code, msg.text); - } - MessageSeverity::Notice | MessageSeverity::Info => { - log::info!("[{}] {}", msg.code, msg.text); - } - MessageSeverity::Warning | MessageSeverity::Unknown(_) => { - log::warn!("[{}] {}", msg.code, msg.text); - } - } - continue; - } - msg => return Ok(msg), - } - } -} - -async fn _read_buf(stream: &mut (impl AsyncRead + Unpin), buf: &mut BytesMut) -> io::Result { - // Because of a combination of multiple different API impedence - // mismatches, when read_buf is called on a tls_api tokio stream, - // tls_api will zero the entire buffer on each call. This leads to - // pathological quadratic repeated zeroing when the buffer is much - // larger than the bytes read per call. - // (like for the 10MiB buffers for dump packets) - // - // We fix this by capping the size of the buffer that we pass to - // read_buf. - let cap = buf.spare_capacity_mut(); - let cap_len = cap.len(); - let mut rbuf = ReadBuf::uninit(&mut cap[..min(cap_len, 16 * 1024)]); - let n = stream.read_buf(&mut rbuf).await?; - unsafe { - buf.set_len(buf.len() + n); - } - Ok(n) -} - -async fn _wait_message<'x>( - stream: &mut (impl AsyncRead + Unpin), - buf: &mut BytesMut, - proto: &ProtocolVersion, -) -> Result { - while buf.len() < 5 { - buf.reserve(5); - if _read_buf(stream, buf).await.map_err(conn_err)? == 0 { - return Err(ClientConnectionEosError::with_message( - "end of stream while reading message", - )); - } - } - let len = u32::from_be_bytes(buf[1..5].try_into().unwrap()) as usize; - let frame_len = len + 1; - - while buf.len() < frame_len { - buf.reserve(frame_len - buf.len()); - if _read_buf(stream, buf).await.map_err(conn_err)? == 0 { - return Err(ClientConnectionEosError::with_message( - "end of stream while reading message", - )); - } - } - let frame = buf.split_to(frame_len).freeze(); - let result = ServerMessage::decode(&mut Input::new(proto.clone(), frame)) - .map_err(ProtocolEncodingError::with_source)?; - - log::debug!(target: "edgedb::incoming::frame", - "Frame Contents: {:#?}", result); - - Ok(result) -} - -fn connect_sleep() -> Duration { - Duration::from_millis(thread_rng().gen_range(10u64..200u64)) -} - -async fn connect_timeout(cfg: &Config, f: F) -> Result -where - F: Future>, -{ - use tokio::time::timeout; - - timeout(cfg.0.connect_timeout, f).await.unwrap_or_else(|_| { - Err(ClientConnectionFailedTemporarilyError::with_source( - io::Error::from(io::ErrorKind::TimedOut), - )) - }) -} - -fn is_temporary(e: &Error) -> bool { - use io::ErrorKind::{ - AddrNotAvailable, ConnectionAborted, ConnectionRefused, ConnectionReset, NotFound, - TimedOut, UnexpectedEof, - }; - - if e.is::() { - return true; - } - // todo(tailhook) figure out whether TLS api errors are properly unpacked - if e.is::() { - let io_err = e.source().and_then(|src| { - src.downcast_ref::() - .or_else(|| src.downcast_ref::>().map(|b| &**b)) - }); - if let Some(e) = io_err { - match e.kind() { - | ConnectionRefused - | ConnectionReset - | ConnectionAborted - | NotFound // For unix sockets - | TimedOut - | UnexpectedEof // For Docker server which is starting up - | AddrNotAvailable // Docker exposed ports not yet bound - => return true, - _ => {}, - } - } - } - false -} - -fn tls_fail(e: anyhow::Error) -> Error { - if let Some(e) = e.downcast_ref::() { - if matches!(e, rustls::Error::InvalidMessage(_)) { - return ProtocolTlsError::with_message( - "corrupt message, possibly server \ - does not support TLS connection.", - ); - } - } - ClientConnectionError::with_source_ref(e) -} diff --git a/edgedb-tokio/src/raw/dumps.rs b/edgedb-tokio/src/raw/dumps.rs deleted file mode 100644 index 7c6a28b9..00000000 --- a/edgedb-tokio/src/raw/dumps.rs +++ /dev/null @@ -1,308 +0,0 @@ -use std::collections::HashMap; -use std::mem; -use std::time::{Duration, Instant}; - -use bytes::Bytes; -use tokio::time::sleep; -use tokio_stream::{Stream, StreamExt}; - -use edgedb_errors::ProtocolOutOfOrderError; -use edgedb_errors::{Error, ErrorKind}; -use edgedb_protocol::client_message::{ClientMessage, Restore, RestoreBlock}; -use edgedb_protocol::client_message::{Dump2, Dump3, DumpFlags}; -use edgedb_protocol::server_message::{RawPacket, ServerMessage}; - -use crate::raw::connection::{send_messages, wait_message}; -use crate::raw::queries::Guard; -use crate::raw::{Connection, Response}; - -enum DumpState { - Header(RawPacket), - Blocks, - Complete(Response<()>), - Error(Error), - Reset, -} - -pub struct DumpStream<'a> { - conn: &'a mut Connection, - state: DumpState, - guard: Option, -} - -impl Connection { - pub async fn restore( - &mut self, - header: Bytes, - mut stream: impl Stream> + Unpin, - ) -> Result, Error> { - let guard = self.begin_request()?; - let start_headers = Instant::now(); - self.send_messages(&[ClientMessage::Restore(Restore { - headers: HashMap::new(), - jobs: 1, - data: header, - })]) - .await?; - - match self.message().await? { - ServerMessage::RestoreReady(_) => { - log::info!("Schema applied in {:?}", start_headers.elapsed()); - } - ServerMessage::ErrorResponse(err) => { - self.send_messages(&[ClientMessage::Sync]).await?; - self.expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(Into::::into(err).context("error initiating restore protocol")); - } - msg => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "unsolicited message {:?}", - msg - )))?; - } - } - - let start_blocks = Instant::now(); - let mut num_blocks = 0; - let mut total_len = 0; - while let Some(data) = stream.next().await.transpose()? { - num_blocks += 1; - total_len += data.len(); - let (mut rd, mut wr) = tokio::io::split(&mut self.stream); - let block = [ClientMessage::RestoreBlock(RestoreBlock { data })]; - tokio::select! { - msg = wait_message(&mut rd, &mut self.in_buf, &self.proto) - => match msg? { - ServerMessage::ErrorResponse(err) => { - self.send_messages(&[ClientMessage::Sync]).await?; - self.expect_ready_or_eos(guard).await - .map_err(|e| log::warn!( - "Error waiting for Ready \ - after error: {e:#}")) - .ok(); - return Err(Into::::into(err))?; - } - msg => { - return Err(ProtocolOutOfOrderError::with_message( - format!("unsolicited message {:?}", msg)))?; - } - }, - res = send_messages(&mut wr, &mut self.out_buf, - &self.proto, &block) - => res?, - } - log::info!(target: "edgedb::restore", "Block {num_blocks} processed: {:.02} MB restored", total_len as f64 / 1048576.0); - } - self.send_messages(&[ClientMessage::RestoreEof]).await?; - log::info!(target: "edgedb::restore", - "Database restored in {:?}", start_blocks.elapsed()); - - let wait = wait_print_loop(); - tokio::pin!(wait); - loop { - let msg = tokio::select! { - _ = &mut wait => unreachable!(), - msg = self.message() => msg?, - }; - match msg { - ServerMessage::StateDataDescription(d) => { - self.state_desc = d.typedesc; - } - ServerMessage::CommandComplete0(complete) => { - log::info!("Complete in {:?}", start_headers.elapsed()); - self.end_request(guard); - return Ok(Response { - status_data: complete.status_data, - new_state: None, - data: (), - warnings: vec![], - }); - } - ServerMessage::CommandComplete1(complete) => { - log::info!("Complete in {:?}", start_headers.elapsed()); - self.end_request(guard); - return Ok(Response { - status_data: complete.status_data, - new_state: complete.state, - data: (), - warnings: vec![], - }); - } - ServerMessage::ErrorResponse(err) => { - self.send_messages(&[ClientMessage::Sync]).await?; - self.expect_ready_or_eos(guard) - .await - .map_err(|e| { - log::warn!( - "Error waiting for Ready \ - after error: {e:#}" - ) - }) - .ok(); - return Err(Into::::into(err))?; - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "unsolicited message {:?}", - msg - )))?; - } - } - } - } - pub async fn dump(&mut self) -> Result, Error> { - self.dump_with_secrets(false).await - } - pub async fn dump_with_secrets(&mut self, with_secrets: bool) -> Result, Error> { - let guard = self.begin_request()?; - - if self.proto.is_3() { - let mut flags = DumpFlags::empty(); - if with_secrets { - flags |= DumpFlags::DUMP_SECRETS; - } - self.send_messages(&[ - ClientMessage::Dump3(Dump3 { - annotations: None, - flags, - }), - ClientMessage::Sync, - ]) - .await?; - } else { - let mut headers = HashMap::new(); - if with_secrets { - headers.insert(0xFF10, Bytes::from(vec![with_secrets as u8])); - } - - self.send_messages(&[ClientMessage::Dump2(Dump2 { headers }), ClientMessage::Sync]) - .await?; - } - - let msg = self.message().await?; - let header = match msg { - ServerMessage::DumpHeader(packet) => packet, - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(Into::::into(err).context("error receiving dump header")); - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "unsolicited message {:?}", - msg - )))?; - } - }; - Ok(DumpStream { - conn: self, - state: DumpState::Header(header), - guard: Some(guard), - }) - } -} - -impl DumpStream<'_> { - pub async fn complete(mut self) -> Result, Error> { - self.process_complete().await - } - pub fn take_header(&mut self) -> Option { - match mem::replace(&mut self.state, DumpState::Reset) { - DumpState::Header(header) => { - self.state = DumpState::Blocks; - Some(header) - } - state => { - self.state = state; - None - } - } - } - pub async fn next_block(&mut self) -> Option { - match &self.state { - DumpState::Header(_) | DumpState::Blocks => match self.conn.message().await { - Ok(ServerMessage::DumpBlock(packet)) => Some(packet), - Ok(ServerMessage::CommandComplete0(complete)) - if self.guard.is_some() && !self.conn.proto.is_1() => - { - let guard = self.guard.take().unwrap(); - if let Err(e) = self.conn.expect_ready(guard).await { - self.state = DumpState::Error(e) - } else { - self.state = DumpState::Complete(Response { - status_data: complete.status_data, - new_state: None, - data: (), - warnings: vec![], - }); - } - None - } - Ok(ServerMessage::CommandComplete1(complete)) - if self.guard.is_some() && self.conn.proto.is_1() => - { - let guard = self.guard.take().unwrap(); - if let Err(e) = self.conn.expect_ready(guard).await { - self.state = DumpState::Error(e) - } else { - self.state = DumpState::Complete(Response { - status_data: complete.status_data, - new_state: complete.state, - data: (), - warnings: vec![], - }); - } - None - } - Ok(ServerMessage::ErrorResponse(err)) => { - let guard = self.guard.take().unwrap(); - self.conn - .expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - self.state = DumpState::Error(err.into()); - None - } - Ok(msg) => { - self.state = DumpState::Error(ProtocolOutOfOrderError::with_message(format!( - "unsolicited message {:?}", - msg - ))); - None - } - Err(e) => { - self.state = DumpState::Error(e); - None - } - }, - _ => None, - } - } - pub async fn process_complete(&mut self) -> Result, Error> { - use DumpState::*; - - match mem::replace(&mut self.state, Reset) { - Header(..) | Blocks => panic!("process_complete() called too early"), - Complete(c) => Ok(c), - Error(e) => Err(e), - - Reset => panic!("process_complete() called twice"), - } - } -} - -async fn wait_print_loop() { - // This future should be canceled restore loop finishes - let start_waiting = Instant::now(); - loop { - sleep(Duration::from_secs(60)).await; - log::info!(target: "edgedb::restore", - "Waiting for complete {:?}", start_waiting.elapsed()); - } -} diff --git a/edgedb-tokio/src/raw/mod.rs b/edgedb-tokio/src/raw/mod.rs deleted file mode 100644 index a86b51fe..00000000 --- a/edgedb-tokio/src/raw/mod.rs +++ /dev/null @@ -1,191 +0,0 @@ -#![cfg_attr(not(feature = "unstable"), allow(dead_code))] - -mod connection; -#[cfg(feature = "unstable")] -mod dumps; -mod options; -mod queries; -mod response; -pub mod state; - -use std::collections::VecDeque; -use std::sync::{Arc, Mutex as BlockingMutex}; -use std::time::Duration; - -use bytes::{Bytes, BytesMut}; -use tls_api::TlsStream; -use tokio::sync::{self, Semaphore}; - -use edgedb_protocol::common::{Capabilities, RawTypedesc}; -use edgedb_protocol::features::ProtocolVersion; -use edgedb_protocol::server_message::CommandDataDescription1; -use edgedb_protocol::server_message::TransactionState; - -use crate::builder::Config; -use crate::errors::{ClientError, Error, ErrorKind}; -use crate::server_params::ServerParams; - -pub use options::Options; -pub use response::ResponseStream; -pub use state::{PoolState, State}; - -#[cfg(feature = "unstable")] -pub use dumps::DumpStream; - -#[derive(Clone, Debug)] -pub struct Pool(Arc); - -pub enum QueryCapabilities { - Unparsed, - Parsed(Capabilities), -} - -pub struct Description; - -#[derive(Debug)] -struct PoolInner { - pub config: Config, - pub semaphore: Arc, - pub queue: BlockingMutex>, -} - -#[derive(Debug)] -pub struct PoolConnection { - inner: Option, - #[allow(dead_code)] // needed only for Drop side effect - permit: sync::OwnedSemaphorePermit, - pool: Arc, -} - -#[derive(Debug)] -pub struct Connection { - proto: ProtocolVersion, - server_params: ServerParams, - mode: connection::Mode, - transaction_state: TransactionState, - state_desc: RawTypedesc, - in_buf: BytesMut, - out_buf: BytesMut, - stream: TlsStream, - ping_interval: PingInterval, -} - -#[derive(Debug)] -pub struct Response { - pub status_data: Bytes, - pub new_state: Option, - pub data: T, - pub warnings: Vec, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub(crate) enum PingInterval { - Unknown, - Disabled, - Interval(Duration), -} - -impl edgedb_errors::Field for QueryCapabilities { - const NAME: &'static str = "capabilities"; - type Value = QueryCapabilities; -} - -impl edgedb_errors::Field for Description { - const NAME: &'static str = "descriptor"; - type Value = CommandDataDescription1; -} - -impl Pool { - pub fn new(config: &Config) -> Pool { - let concurrency = config - .0 - .max_concurrency - // TODO(tailhook) use 1 and get concurrency from the connection - .unwrap_or(crate::builder::DEFAULT_POOL_SIZE); - Pool(Arc::new(PoolInner { - semaphore: Arc::new(Semaphore::new(concurrency)), - queue: BlockingMutex::new(VecDeque::with_capacity(concurrency)), - config: config.clone(), - })) - } - pub async fn acquire(&self) -> Result { - self.0.acquire().await - } -} - -impl PoolInner { - fn _next_conn(&self, _permit: &sync::OwnedSemaphorePermit) -> Option { - self.queue - .lock() - .expect("pool shared state mutex is not poisoned") - .pop_front() - } - async fn acquire(self: &Arc) -> Result { - let permit = self - .semaphore - .clone() - .acquire_owned() - .await - .map_err(|e| ClientError::with_source(e).context("cannot acquire connection"))?; - while let Some(mut conn) = self._next_conn(&permit) { - assert!(conn.is_consistent()); - if conn.is_connection_reset().await { - continue; - } - return Ok(PoolConnection { - inner: Some(conn), - permit, - pool: self.clone(), - }); - } - let conn = Connection::connect(&self.config).await?; - // Make sure that connection is wrapped before we commit, - // so that connection is returned into a pool if we fail - // to commit because of async stuff - Ok(PoolConnection { - inner: Some(conn), - permit, - pool: self.clone(), - }) - } -} - -impl PoolConnection { - pub fn is_consistent(&self) -> bool { - self.inner - .as_ref() - .map(|c| c.is_consistent()) - .unwrap_or(false) - } -} - -impl Drop for PoolConnection { - fn drop(&mut self) { - if let Some(conn) = self.inner.take() { - if conn.is_consistent() { - self.pool - .queue - .lock() - .expect("pool shared state mutex is not poisoned") - .push_back(conn); - } - } - } -} - -impl Response { - fn map(self, f: impl FnOnce(T) -> Result) -> Result, R> { - Ok(Response { - status_data: self.status_data, - new_state: self.new_state, - data: f(self.data)?, - warnings: self.warnings, - }) - } - - fn log_warnings(&self) { - for w in &self.warnings { - log::warn!(target: "edgedb_tokio::warning", "{w}"); - } - } -} diff --git a/edgedb-tokio/src/raw/options.rs b/edgedb-tokio/src/raw/options.rs deleted file mode 100644 index e8ed38b3..00000000 --- a/edgedb-tokio/src/raw/options.rs +++ /dev/null @@ -1,14 +0,0 @@ -use std::sync::Arc; - -use edgedb_protocol::encoding::Annotations; - -use crate::options::{RetryOptions, TransactionOptions}; -use crate::raw::state::PoolState; - -#[derive(Debug, Clone, Default)] -pub struct Options { - pub(crate) transaction: TransactionOptions, - pub(crate) retry: RetryOptions, - pub(crate) state: Arc, - pub(crate) annotations: Arc, -} diff --git a/edgedb-tokio/src/raw/queries.rs b/edgedb-tokio/src/raw/queries.rs deleted file mode 100644 index 4edf13ed..00000000 --- a/edgedb-tokio/src/raw/queries.rs +++ /dev/null @@ -1,735 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use bytes::{Bytes, BytesMut}; -use tokio::time::Instant; - -use edgedb_errors::fields::QueryText; -use edgedb_protocol::client_message::OptimisticExecute; -use edgedb_protocol::client_message::{ClientMessage, Parse, Prepare}; -use edgedb_protocol::client_message::{DescribeAspect, DescribeStatement}; -use edgedb_protocol::client_message::{Execute0, Execute1}; -use edgedb_protocol::common::CompilationOptions; -use edgedb_protocol::common::{Capabilities, Cardinality, InputLanguage, IoFormat}; -use edgedb_protocol::descriptors::Typedesc; -use edgedb_protocol::encoding::Annotations; -use edgedb_protocol::features::ProtocolVersion; -use edgedb_protocol::model::Uuid; -use edgedb_protocol::query_arg::{Encoder, QueryArgs}; -use edgedb_protocol::server_message::{CommandDataDescription1, PrepareComplete}; -use edgedb_protocol::server_message::{Data, ServerMessage}; -use edgedb_protocol::QueryResult; - -use crate::errors::NoResultExpected; -use crate::errors::{ClientConnectionEosError, ProtocolEncodingError}; -use crate::errors::{ClientInconsistentError, ProtocolOutOfOrderError}; -use crate::errors::{Error, ErrorKind}; -use crate::raw::connection::Mode; -use crate::raw::{Connection, PoolConnection, QueryCapabilities}; -use crate::raw::{Description, Response, ResponseStream, State}; - -pub(crate) struct Guard; - -impl Connection { - pub(crate) fn begin_request(&mut self) -> Result { - match self.mode { - Mode::Normal { .. } => { - self.mode = Mode::Dirty; - Ok(Guard) - } - Mode::Dirty => Err(ClientInconsistentError::build()), - // TODO(tailhook) technically we could just wait ping here - Mode::AwaitingPing => Err(ClientInconsistentError::with_message("interrupted ping")), - } - } - pub(crate) fn end_request(&mut self, _guard: Guard) { - self.mode = Mode::Normal { - idle_since: Instant::now(), - }; - } - pub(crate) async fn expect_ready(&mut self, guard: Guard) -> Result<(), Error> { - loop { - let msg = self.message().await?; - - // TODO(tailhook) should we react on messages somehow? - // At least parse LogMessage's? - - if let ServerMessage::ReadyForCommand(ready) = msg { - self.transaction_state = ready.transaction_state; - self.end_request(guard); - return Ok(()); - } - } - } - - pub(crate) async fn expect_ready_or_eos(&mut self, guard: Guard) -> Result<(), Error> { - match self.expect_ready(guard).await { - Ok(()) => Ok(()), - Err(e) if e.is::() => { - assert!(!self.is_consistent()); - Ok(()) - } - Err(e) => Err(e), - } - } - pub async fn parse( - &mut self, - flags: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - ) -> Result { - if self.proto.is_1() { - self._parse1(flags, query, state, annotations) - .await - .map_err(|e| e.set::(query)) - } else { - let pre = self - ._prepare0(flags, query) - .await - .map_err(|e| e.set::(query))?; - self._describe0(pre).await - } - } - async fn _parse1( - &mut self, - flags: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - ) -> Result { - let guard = self.begin_request()?; - self.send_messages(&[ - ClientMessage::Parse(Parse::new( - flags, - query, - state.encode(&self.state_desc)?, - self.proto.is_3().then(|| annotations.clone()), - )), - ClientMessage::Sync, - ]) - .await?; - - loop { - let msg = self.message().await?; - match msg { - ServerMessage::StateDataDescription(d) => { - self.state_desc = d.typedesc; - } - ServerMessage::CommandDataDescription1(data_desc) => { - self.expect_ready(guard).await?; - return Ok(data_desc); - } - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(err.into()); - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - ))); - } - } - } - } - async fn _prepare0( - &mut self, - flags: &CompilationOptions, - query: &str, - ) -> Result { - let guard = self.begin_request()?; - self.send_messages(&[ - ClientMessage::Prepare(Prepare::new(flags, query)), - ClientMessage::Sync, - ]) - .await?; - - match self.message().await? { - ServerMessage::PrepareComplete(data) => { - self.expect_ready(guard).await?; - Ok(data) - } - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - Err(err.into()) - } - msg => Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - ))), - } - } - async fn _describe0( - &mut self, - prepare: PrepareComplete, - ) -> Result { - let guard = self.begin_request()?; - self.send_messages(&[ - ClientMessage::DescribeStatement(DescribeStatement { - headers: HashMap::new(), - aspect: DescribeAspect::DataDescription, - statement_name: Bytes::from(""), - }), - ClientMessage::Sync, - ]) - .await?; - - let desc = match self.message().await? { - ServerMessage::CommandDataDescription0(data_desc) => { - self.expect_ready(guard).await?; - data_desc - } - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(err.into()); - } - msg => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - ))); - } - }; - // normalize CommandDataDescription0 into Parse (proto 1.x) output - Ok(CommandDataDescription1 { - annotations: HashMap::new(), - capabilities: prepare.get_capabilities().unwrap_or(Capabilities::ALL), - result_cardinality: prepare.cardinality, - input: desc.input, - output: desc.output, - }) - } - async fn _execute( - &mut self, - opts: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - desc: &CommandDataDescription1, - arguments: &Bytes, - ) -> Result>, Error> { - if self.proto.is_1() { - self._execute1(opts, query, state, annotations, desc, arguments) - .await - .map_err(|e| e.set::(query)) - } else { - self._execute0(arguments) - .await - .map_err(|e| e.set::(query)) - } - } - - async fn _execute1( - &mut self, - opts: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - desc: &CommandDataDescription1, - arguments: &Bytes, - ) -> Result>, Error> { - let guard = self.begin_request()?; - self.send_messages(&[ - ClientMessage::Execute1(Execute1 { - annotations: self.proto.is_3().then(|| annotations.clone()), - allowed_capabilities: opts.allow_capabilities, - compilation_flags: opts.flags(), - implicit_limit: opts.implicit_limit, - input_language: opts.input_language, - output_format: opts.io_format, - expected_cardinality: opts.expected_cardinality, - command_text: query.into(), - state: state.encode(&self.state_desc)?, - input_typedesc_id: desc.input.id, - output_typedesc_id: desc.output.id, - arguments: arguments.clone(), - }), - ClientMessage::Sync, - ]) - .await?; - - let mut data = Vec::new(); - let mut description = None; - let mut warnings: Vec = Vec::new(); - loop { - let msg = self.message().await?; - match msg { - ServerMessage::StateDataDescription(d) => { - self.state_desc = d.typedesc; - } - ServerMessage::CommandDataDescription1(desc) => { - warnings.extend(edgedb_protocol::annotations::decode_warnings( - &desc.annotations, - )?); - description = Some(desc); - } - ServerMessage::Data(datum) => { - data.push(datum); - } - ServerMessage::CommandComplete1(complete) => { - self.expect_ready(guard).await?; - return Ok(Response { - status_data: complete.status_data, - new_state: complete.state, - data, - warnings, - }); - } - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - let mut err: Error = err.into(); - if let Some(desc) = description { - err = err.set::(desc); - } - return Err(err); - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - ))); - } - } - } - } - - async fn _execute0(&mut self, arguments: &Bytes) -> Result>, Error> { - let guard = self.begin_request()?; - self.send_messages(&[ - ClientMessage::Execute0(Execute0 { - headers: HashMap::new(), - statement_name: Bytes::from(""), - arguments: arguments.clone(), - }), - ClientMessage::Sync, - ]) - .await?; - - let mut data = Vec::new(); - loop { - let msg = self.message().await?; - match msg { - ServerMessage::Data(datum) => { - data.push(datum); - } - ServerMessage::CommandComplete0(complete) => { - self.expect_ready(guard).await?; - return Ok(Response { - status_data: complete.status_data, - new_state: None, - data, - warnings: vec![], - }); - } - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(err.into()); - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - ))); - } - } - } - } - pub async fn execute_stream( - &mut self, - opts: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - desc: &CommandDataDescription1, - arguments: &A, - ) -> Result, Error> - where - A: QueryArgs, - R: QueryResult, - R::State: Unpin, - { - let inp_desc = desc.input().map_err(ProtocolEncodingError::with_source)?; - - let mut arg_buf = BytesMut::with_capacity(8); - arguments.encode(&mut Encoder::new( - &inp_desc.as_query_arg_context(), - &mut arg_buf, - ))?; - - let guard = self.begin_request()?; - if self.proto.is_1() { - self.send_messages(&[ - ClientMessage::Execute1(Execute1 { - annotations: self.proto.is_3().then(|| annotations.clone()), - allowed_capabilities: opts.allow_capabilities, - compilation_flags: opts.flags(), - implicit_limit: opts.implicit_limit, - input_language: opts.input_language, - output_format: opts.io_format, - expected_cardinality: opts.expected_cardinality, - command_text: query.into(), - state: state.encode(&self.state_desc)?, - input_typedesc_id: desc.input.id, - output_typedesc_id: desc.output.id, - arguments: arg_buf.freeze(), - }), - ClientMessage::Sync, - ]) - .await?; - } else { - // TODO(tailhook) maybe use OptimisticExecute instead? - self.send_messages(&[ - ClientMessage::Execute0(Execute0 { - headers: HashMap::new(), - statement_name: Bytes::from(""), - arguments: arg_buf.freeze(), - }), - ClientMessage::Sync, - ]) - .await?; - } - - let out_desc = desc.output().map_err(ProtocolEncodingError::with_source)?; - ResponseStream::new(self, &out_desc, guard).await - } - pub async fn try_execute_stream( - &mut self, - opts: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - input: &Typedesc, - output: &Typedesc, - arguments: &A, - ) -> Result, Error> - where - A: QueryArgs, - R: QueryResult, - R::State: Unpin, - { - let mut arg_buf = BytesMut::with_capacity(8); - arguments.encode(&mut Encoder::new( - &input.as_query_arg_context(), - &mut arg_buf, - ))?; - - let guard = self.begin_request()?; - if self.proto.is_1() { - self.send_messages(&[ - ClientMessage::Execute1(Execute1 { - annotations: self.proto.is_3().then(|| annotations.clone()), - allowed_capabilities: opts.allow_capabilities, - compilation_flags: opts.flags(), - implicit_limit: opts.implicit_limit, - input_language: opts.input_language, - output_format: opts.io_format, - expected_cardinality: opts.expected_cardinality, - command_text: query.into(), - state: state.encode(&self.state_desc)?, - input_typedesc_id: *input.id(), - output_typedesc_id: *input.id(), - arguments: arg_buf.freeze(), - }), - ClientMessage::Sync, - ]) - .await?; - } else { - self.send_messages(&[ - ClientMessage::OptimisticExecute(OptimisticExecute::new( - opts, - query, - arg_buf.freeze(), - *input.id(), - *output.id(), - )), - ClientMessage::Sync, - ]) - .await?; - } - - ResponseStream::new(self, output, guard).await - } - pub async fn statement( - &mut self, - flags: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - ) -> Result<(), Error> { - if self.proto.is_1() { - self._statement1(flags, query, state, annotations).await - } else { - self._statement0(flags, query).await - } - } - - async fn _statement1( - &mut self, - opts: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - ) -> Result<(), Error> { - let guard = self.begin_request()?; - self.send_messages(&[ - ClientMessage::Execute1(Execute1 { - annotations: self.proto.is_3().then(|| annotations.clone()), - allowed_capabilities: opts.allow_capabilities, - compilation_flags: opts.flags(), - implicit_limit: opts.implicit_limit, - input_language: opts.input_language, - output_format: opts.io_format, - expected_cardinality: opts.expected_cardinality, - command_text: query.into(), - state: state.encode(&self.state_desc)?, - input_typedesc_id: Uuid::from_u128(0), - output_typedesc_id: Uuid::from_u128(0), - arguments: Bytes::new(), - }), - ClientMessage::Sync, - ]) - .await?; - - loop { - let msg = self.message().await?; - match msg { - ServerMessage::StateDataDescription(d) => { - self.state_desc = d.typedesc; - } - ServerMessage::Data(_) => {} - ServerMessage::CommandComplete1(..) => { - self.expect_ready(guard).await?; - return Ok(()); - } - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(err.into()); - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - ))); - } - } - } - } - async fn _statement0(&mut self, flags: &CompilationOptions, query: &str) -> Result<(), Error> { - let guard = self.begin_request()?; - self.send_messages(&[ - ClientMessage::OptimisticExecute(OptimisticExecute::new( - flags, - query, - Bytes::new(), - Uuid::from_u128(0x0), - Uuid::from_u128(0x0), - )), - ClientMessage::Sync, - ]) - .await?; - - loop { - let msg = self.message().await?; - match msg { - ServerMessage::Data(_) => {} - ServerMessage::CommandComplete0(_) => { - self.expect_ready(guard).await?; - return Ok(()); - } - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(err.into()); - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - ))); - } - } - } - } - - pub async fn query( - &mut self, - query: &str, - arguments: &A, - state: &dyn State, - annotations: &Arc, - allow_capabilities: Capabilities, - io_format: IoFormat, - cardinality: Cardinality, - ) -> Result>, Error> - where - A: QueryArgs, - R: QueryResult, - { - let mut caps = QueryCapabilities::Unparsed; - let result = async { - let flags = CompilationOptions { - implicit_limit: None, - implicit_typenames: false, - implicit_typeids: false, - explicit_objectids: true, - allow_capabilities, - io_format, - input_language: InputLanguage::EdgeQL, - expected_cardinality: cardinality, - }; - let desc = self.parse(&flags, query, state, annotations).await?; - caps = QueryCapabilities::Parsed(desc.capabilities); - let inp_desc = desc.input().map_err(ProtocolEncodingError::with_source)?; - - let mut arg_buf = BytesMut::with_capacity(8); - if let Err(e) = arguments.encode(&mut Encoder::new( - &inp_desc.as_query_arg_context(), - &mut arg_buf, - )) { - return Err(e.set::(desc)); - } - - let response = self - ._execute(&flags, query, state, annotations, &desc, &arg_buf.freeze()) - .await?; - response.log_warnings(); - - let out_desc = desc.output().map_err(ProtocolEncodingError::with_source)?; - match out_desc.root_pos() { - Some(root_pos) => { - let ctx = out_desc.as_queryable_context(); - let mut state = R::prepare(&ctx, root_pos)?; - response.map(|data| { - data.into_iter() - .flat_map(|chunk| chunk.data) - .map(|chunk| R::decode(&mut state, &chunk)) - .collect::, _>>() - }) - } - None => Err(NoResultExpected::build()), - } - } - .await; - result.map_err(|e| e.set::(caps)) - } - - pub async fn execute( - &mut self, - query: &str, - arguments: &A, - state: &dyn State, - annotations: &Arc, - allow_capabilities: Capabilities, - ) -> Result, Error> - where - A: QueryArgs, - { - let mut caps = QueryCapabilities::Unparsed; - let result: Result<_, Error> = async { - let flags = CompilationOptions { - implicit_limit: None, - implicit_typenames: false, - implicit_typeids: false, - explicit_objectids: true, - allow_capabilities, - input_language: InputLanguage::EdgeQL, - io_format: IoFormat::Binary, - expected_cardinality: Cardinality::Many, - }; - let desc = self.parse(&flags, query, state, annotations).await?; - caps = QueryCapabilities::Parsed(desc.capabilities); - let inp_desc = desc.input().map_err(ProtocolEncodingError::with_source)?; - - let mut arg_buf = BytesMut::with_capacity(8); - if let Err(e) = arguments.encode(&mut Encoder::new( - &inp_desc.as_query_arg_context(), - &mut arg_buf, - )) { - return Err(e.set::(desc)); - } - - let response = self - ._execute(&flags, query, state, annotations, &desc, &arg_buf.freeze()) - .await?; - response.log_warnings(); - response.map(|_| Ok::<_, Error>(())) - } - .await; - result.map_err(|e| e.set::(caps)) - } -} - -impl PoolConnection { - pub async fn parse( - &mut self, - flags: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - ) -> Result { - self.inner().parse(flags, query, state, annotations).await - } - pub async fn execute( - &mut self, - opts: &CompilationOptions, - query: &str, - state: &dyn State, - annotations: &Arc, - desc: &CommandDataDescription1, - arguments: &Bytes, - ) -> Result, Error> { - self.inner() - ._execute(opts, query, state, annotations, desc, arguments) - .await - .map(|r| r.data) - } - pub async fn statement( - &mut self, - query: &str, - state: &dyn State, - annotations: &Arc, - ) -> Result<(), Error> { - let flags = CompilationOptions { - implicit_limit: None, - implicit_typenames: false, - implicit_typeids: false, - explicit_objectids: false, - allow_capabilities: Capabilities::ALL, - input_language: InputLanguage::EdgeQL, - io_format: IoFormat::Binary, - expected_cardinality: Cardinality::Many, // no result is unsupported - }; - self.inner().statement(&flags, query, state, annotations).await - } - pub fn proto(&self) -> &ProtocolVersion { - &self - .inner - .as_ref() - .expect("connection is not dropped") - .proto - } - pub fn inner(&mut self) -> &mut Connection { - self.inner.as_mut().expect("connection is not dropped") - } -} diff --git a/edgedb-tokio/src/raw/response.rs b/edgedb-tokio/src/raw/response.rs deleted file mode 100644 index 84dfb95b..00000000 --- a/edgedb-tokio/src/raw/response.rs +++ /dev/null @@ -1,319 +0,0 @@ -use std::collections::VecDeque; -use std::mem; - -use bytes::Bytes; -use edgedb_errors::ProtocolEncodingError; -use edgedb_errors::{Error, ErrorKind}; -use edgedb_errors::{ParameterTypeMismatchError, ProtocolOutOfOrderError}; -use edgedb_protocol::annotations::Warning; -use edgedb_protocol::common::State; -use edgedb_protocol::descriptors::Typedesc; -use edgedb_protocol::server_message::CommandDataDescription1; -use edgedb_protocol::server_message::{ErrorResponse, ServerMessage}; -use edgedb_protocol::{annotations, QueryResult}; - -use crate::raw::queries::Guard; -use crate::raw::{Connection, Description, Response}; - -enum Buffer { - Reading(VecDeque), - Complete { - status_data: Bytes, - new_state: Option, - }, - ErrorResponse(ErrorResponse), - Error(Error), - Reset, -} - -pub struct ResponseStream<'a, T: QueryResult> -where - T::State: Unpin, -{ - connection: &'a mut Connection, - buffer: Buffer, - state: Option, - guard: Option, - description: Option, - warnings: Vec, -} - -impl<'a, T: QueryResult> ResponseStream<'a, T> -where - T::State: Unpin, -{ - pub(crate) async fn new( - connection: &'a mut Connection, - out_desc: &Typedesc, - guard: Guard, - ) -> Result, Error> { - use Buffer::*; - - let buffer; - let mut description = None; - let mut guard = Some(guard); - loop { - match connection.message().await? { - ServerMessage::StateDataDescription(d) => { - connection.state_desc = d.typedesc; - } - ServerMessage::Data(datum) => { - buffer = Reading(datum.data.into()); - break; - } - ServerMessage::CommandComplete1(complete) if connection.proto.is_1() => { - let guard = guard.take().unwrap(); - connection.expect_ready(guard).await?; - buffer = Complete { - status_data: complete.status_data, - new_state: complete.state, - }; - break; - } - ServerMessage::CommandComplete0(complete) if !connection.proto.is_1() => { - let guard = guard.take().unwrap(); - connection.expect_ready(guard).await?; - buffer = Complete { - status_data: complete.status_data, - new_state: None, - }; - break; - } - ServerMessage::CommandDataDescription1(desc) if connection.proto.is_1() => { - description = Some(desc); - } - ServerMessage::CommandDataDescription0(desc) if !connection.proto.is_1() => { - let guard = guard.take().unwrap(); - connection.expect_ready(guard).await?; - let err = ParameterTypeMismatchError::build() - .set::(CommandDataDescription1::from(desc)); - return Err(err); - } - ServerMessage::ErrorResponse(err) => { - let guard = guard.take().unwrap(); - connection - .expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - let mut err: edgedb_errors::Error = err.into(); - if let Some(desc) = description.take() { - err = err.set::(desc); - } - return Err(err); - } - msg => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - )))?; - } - } - } - let warnings = description - .as_ref() - .map(|d| annotations::decode_warnings(&d.annotations)) - .transpose()? - .unwrap_or_default(); - let computed_desc = description - .as_ref() - .map(|d| d.output()) - .transpose() - .map_err(ProtocolEncodingError::with_source)?; - let computed_desc = computed_desc.as_ref().unwrap_or(out_desc); - if let Some(type_pos) = computed_desc.root_pos() { - let ctx = computed_desc.as_queryable_context(); - let state = T::prepare(&ctx, type_pos)?; - Ok(ResponseStream { - connection, - buffer, - state: Some(state), - guard, - description, - warnings, - }) - } else { - Ok(ResponseStream { - connection, - buffer, - state: None, - guard, - description, - warnings, - }) - } - } - pub fn can_contain_data(&self) -> bool { - self.state.is_some() - } - async fn expect_ready(&mut self) { - let guard = self.guard.take().expect("guard is checked before"); - if let Err(e) = self.connection.expect_ready(guard).await { - self.buffer = Buffer::Error(e); - } - } - async fn ignore_data(&mut self) { - use Buffer::*; - - loop { - match self.connection.message().await { - Ok(ServerMessage::StateDataDescription(d)) => { - self.connection.state_desc = d.typedesc; - } - Ok(ServerMessage::Data(_)) if self.state.is_some() => {} - Ok(ServerMessage::CommandComplete1(complete)) - if self.guard.is_some() && self.connection.proto.is_1() => - { - self.buffer = Complete { - status_data: complete.status_data, - new_state: complete.state, - }; - self.expect_ready().await; - return; - } - Ok(ServerMessage::CommandComplete0(complete)) - if self.guard.is_some() && !self.connection.proto.is_1() => - { - self.buffer = Complete { - status_data: complete.status_data, - new_state: None, - }; - self.expect_ready().await; - return; - } - Ok(ServerMessage::ErrorResponse(err)) if self.guard.is_some() => { - let guard = self.guard.take().unwrap(); - self.connection - .expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - self.buffer = ErrorResponse(err); - return; - } - Ok(msg) => { - self.buffer = Error(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - ))); - return; - } - Err(e) => { - self.buffer = Error(e); - return; - } - } - } - } - pub async fn next_element(&mut self) -> Option { - use Buffer::*; - - let Reading(ref mut buffer) = self.buffer else { - return None; - }; - loop { - if let Some(element) = buffer.pop_front() { - let state = self - .state - .as_mut() - .expect("data packets are ignored if state is None"); - match T::decode(state, &element) { - Ok(value) => return Some(value), - Err(e) => { - self.ignore_data().await; - self.buffer = Error(e); - return None; - } - } - } - match self.connection.message().await { - Ok(ServerMessage::StateDataDescription(d)) => { - self.connection.state_desc = d.typedesc; - } - Ok(ServerMessage::Data(datum)) if self.state.is_some() => { - buffer.extend(datum.data); - } - Ok(ServerMessage::CommandComplete1(complete)) - if self.guard.is_some() && self.connection.proto.is_1() => - { - self.expect_ready().await; - self.buffer = Complete { - status_data: complete.status_data, - new_state: complete.state, - }; - return None; - } - Ok(ServerMessage::CommandComplete0(complete)) - if self.guard.is_some() && !self.connection.proto.is_1() => - { - self.expect_ready().await; - self.buffer = Complete { - status_data: complete.status_data, - new_state: None, - }; - return None; - } - Ok(ServerMessage::ErrorResponse(err)) if self.guard.is_some() => { - let guard = self.guard.take().unwrap(); - self.connection - .expect_ready_or_eos(guard) - .await - .map_err(|e| log::warn!("Error waiting for Ready after error: {e:#}")) - .ok(); - self.buffer = ErrorResponse(err); - return None; - } - Ok(msg) => { - self.buffer = Error(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", - msg - ))); - return None; - } - Err(e) => { - self.buffer = Error(e); - return None; - } - } - } - } - pub fn warnings(&self) -> &[Warning] { - &self.warnings - } - pub async fn complete(mut self) -> Result, Error> { - self.process_complete().await - } - pub async fn process_complete(&mut self) -> Result, Error> { - use Buffer::*; - while matches!(self.buffer, Reading(_)) { - self.ignore_data().await - } - - match mem::replace(&mut self.buffer, Buffer::Reset) { - Reading(_) => unreachable!(), - Complete { - status_data, - new_state, - } => { - let warnings = std::mem::take(&mut self.warnings); - let response = Response { - status_data, - new_state, - data: (), - warnings, - }; - response.log_warnings(); - Ok(response) - } - Error(e) => Err(e), - ErrorResponse(e) => { - let mut err: edgedb_errors::Error = e.into(); - if let Some(desc) = self.description.take() { - err = err.set::(desc); - } - Err(err) - } - Reset => panic!("process_complete() called twice"), - } - } -} diff --git a/edgedb-tokio/src/raw/state.rs b/edgedb-tokio/src/raw/state.rs deleted file mode 100644 index a8bd2d26..00000000 --- a/edgedb-tokio/src/raw/state.rs +++ /dev/null @@ -1,445 +0,0 @@ -//! Connection state modification utilities - -use std::collections::{BTreeMap, HashMap}; -use std::sync::Arc; - -use arc_swap::ArcSwapOption; -use edgedb_protocol::client_message::State as EncodedState; -use edgedb_protocol::descriptors::{RawTypedesc, StateBorrow}; -use edgedb_protocol::model::Uuid; -use edgedb_protocol::query_arg::QueryArg; -use edgedb_protocol::value::Value; - -use crate::errors::{ClientError, Error, ErrorKind, ProtocolEncodingError}; - -/// Unset a set of global or config variables -/// -/// Accepts an iterator of names. Used with globals lie this: -/// -/// ```rust,no_run -/// # use edgedb_tokio::state::Unset; -/// # #[tokio::main] -/// # async fn main() { -/// # let conn = edgedb_tokio::create_client().await.unwrap(); -/// let conn = conn.with_globals(Unset(["xxx", "yyy"])); -/// # } -/// ``` -#[derive(Debug)] -pub struct Unset(pub I); - -/// Use a closure to set or unset global or config variables -/// -/// ```rust,no_run -/// # use edgedb_tokio::state::{Fn, GlobalsModifier}; -/// # #[tokio::main] -/// # async fn main() { -/// # let conn = edgedb_tokio::create_client().await.unwrap(); -/// let conn = conn.with_globals(Fn(|m: &mut GlobalsModifier| { -/// m.set("x", "x_value"); -/// m.unset("y"); -/// })); -/// # } -/// ``` -#[derive(Debug)] -pub struct Fn(pub F); - -#[derive(Debug)] -pub struct PoolState { - raw_state: RawState, - cache: ArcSwapOption, -} - -#[derive(Debug)] -struct RawState { - // The idea behind the split between common and globals is that once - // setting module/aliases/config becomes bottleneck it's possible to have - // connection pool with those settings pre-set. But globals are supposed to - // have per-request values more often (one example is having `user_id` - // global variable). - common: Arc, - globals: BTreeMap, -} - -#[derive(Debug)] -struct CommonState { - module: Option, - aliases: BTreeMap, - config: BTreeMap, -} - -/// Utility object used to modify globals -/// -/// This object is passed to [`Fn`] closure and [`GlobalsDelta::apply`]. -#[derive(Debug)] -pub struct GlobalsModifier<'a> { - globals: &'a mut BTreeMap, - module: &'a str, - aliases: &'a BTreeMap, -} - -/// Utility object used to modify config -/// -/// This object is passed to [`Fn`] closure and [`ConfigDelta::apply`]. -#[derive(Debug)] -pub struct ConfigModifier<'a> { - config: &'a mut BTreeMap, -} - -/// Utility object used to modify aliases -/// -/// This object is passed to [`AliasesDelta::apply`] to do the actual -/// modification -#[derive(Debug)] -pub struct AliasesModifier<'a> { - data: &'a mut BTreeMap, -} - -/// Trait that modifies global variables -pub trait GlobalsDelta { - /// Applies variables delta using specified modifier object - fn apply(self, man: &mut GlobalsModifier<'_>); -} - -/// Trait that modifies config variables -pub trait ConfigDelta { - /// Applies variables delta using specified modifier object - fn apply(self, man: &mut ConfigModifier<'_>); -} - -/// Trait that modifies module aliases -pub trait AliasesDelta { - /// Applies variables delta using specified modifier object - fn apply(self, man: &mut AliasesModifier); -} - -pub trait SealedState { - fn encode(&self, desc: &RawTypedesc) -> Result; -} - -/// Provides state of the session in the binary form -/// -/// This trait is sealed. -pub trait State: SealedState + Send + Sync {} - -impl GlobalsModifier<'_> { - /// Set global variable to a value - /// - /// If `key` doesn't contain module name (`::` to be more - /// specific) then the variable name is resolved using current module. - /// Otherwise, modules are resolved using aliases if any. Note: modules are - /// resolved at method call time. This means that a sequence like this: - /// ```rust,no_run - /// # use edgedb_tokio::state::Fn; - /// # #[tokio::main] - /// # async fn main() { - /// # let conn = edgedb_tokio::create_client().await.unwrap(); - /// let conn = conn - /// .with_globals_fn(|m| m.set("var1", "value1")) - /// .with_default_module(Some("another_module")) - /// .with_globals_fn(|m| m.set("var1", "value2")); - /// # } - /// ``` - /// Will set `var1` in `default` and in `another_module` to different - /// values. - /// - /// # Panics - /// - /// This methods panics if `value` cannot be converted into dynamically - /// typed `Value` (`QueryArg::to_value()` method returns error). To avoid - /// this panic use either native EdgeDB types (e.g. - /// `edgedb_protocol::model::Datetime` instead of `std::time::SystemTime` - /// or call `to_value` manually before passing to `set`. - pub fn set(&mut self, key: &str, value: T) { - let value = value.to_value().expect("global can be encoded"); - if let Some(ns_off) = key.rfind("::") { - if let Some(alias) = self.aliases.get(&key[..ns_off]) { - self.globals.insert( - format!("{alias}::{suffix}", suffix = &key[ns_off + 2..]), - value, - ); - } else { - self.globals.insert(key.into(), value); - } - } else { - self.globals - .insert(format!("{}::{}", self.module, key), value); - } - } - /// Unset the global variable - /// - /// In most cases this will effectively set the variable to a default - /// value. - /// - /// To set variable to the actual empty value use `set("name", - /// Value::Nothing)`. - /// - /// Note: same namespacing rules like for `set` are applied here. - pub fn unset(&mut self, key: &str) { - if let Some(ns_off) = key.rfind("::") { - if let Some(alias) = self.aliases.get(&key[..ns_off]) { - self.globals - .remove(&format!("{alias}::{suffix}", suffix = &key[ns_off + 2..])); - } else { - self.globals.remove(key); - } - } else { - self.globals.remove(&format!("{}::{}", self.module, key)); - } - } -} - -impl ConfigModifier<'_> { - /// Set configuration setting to a value - /// - /// # Panics - /// - /// This methods panics if `value` cannot be converted into dynamically - /// typed `Value` (`QueryArg::to_value()` method returns error). To avoid - /// this panic use either native EdgeDB types (e.g. - /// `edgedb_protocol::model::Datetime` instead of `std::time::SystemTime` - /// or call `to_value` manually before passing to `set`. - pub fn set(&mut self, key: &str, value: T) { - let value = value.to_value().expect("config can be encoded"); - self.config.insert(key.into(), value); - } - /// Unset the global variable - /// - /// In most cases this will effectively set the variable to a default - /// value. - /// - /// To set setting to the actual empty value use `set("name", - /// Value::Nothing)`. - pub fn unset(&mut self, key: &str) { - self.config.remove(key); - } -} - -impl AliasesModifier<'_> { - /// Set a module alias - pub fn set(&mut self, key: &str, value: &str) { - self.data.insert(key.into(), value.into()); - } - /// Unsed a module alias - pub fn unset(&mut self, key: &str) { - self.data.remove(key); - } -} - -impl, I: IntoIterator> GlobalsDelta for Unset { - fn apply(self, man: &mut GlobalsModifier) { - for key in self.0.into_iter() { - man.unset(key.as_ref()); - } - } -} - -impl, I: IntoIterator> ConfigDelta for Unset { - fn apply(self, man: &mut ConfigModifier) { - for key in self.0.into_iter() { - man.unset(key.as_ref()); - } - } -} - -impl, I: IntoIterator> AliasesDelta for Unset { - fn apply(self, man: &mut AliasesModifier) { - for key in self.0.into_iter() { - man.unset(key.as_ref()); - } - } -} - -impl)> GlobalsDelta for Fn { - fn apply(self, man: &mut GlobalsModifier) { - self.0(man) - } -} - -impl)> ConfigDelta for Fn { - fn apply(self, man: &mut ConfigModifier) { - self.0(man) - } -} - -impl)> AliasesDelta for Fn { - fn apply(self, man: &mut AliasesModifier) { - self.0(man) - } -} - -impl, V: AsRef> AliasesDelta for BTreeMap { - fn apply(self, man: &mut AliasesModifier) { - for (key, value) in self { - man.set(key.as_ref(), value.as_ref()); - } - } -} - -impl, V: AsRef> AliasesDelta for HashMap { - fn apply(self, man: &mut AliasesModifier) { - for (key, value) in self { - man.set(key.as_ref(), value.as_ref()); - } - } -} - -impl, V: AsRef> AliasesDelta for &'_ BTreeMap { - fn apply(self, man: &mut AliasesModifier) { - for (key, value) in self { - man.set(key.as_ref(), value.as_ref()); - } - } -} - -impl, V: AsRef> AliasesDelta for &'_ HashMap { - fn apply(self, man: &mut AliasesModifier) { - for (key, value) in self { - man.set(key.as_ref(), value.as_ref()); - } - } -} - -impl, V: QueryArg> GlobalsDelta for BTreeMap { - fn apply(self, man: &mut GlobalsModifier) { - for (key, value) in self { - let value = value.to_value().expect("global can be encoded"); - man.set(key.as_ref(), value); - } - } -} - -impl, V: QueryArg> ConfigDelta for BTreeMap { - fn apply(self, man: &mut ConfigModifier) { - for (key, value) in self { - let value = value.to_value().expect("global can be encoded"); - man.set(key.as_ref(), value); - } - } -} - -impl PoolState { - pub fn with_default_module(&self, module: Option) -> Self { - PoolState { - raw_state: RawState { - common: Arc::new(CommonState { - module, - aliases: self.raw_state.common.aliases.clone(), - config: self.raw_state.common.config.clone(), - }), - globals: self.raw_state.globals.clone(), - }, - cache: ArcSwapOption::new(None), - } - } - pub fn with_globals(&self, delta: impl GlobalsDelta) -> Self { - let mut globals = self.raw_state.globals.clone(); - delta.apply(&mut GlobalsModifier { - module: self.raw_state.common.module.as_deref().unwrap_or("default"), - aliases: &self.raw_state.common.aliases, - globals: &mut globals, - }); - PoolState { - raw_state: RawState { - common: self.raw_state.common.clone(), - globals, - }, - cache: ArcSwapOption::new(None), - } - } - pub fn with_config(&self, delta: impl ConfigDelta) -> Self { - let mut config = self.raw_state.common.config.clone(); - delta.apply(&mut ConfigModifier { - config: &mut config, - }); - PoolState { - raw_state: RawState { - common: Arc::new(CommonState { - module: self.raw_state.common.module.clone(), - aliases: self.raw_state.common.aliases.clone(), - config, - }), - globals: self.raw_state.globals.clone(), - }, - cache: ArcSwapOption::new(None), - } - } - - pub fn with_aliases(&self, delta: impl AliasesDelta) -> Self { - let mut aliases = self.raw_state.common.aliases.clone(); - delta.apply(&mut AliasesModifier { data: &mut aliases }); - PoolState { - raw_state: RawState { - common: Arc::new(CommonState { - module: self.raw_state.common.module.clone(), - aliases, - config: self.raw_state.common.config.clone(), - }), - globals: self.raw_state.globals.clone(), - }, - cache: ArcSwapOption::new(None), - } - } - pub fn encode(&self, desc: &RawTypedesc) -> Result { - if let Some(cache) = &*self.cache.load() { - if cache.typedesc_id == desc.id { - return Ok((**cache).clone()); - } - } - let typedesc = desc.decode().map_err(ProtocolEncodingError::with_source)?; - let result = typedesc.serialize_state(&StateBorrow { - module: &self.raw_state.common.module, - aliases: &self.raw_state.common.aliases, - config: &self.raw_state.common.config, - globals: &self.raw_state.globals, - })?; - self.cache.store(Some(Arc::new(result.clone()))); - Ok(result) - } -} - -impl SealedState for &PoolState { - fn encode(&self, desc: &RawTypedesc) -> Result { - PoolState::encode(self, desc) - } -} -impl State for &PoolState {} -impl SealedState for Arc { - fn encode(&self, desc: &RawTypedesc) -> Result { - PoolState::encode(self, desc) - } -} -impl State for Arc {} - -impl SealedState for EncodedState { - fn encode(&self, desc: &RawTypedesc) -> Result { - if self.typedesc_id == Uuid::from_u128(0) || self.typedesc_id == desc.id { - return Ok((*self).clone()); - } - Err(ClientError::with_message( - "state doesn't match state descriptor", - )) - } -} -impl State for EncodedState {} -impl SealedState for Arc { - fn encode(&self, desc: &RawTypedesc) -> Result { - (**self).encode(desc) - } -} -impl State for Arc {} - -impl Default for PoolState { - fn default() -> PoolState { - PoolState { - raw_state: RawState { - common: Arc::new(CommonState { - module: None, - aliases: Default::default(), - config: Default::default(), - }), - globals: Default::default(), - }, - cache: ArcSwapOption::new(None), - } - } -} diff --git a/edgedb-tokio/src/sealed.rs b/edgedb-tokio/src/sealed.rs deleted file mode 100644 index e894357d..00000000 --- a/edgedb-tokio/src/sealed.rs +++ /dev/null @@ -1 +0,0 @@ -pub trait SealedParam {} diff --git a/edgedb-tokio/src/server_params.rs b/edgedb-tokio/src/server_params.rs deleted file mode 100644 index 2cfe042f..00000000 --- a/edgedb-tokio/src/server_params.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! Parameters returned by the server on initial handshake -use std::any::{Any, TypeId}; -use std::collections::HashMap; -use std::fmt; -use std::time::Duration; - -use serde::{Deserialize, Serialize}; - -use crate::sealed::SealedParam; - -#[derive(Debug)] -pub(crate) struct ServerParams(HashMap>); - -/// Address of the underlying postgres, available only in dev mode. -#[derive(Deserialize, Debug, Serialize)] -pub struct PostgresAddress { - pub host: String, - pub port: u16, - pub user: String, - pub password: Option, - pub database: String, - pub server_settings: HashMap, -} - -/// A trait that represents a param sent from the server. -pub trait ServerParam: SealedParam + 'static { - type Value: fmt::Debug + Send + Sync + 'static; -} - -impl ServerParam for PostgresAddress { - type Value = PostgresAddress; -} - -impl SealedParam for PostgresAddress {} - -#[derive(Debug)] -#[allow(dead_code)] -pub struct PostgresDsn(pub String); - -impl ServerParam for PostgresDsn { - type Value = PostgresDsn; -} - -impl SealedParam for PostgresDsn {} - -/// ParameterStatus_SystemConfig -#[derive(Debug)] -pub struct SystemConfig { - pub session_idle_timeout: Option, -} - -impl ServerParam for SystemConfig { - type Value = SystemConfig; -} - -impl SealedParam for SystemConfig {} - -impl ServerParams { - pub fn new() -> ServerParams { - ServerParams(HashMap::new()) - } - pub fn set(&mut self, value: T::Value) { - self.0.insert(TypeId::of::(), Box::new(value)); - } - pub fn get(&self) -> Option<&T::Value> { - self.0 - .get(&TypeId::of::()) - .and_then(|v| v.downcast_ref()) - } -} diff --git a/edgedb-tokio/src/state.rs b/edgedb-tokio/src/state.rs deleted file mode 100644 index 183c6669..00000000 --- a/edgedb-tokio/src/state.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! Connection state modification utilities -//! - -pub use crate::raw::state::State; -pub use crate::raw::state::{AliasesDelta, ConfigDelta, GlobalsDelta}; -pub use crate::raw::state::{AliasesModifier, ConfigModifier, GlobalsModifier}; -pub use crate::raw::state::{Fn, Unset}; diff --git a/edgedb-tokio/src/tls.rs b/edgedb-tokio/src/tls.rs deleted file mode 100644 index 249b2a91..00000000 --- a/edgedb-tokio/src/tls.rs +++ /dev/null @@ -1,183 +0,0 @@ -use std::io; -use std::sync::Arc; - -use anyhow::Context; -use rustls::client::danger::HandshakeSignatureValid; -use rustls::client::danger::{ServerCertVerified, ServerCertVerifier}; -use rustls::crypto::WebPkiSupportedAlgorithms; -use rustls::crypto::{self, ring}; -use rustls::crypto::{verify_tls12_signature, verify_tls13_signature}; -use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; -use rustls::{DigitallySignedStruct, SignatureScheme}; -use tls_api::TlsConnectorBox; -use tls_api::{TlsConnector as _, TlsConnectorBuilder as _}; -use tls_api_rustls::TlsConnector; - -#[derive(Debug)] -pub struct NullVerifier; - -#[derive(Debug)] -pub struct NoHostnameVerifier { - roots: Arc, - supported: WebPkiSupportedAlgorithms, -} - -impl NoHostnameVerifier { - pub fn new(roots: Arc) -> Self { - NoHostnameVerifier { - roots, - supported: ring::default_provider().signature_verification_algorithms, - } - } -} - -impl ServerCertVerifier for NoHostnameVerifier { - fn verify_server_cert( - &self, - end_entity: &CertificateDer<'_>, - intermediates: &[CertificateDer<'_>], - _server_name: &ServerName, - _ocsp_response: &[u8], - now: UnixTime, - ) -> Result { - let cert = webpki::EndEntityCert::try_from(end_entity).map_err(pki_error)?; - - let result = cert.verify_for_usage( - self.supported.all, - &self.roots.roots, - intermediates, - now, - webpki::KeyUsage::server_auth(), - None, - None, - ); - - match result { - Ok(_) => Ok(ServerCertVerified::assertion()), - Err(e) => Err(pki_error(e)), - } - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - verify_tls12_signature(message, cert, dss, &self.supported) - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - verify_tls13_signature(message, cert, dss, &self.supported) - } - - fn supported_verify_schemes(&self) -> Vec { - ring::default_provider() - .signature_verification_algorithms - .supported_schemes() - } -} - -impl ServerCertVerifier for NullVerifier { - fn verify_server_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &ServerName, - _ocsp_response: &[u8], - _now: UnixTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - ring::default_provider() - .signature_verification_algorithms - .supported_schemes() - } -} - -pub fn connector(cert_verifier: Arc) -> anyhow::Result { - // ensure that crypto provider is installed per-process - // if users of edgedb-tokio have not installed a provider, we install ring here - if crypto::CryptoProvider::get_default().is_none() { - crypto::ring::default_provider().install_default().ok(); - } - - let mut builder = TlsConnector::builder()?; - builder - .config - .dangerous() - .set_certificate_verifier(cert_verifier); - builder.set_alpn_protocols(&[b"edgedb-binary"])?; - let connector = builder.build()?.into_dyn(); - Ok(connector) -} - -pub fn read_root_cert_pem(data: &str) -> anyhow::Result { - let mut cursor = io::Cursor::new(data); - let open_data = rustls_pemfile::read_all(&mut cursor); - let mut cert_store = rustls::RootCertStore::empty(); - for item in open_data { - match item { - Ok(rustls_pemfile::Item::X509Certificate(data)) => { - cert_store - .add(data) - .context("certificate data found, but is not a valid root certificate")?; - } - Ok(rustls_pemfile::Item::Pkcs1Key(_)) - | Ok(rustls_pemfile::Item::Pkcs8Key(_)) - | Ok(rustls_pemfile::Item::Sec1Key(_)) => { - log::debug!("Skipping private key in cert data"); - } - Ok(rustls_pemfile::Item::Crl(_)) => { - log::debug!("Skipping CRL in cert data"); - } - Ok(_) => { - log::debug!("Skipping unknown item cert data"); - } - Err(e) => { - log::error!("could not parse item in PEM file: {:?}", e); - } - } - } - Ok(cert_store) -} - -fn pki_error(error: webpki::Error) -> rustls::Error { - use webpki::Error::*; - match error { - BadDer | BadDerTime => { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding) - } - InvalidSignatureForPublicKey => { - rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature) - } - e => rustls::Error::InvalidCertificate(rustls::CertificateError::Other( - rustls::OtherError(Arc::new(e)), - )), - } -} diff --git a/edgedb-tokio/src/transaction.rs b/edgedb-tokio/src/transaction.rs deleted file mode 100644 index 8b7e6d1d..00000000 --- a/edgedb-tokio/src/transaction.rs +++ /dev/null @@ -1,420 +0,0 @@ -use std::future::Future; -use std::sync::Arc; - -use bytes::BytesMut; -use edgedb_protocol::common::CompilationOptions; -use edgedb_protocol::common::{Capabilities, Cardinality, InputLanguage, IoFormat}; -use edgedb_protocol::model::Json; -use edgedb_protocol::query_arg::{Encoder, QueryArgs}; -use edgedb_protocol::QueryResult; -use tokio::sync::oneshot; -use tokio::time::sleep; - -use crate::errors::ClientError; -use crate::errors::{Error, ErrorKind, SHOULD_RETRY}; -use crate::errors::{NoDataError, ProtocolEncodingError}; -use crate::raw::{Options, Pool, PoolConnection, Response}; -use crate::ResultVerbose; - -/// Transaction object passed to the closure via -/// [`Client::transaction()`](crate::Client::transaction) method -/// -/// The Transaction object must be dropped by the end of the closure execution. -/// -/// All database queries in transaction should be executed using methods on -/// this object instead of using original [`Client`](crate::Client) instance. -#[derive(Debug)] -pub struct Transaction { - iteration: u32, - options: Arc, - inner: Option, -} - -#[derive(Debug)] -pub struct TransactionResult { - conn: PoolConnection, - started: bool, -} - -#[derive(Debug)] -pub struct Inner { - started: bool, - conn: PoolConnection, - return_conn: oneshot::Sender, -} - -impl Drop for Transaction { - fn drop(&mut self) { - self.inner.take().map( - |Inner { - started, - conn, - return_conn, - }| { return_conn.send(TransactionResult { started, conn }).ok() }, - ); - } -} - -pub(crate) async fn transaction( - pool: &Pool, - options: Arc, - mut body: B, -) -> Result -where - B: FnMut(Transaction) -> F, - F: Future>, -{ - let mut iteration = 0; - 'transaction: loop { - let conn = pool.acquire().await?; - let (tx, mut rx) = oneshot::channel(); - let tran = Transaction { - iteration, - options: options.clone(), - inner: Some(Inner { - started: false, - conn, - return_conn: tx, - }), - }; - let result = body(tran).await; - let TransactionResult { mut conn, started } = rx.try_recv().expect( - "Transaction object must \ - be dropped by the time transaction body finishes.", - ); - match result { - Ok(val) => { - log::debug!("Comitting transaction"); - if started { - conn.statement("COMMIT", &options.state, &options.annotations) - .await?; - } - return Ok(val); - } - Err(outer) => { - log::debug!("Rolling back transaction on error"); - if started { - conn.statement("ROLLBACK", &options.state, &options.annotations) - .await?; - } - - let some_retry = outer.chain().find_map(|e| { - e.downcast_ref::().and_then(|e| { - if e.has_tag(SHOULD_RETRY) { - Some(e) - } else { - None - } - }) - }); - - if some_retry.is_none() { - return Err(outer); - } else { - let e = some_retry.unwrap(); - let rule = options.retry.get_rule(e); - if iteration >= rule.attempts { - return Err(outer); - } else { - log::info!("Retrying transaction on {:#}", e); - iteration += 1; - sleep((rule.backoff)(iteration)).await; - continue 'transaction; - } - } - } - } - } -} - -fn assert_transaction(x: &mut Option) -> &mut PoolConnection { - &mut x.as_mut().expect("transaction object is dropped").conn -} - -impl Transaction { - /// Zero-based iteration (attempt) number for the current transaction - /// - /// First attempt gets `iteration = 0`, second attempt `iteration = 1`, - /// etc. - pub fn iteration(&self) -> u32 { - self.iteration - } - async fn ensure_started(&mut self) -> anyhow::Result<(), Error> { - if let Some(inner) = &mut self.inner { - if !inner.started { - let options = &self.options; - inner - .conn - .statement("START TRANSACTION", &options.state, &options.annotations) - .await?; - inner.started = true; - } - return Ok(()); - } - Err(ClientError::with_message("using transaction after drop")) - } - - async fn query_helper( - &mut self, - query: impl AsRef + Send, - arguments: &A, - io_format: IoFormat, - cardinality: Cardinality, - ) -> Result>, Error> - where - A: QueryArgs, - R: QueryResult, - { - self.ensure_started().await?; - - let conn = assert_transaction(&mut self.inner); - - conn.inner() - .query( - query.as_ref(), - arguments, - &self.options.state, - &self.options.annotations, - Capabilities::MODIFICATIONS, - io_format, - cardinality, - ) - .await - } - - /// Execute a query and return a collection of results. - /// - /// You will usually have to specify the return type for the query: - /// - /// ```rust,ignore - /// let greeting = tran.query::("SELECT 'hello'", &()); - /// // or - /// let greeting: Vec = tran.query("SELECT 'hello'", &()); - /// ``` - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn query( - &mut self, - query: impl AsRef + Send, - arguments: &A, - ) -> Result, Error> - where - A: QueryArgs, - R: QueryResult, - { - self.query_helper(query, arguments, IoFormat::Binary, Cardinality::Many) - .await - .map(|x| x.data) - } - - /// Execute a query and return a collection of results and warnings produced by the server. - /// - /// You will usually have to specify the return type for the query: - /// - /// ```rust,ignore - /// let greeting: (Vec, _) = tran.query_with_warnings("select 'hello'", &()).await?; - /// ``` - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn query_verbose( - &mut self, - query: impl AsRef + Send, - arguments: &A, - ) -> Result>, Error> - where - A: QueryArgs, - R: QueryResult, - { - self.query_helper(query, arguments, IoFormat::Binary, Cardinality::Many) - .await - .map(|Response { data, warnings, .. }| ResultVerbose { data, warnings }) - } - - /// Execute a query and return a single result - /// - /// The query must return exactly one element. If the query returns more - /// than one element, a - /// [`ResultCardinalityMismatchError`][crate::errors::ResultCardinalityMismatchError] - /// is raised. If the query returns an empty set, a - /// [`NoDataError`][crate::errors::NoDataError] is raised. - /// - /// You will usually have to specify the return type for the query: - /// - /// ```rust,ignore - /// let greeting = tran.query_required_single::( - /// "SELECT 'hello'", - /// &(), - /// ); - /// // or - /// let greeting: String = tran.query_required_single( - /// "SELECT 'hello'", - /// &(), - /// ); - /// ``` - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn query_single( - &mut self, - query: impl AsRef + Send, - arguments: &A, - ) -> Result, Error> - where - A: QueryArgs, - R: QueryResult + Send, - { - self.query_helper(query, arguments, IoFormat::Binary, Cardinality::AtMostOne) - .await - .map(|x| x.data.into_iter().next()) - } - - /// Execute a query and return a single result - /// - /// The query must return exactly one element. If the query returns more - /// than one element, a - /// [`ResultCardinalityMismatchError`][crate::errors::ResultCardinalityMismatchError] - /// is raised. If the query returns an empty set, a - /// [`NoDataError`][crate::errors::NoDataError] is raised. - /// - /// You will usually have to specify the return type for the query: - /// - /// ```rust,ignore - /// let greeting = tran.query_required_single::( - /// "SELECT 'hello'", - /// &(), - /// ); - /// // or - /// let greeting: String = tran.query_required_single( - /// "SELECT 'hello'", - /// &(), - /// ); - /// ``` - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn query_required_single( - &mut self, - query: impl AsRef + Send, - arguments: &A, - ) -> Result - where - A: QueryArgs, - R: QueryResult + Send, - { - self.query_helper(query, arguments, IoFormat::Binary, Cardinality::AtMostOne) - .await - .and_then(|x| { - x.data - .into_iter() - .next() - .ok_or_else(|| NoDataError::with_message("query row returned zero results")) - }) - } - - /// Execute a query and return the result as JSON. - pub async fn query_json( - &mut self, - query: &str, - arguments: &impl QueryArgs, - ) -> Result { - let res = self - .query_helper::(query, arguments, IoFormat::Json, Cardinality::Many) - .await?; - - let json = res - .data - .into_iter() - .next() - .ok_or_else(|| NoDataError::with_message("query row returned zero results"))?; - - // we trust database to produce valid json - Ok(Json::new_unchecked(json)) - } - - /// Execute a query and return a single result as JSON. - /// - /// The query must return exactly one element. If the query returns more - /// than one element, a - /// [`ResultCardinalityMismatchError`][crate::errors::ResultCardinalityMismatchError] - /// is raised. - pub async fn query_single_json( - &mut self, - query: &str, - arguments: &impl QueryArgs, - ) -> Result, Error> { - let res = self - .query_helper::(query, arguments, IoFormat::Json, Cardinality::AtMostOne) - .await?; - - // we trust database to produce valid json - Ok(res.data.into_iter().next().map(Json::new_unchecked)) - } - - /// Execute a query and return a single result as JSON. - /// - /// The query must return exactly one element. If the query returns more - /// than one element, a - /// [`ResultCardinalityMismatchError`][crate::errors::ResultCardinalityMismatchError] - /// is raised. If the query returns an empty set, a - /// [`NoDataError`][crate::errors::NoDataError] is raised. - pub async fn query_required_single_json( - &mut self, - query: &str, - arguments: &impl QueryArgs, - ) -> Result { - self.query_single_json(query, arguments) - .await? - .ok_or_else(|| NoDataError::with_message("query row returned zero results")) - } - - /// Execute a query and don't expect result - /// - /// This method can be used with both static arguments, like a tuple of - /// scalars, and with dynamic arguments [`edgedb_protocol::value::Value`]. - /// Similarly, dynamically typed results are also supported. - pub async fn execute(&mut self, query: &str, arguments: &A) -> Result<(), Error> - where - A: QueryArgs, - { - self.ensure_started().await?; - let flags = CompilationOptions { - implicit_limit: None, - implicit_typenames: false, - implicit_typeids: false, - explicit_objectids: true, - allow_capabilities: Capabilities::MODIFICATIONS, - input_language: InputLanguage::EdgeQL, - io_format: IoFormat::Binary, - expected_cardinality: Cardinality::Many, - }; - let state = &self.options.state; - let conn = assert_transaction(&mut self.inner); - let desc = conn - .parse(&flags, query, state, &self.options.annotations) - .await?; - let inp_desc = desc.input().map_err(ProtocolEncodingError::with_source)?; - - let mut arg_buf = BytesMut::with_capacity(8); - arguments.encode(&mut Encoder::new( - &inp_desc.as_query_arg_context(), - &mut arg_buf, - ))?; - - conn.execute( - &flags, - query, - state, - &self.options.annotations, - &desc, - &arg_buf.freeze(), - ) - .await?; - Ok(()) - } -} diff --git a/edgedb-tokio/src/tutorial.md b/edgedb-tokio/src/tutorial.md deleted file mode 100644 index 553e91e9..00000000 --- a/edgedb-tokio/src/tutorial.md +++ /dev/null @@ -1,473 +0,0 @@ -# EdgeDB Rust client tutorial - -## Getting started - -### From examples repo - -If you just want a working repo to get started, clone the [Rust client examples repo](https://github.com/Dhghomon/edgedb_rust_client_examples), type `edgedb project init` to start an EdgeDB project, and then `cargo run` to run the samples. - -This tutorial contains a lot of similar examples to those found in the `main.rs` file inside that repo. - -### From scratch - -The minimum to add to your Cargo.toml to use the client is [edgedb-tokio](https://docs.rs/edgedb-tokio/latest/edgedb_tokio/): - - edgedb-tokio = "0.4.0" - -The next most common dependency is [edgedb-protocol](https://docs.rs/edgedb-protocol/latest/edgedb_protocol/), which includes the EdgeDB types used for data modeling: - - edgedb-protocol = "0.4.0" - -A third crate called [edgedb-derive](https://docs.rs/edgedb-derive/latest/edgedb_derive/) contains the `#[derive(Queryable)]` derive macro which is the main way to unpack EdgeDB output into Rust types: - - edgedb-derive = "0.5.0" - -The Rust client uses tokio so add this to Cargo.toml as well: - - tokio = { version = "1.28.0", features = ["macros", "rt-multi-thread"] }` - -If you are avoiding async code and want to emulate a blocking client, you will still need to use tokio as a dependency but can bridge with async using [one of the bridging methods recommended by tokio](https://tokio.rs/tokio/topics/bridging). This won't require any added features: - - tokio = "1.28.0" - -Then you can start a runtime. Block and wait for futures to resolve by calling the runtime's `.block_on()` method: - -```rust -let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; -let just_a_string: String = - rt.block_on(client.query_required_single("select 'Just a string'", &()))?; -``` - -## Edgedb project setup - - -The EdgeDB CLI initializes an EdgeDB project with a single command in the same way that Cargo initializes a Rust project, except it does not create a new directory. So to start a project: - -* Use `cargo new ` as usual, then: -* Go into the directory and type `edgedb project init`. - -The CLI will prompt you for the instance name and version of EdgeDB to use. It will look something like this: - - PS C:\rust\my_db> edgedb project init - No `edgedb.toml` found in `\\?\C:\rust\my_db` or above - Do you want to initialize a new project? [Y/n] - > Y - Specify the name of EdgeDB instance to use with this project [default: my_db]: - > my_db - Checking EdgeDB versions... - Specify the version of EdgeDB to use with this project [default: 3.0]: - > 3.0 - ┌─────────────────────┬─────────────────────────────────┐ - │ Project directory │ \\?\C:\rust\my_db │ - │ Project config │ \\?\C:\rust\my_db\edgedb.toml │ - │ Schema dir (empty) │ \\?\C:\rust\my_db\dbschema │ - │ Installation method │ WSL │ - │ Version │ 3.0+e7d38e9 │ - │ Instance name │ my_db │ - └─────────────────────┴─────────────────────────────────┘ - Version 3.0+e7d38e9 is already installed - Initializing EdgeDB instance... - Applying migrations... - Everything is up to date. Revision initial - Project initialized. - To connect to my_db, run `edgedb` - -Inside your project directory you'll notice some new items: - -* `edgedb.toml`, which is used to mark the directory as an EdgeDB project. The file itself doesn't contain much — just the version of EdgeDB being used — but is used by the CLI to run commands without connection flags. (E.g., `edgedb -I my_project migrate` becomes simply `edgedb migrate`). See more on edgedb.toml [in the blog post introducing the EdgeDB projects CLI](https://www.edgedb.com/blog/introducing-edgedb-projects). - -* A `/dbschema` folder containing: - * a `default.esdl` file which holds your schema. You can change the schema by directly modifying this file followed by `edgedb migration create` and `edgedb migrate`. - * a `/migrations` folder with `.edgeql` files named starting at `00001`. These hold the [ddl](https://www.edgedb.com/docs/reference/ddl/index) commands that were used to migrate your schema. A new file will show up in this directory every time your schema is migrated. - -If you are running EdgeDB 3.0 and above, you also have the option of using the [edgedb watch](https://www.edgedb.com/docs/cli/edgedb_watch) command. Doing so starts a long-running process that keeps an eye on changes in `/dbschema`, automatically applying these changes in real time. - -Now that you have the right dependencies and an EdgeDB instance, you can create a client. - -# Using the client - -Creating a new EdgeDB client can be done in a single line: - -```rust -let client = edgedb_tokio::create_client().await?; -``` - -Under the hood, this will create a [Builder](crate::Builder), look for environment variables and/or an `edgedb.toml` file and return an `Ok(Self)` if successful. This `Builder` can be used on its own instead of `create_client()` if you need a more customized setup. - -# Queries with the client - -Here are the simplified signatures of the client methods used for querying: - -(Note: `R` here means a type that implements [`QueryResult`](https://docs.rs/edgedb-protocol/0.4.0/edgedb_protocol/trait.QueryResult.html)) - -```rust -fn query -> Result, Error> -fn query_json -> Result - -fn query_single -> Result, Error> -fn query_single_json -> Result> - -fn query_required_single -> Result -fn query_required_single_json -> Result -``` - -Note the difference between the `_single` and the `_required_single` methods: - -* The `_required_single` methods return empty results as a `NoDataError` which allows propagating errors normally through an application -* The `_single` methods will simply give you an `Ok(None)` in this case - -These methods all take a *query* (a `&str`) and *arguments* (something that implements the [`QueryArgs`](https://docs.rs/edgedb-protocol/latest/edgedb_protocol/query_arg/trait.QueryArgs.html) trait). - -The `()` unit type implements `QueryArgs` and is used when no arguments are present so `&()` is a pretty common sight when using the Rust client. - -```rust -// Without arguments: just add &() after the query -let query_res: String = client.query_required_single("select 'Just a string'", &()).await?; - -// With arguments, same output -let one = " a "; -let two = "string"; -let query_res: String = client - .query_required_single("select 'Just' ++ $0 ++ $1", &(first, second)) - .await?; -``` - -For more information, see the ["Passing in arguments" section](#passing-in-arguments) below. - -These methods take two generic parameters which can be specified with the turbofish syntax: - -```rust -let query_res = client - .query_required_single::("select 'Just a string'", &()) - .await?; -// or -let query_res = client - .query_required_single::("select 'Just a string'", &()) - .await?; -``` - -But declaring the final expected type upfront tends to look neater. - -```rust -let query_res: String = client - .query_required_single("select 'Just a string'", &()) - .await?; -``` - -# Sample queries - -## When cardinality is guaranteed to be 1 - -Using the `.query()` method works fine for any cardinality, but returns a `Vec` of results. This query with a cardinality of 1 returns a `Result>` which becomes a `Vec` after the error is handled: - -```rust -let query = "select 'Just a string'"; -let query_res: Vec = client.query(query, &()).await?; -``` - -But if you know that only a single result will be returned, using `.query_required_single()` or `.query_single()` will be more ergonomic: - -```rust -let query = "select 'Just a string'"; -let query_res: String = client.query_required_single(query, &()).await?; -let query_res_opt: Option = client.query_single(query, &()).await?; -``` - -## Using the `Queryable` macro - -The easiest way to unpack an EdgeDB query result is the built-in `Queryable` macro from the `edgedb-derive` crate. This turns queries directly into Rust types without having to match on a `Value` (more in the section on [the `Value` enum](#the-value-enum)), cast to JSON, etc. - -```rust -#[derive(Debug, Deserialize, Queryable)] -pub struct QueryableAccount { - pub username: String, - pub id: Uuid, -} - -let query = "select account { - username, - id - };"; -let as_queryable_account: QueryableAccount = client - .query_required_single(query, &()) - .await?; -``` - -Note: Field order within the shape of the query matters when using the `Queryable` macro. In the example below, a query is done in the order `id, username` instead of `username, id` as defined in the struct: - -```rust -let query = "select account { - id, - username - };"; -let wrong_order: Result = client - .query_required_single(query, &()) - .await; -assert!( - format!("{wrong_order:?}") - .contains(r#"WrongField { unexpected: "id", expected: "username" }"#); -); -``` - -You can use [`cargo expand`](https://github.com/dtolnay/cargo-expand) with the nightly compiler to see the code generated by the `Queryable` macro, but the minimal example repo also contains [a somewhat cleaned up version of the generated `Queryable` code](https://github.com/Dhghomon/edgedb_rust_client_examples/blob/master/src/lib.rs#L12). - -## Passing in arguments - -A regular EdgeQL query without arguments looks like this: - -``` -with - message1 := 'Hello there', - message2 := 'General Kenobi', -select message1 ++ ' ' ++ message2; -``` - -And the same query with arguments: - -``` -with - message1 := $0, - message2 := $1, -select message1 ++ ' ' ++ message2; -``` - -In the EdgeQL REPL you are prompted to enter arguments: - -``` -db> with -... message1 := $0, -... message2 := $1, -... select message1 ++ ' ' ++ message2; -Parameter $0: Hello there -Parameter $1: General Kenobi -{'Hello there General Kenobi'} -``` - -But when using the Rust client, there is no prompt to do so. At present, arguments also have to be in the order `$0`, `$1`, and so on while in the REPL, they can be named (e.g. `$message` and `$person` instead of `$0` and `$1`). The arguments in the client are then passed in as a tuple: - -```rust -let arguments = ("Nice movie", 2023); -let query = "with -movie := (insert Movie { - title := $0, - release_year := $1 -}) - select { - title, - release_year, - id -}"; -let query_res: Value = client.query_required_single(query, &(arguments)).await?; -``` - -A note on the casting syntax: EdgeDB requires arguments to have a cast in the same way that Rust requires a type declaration in function signatures. As such, arguments in queries are used as type specification for the EdgeDB compiler, not to cast from queries from the Rust side. Take this query as an example: - -```rust -let query = "select $0"; -``` - -This simply means "select an argument that must be an `int32`", not "take the received argument and cast it into an `int32`". - -As such, this will return an error: - -```rust -let query = "select $0"; -let argument = 9i16; // Rust client will expect an int16 -let query_res: Result = client.query_required_single(query, &(argument,)).await; -assert!(query_res - .unwrap_err() - .to_string() - .contains("expected std::int16")); -``` - -## The `Value` enum - -The [`Value`](https://docs.rs/edgedb-protocol/latest/edgedb_protocol/value/enum.Value.html) enum can be found in the edgedb-protocol crate. A `Value` represents anything returned from EdgeDB. This means you can always return a `Value` from any of the query methods without needing to deserialize into a Rust type, and the enum can be instructive in getting to know the protocol. On the other hand, returning a `Value` leads to pattern matching to get to the inner value and is not the most ergonomic way to work with results from EdgeDB. - -```rust -pub enum Value { - Nothing, - Uuid(Uuid), - Str(String), - Bytes(Vec), - Int16(i16), - Int32(i32), - Int64(i64), - Float32(f32), - Float64(f64), - BigInt(BigInt), - // ... and so on -} -``` - -Most variants of the `Value` enum correspond to a Rust type from the standard library, while some are from the `edgedb-protocol` crate and will have to be constructed. For example, this query expecting an EdgeDB `bigint` type will return an error as it receives a `20`, which is *not* a `bigint` but an `i32`: - -```rust -let query = "select $0"; -let argument = 20; -let query_res: Result = client.query_required_single(query, &(argument,)).await; -assert!(format!("{query_res:?}").contains("expected std::int32")); -``` - -Instead, first construct a `BigInt` from the `i32` and pass that in as an argument: - -```rust -use edgedb_protocol::model::BigInt; - -let query = "select $0"; -let bigint_arg = BigInt::from(20); -let query_res: Result = client.query_required_single(query, &(bigint_arg,)).await; -assert_eq!( - format!("{query_res:?}"), - "Ok(BigInt(BigInt { negative: false, weight: 0, digits: [20] }))" -); -``` - -## Using JSON - -EdgeDB can cast any type to JSON with ``, but the `_json` methods don't require this cast in the query. This result can be turned into a `String` and used to respond to some JSON API request directly, unpacked into a struct using `serde` and `serde_json`, etc. - -```rust -#[derive(Debug, Deserialize)] -pub struct Account { - pub username: String, - pub id: Uuid, -} - -// No need for cast here -let query = "select Account { - username, - id - } filter .username = $0;"; - -// Assuming we know there will only be one result we can use query_single_json; -// otherwise query_json which returns a map of json -let json_res = client - .query_single_json(query, &("SomeUserName",)) - .await? - .unwrap(); - -// Format: {"username" : "SomeUser1", "id" : "7093944a-fd3a-11ed-a013-c7de12ffe7a9"} -let as_string = json_res.to_string(); -let as_account: Account = serde_json::from_str(&json_res)?; -``` - -## Execute - -The `execute` method doesn't return anything (a successful execute returns an `Ok(())`) which is convenient for things like updates or commands where we don't care about getting output if it works: - -```rust -client.execute("update Account set {username := .username ++ '!'};", &()).await?; -client.execute("create superuser role project;", &()).await?; -client.execute("alter role project set password := 'STRONGpassword';", &()).await?; - -// Returns Ok(()) upon success but error info will be returned of course -let command = client.execute("create type MyType {};", &()).await; -assert!(command.unwrap_err().to_string().contains("bare DDL statements are not allowed")); -``` - -## Transactions - -The client also has a `.transaction()` method that allows for atomic [transactions](https://www.edgedb.com/docs/edgeql/transactions). - -Wikipedia has a good example of a scenario requiring a transaction which we can then implement: - -``` -An example of an atomic transaction is a monetary transfer from bank account A -to account B. It consists of two operations, withdrawing the money from -account A and saving it to account B. Performing these operations in an atomic -transaction ensures that the database remains in a consistent state, that is, -money is neither lost nor created if either of those two operations fails. -``` - -A transaction removing 10 cents from one customer's account and placing it in another's would look like this: - -```rust -#[derive(Debug, Deserialize, Queryable)] -pub struct BankCustomer { - pub name: String, - pub bank_balance: i32, -} -// Customer1 has an account with 110 cents in it. -// Customer2 has an account with 90 cents in it. -// Customer1 is going to send 10 cents to Customer 2. This will be a transaction -// because we don't want the case to ever occur - even for a split second - -// where one account has sent money while the other has not received it yet. - -// After the transaction is over, each customer should have 100 cents. - -let sender_name = "Customer1"; -let receiver_name = "Customer2"; -let balance_check_query = "select BankCustomer { name, bank_balance } - filter .name = $0"; -let balance_change_query = "update BankCustomer - filter .name = $0 - set { bank_balance := .bank_balance + $1 }"; -let send_amount = 10; - -client - .transaction(|mut conn| async move { - let sender: BankCustomer = conn - .query_required_single(balance_check_query, &(sender_name,)) - .await?; - if sender.bank_balance < send_amount { - println!("Not enough money to send, bailing from transaction"); - return Ok(()); - }; - conn.execute(balance_change_query, &(sender_name, send_amount.neg())) - .await?; - conn.execute(balance_change_query, &(receiver_name, send_amount)).await?; - Ok(()) - }) - .await?; -``` - -Note: What often may seem to require an atomic transaction can instead be achieved with links and [backlinks](https://www.edgedb.com/docs/edgeql/paths#backlinks) which are both idiomatic and easy to use in EdgeDB. For example, if one object holds a `required link` to two other objects and each of these two objects has a single backlink to the first one, simply updating the first object will effectively change the state of the other two instantaneously. - -## Client configuration - -The Client can still be configured after initialization via the `with_` methods ([`with_retry_options`](crate::Client::with_retry_options), [`with_transaction_options`](crate::Client::with_transaction_options), etc.) that create a shallow copy of the client with adjusted options. - -```rust -// Take a schema with matching Rust structs: -// -// module default { -// type User { -// required property name -> str; -// } -// } - -// module test { -// type User { -// required property name -> str; -// } -// }; - -// The regular client will query from module 'default' by default -let client = edgedb_tokio::create_client().await?; - -// This client will query from module 'test' by default -// The original client is unaffected -let test_client = client.with_default_module(Some("test")); - -// Each client queries separately with different behavior -let query = "select User {name};"; -let users: Vec = client.query(query, &()).await?; -let test_users: Vec = test_client.query(query, &()).await?; - -// Many other clients can be created with different options, -// all independent of the main client: -let transaction_opts = TransactionOptions::default().read_only(true); -let _read_only_client = client.with_transaction_options(transaction_opts); - -let retry_opts = RetryOptions::default().with_rule( - RetryCondition::TransactionConflict, - // No. of retries - 1, - // Retry immediately instead of default with increasing backoff - |_| std::time::Duration::from_millis(0), -); -let _immediate_retry_once_client = client.with_retry_options(retry_opts); -``` \ No newline at end of file diff --git a/edgedb-tokio/src/tutorial.rs b/edgedb-tokio/src/tutorial.rs deleted file mode 100644 index 356b96d3..00000000 --- a/edgedb-tokio/src/tutorial.rs +++ /dev/null @@ -1,2 +0,0 @@ -#![allow(rustdoc::invalid_rust_codeblocks)] -#![cfg_attr(not(doctest), doc = include_str!("tutorial.md"))] diff --git a/edgedb-tokio/tests/credentials1.json b/edgedb-tokio/tests/credentials1.json deleted file mode 100644 index b6fea337..00000000 --- a/edgedb-tokio/tests/credentials1.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "port": 10702, - "user": "test3n", - "password": "lZTBy1RVCfOpBAOwSCwIyBIR", - "database": "test3n" -} diff --git a/edgedb-tokio/tests/func/.gitignore b/edgedb-tokio/tests/func/.gitignore deleted file mode 100644 index 9ff8c18e..00000000 --- a/edgedb-tokio/tests/func/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/dbschema/migrations diff --git a/edgedb-tokio/tests/func/client.rs b/edgedb-tokio/tests/func/client.rs deleted file mode 100644 index 4d318efa..00000000 --- a/edgedb-tokio/tests/func/client.rs +++ /dev/null @@ -1,303 +0,0 @@ -use std::str::FromStr; - -use edgedb_errors::NoDataError; -use edgedb_protocol::model::{Json, Uuid}; -use edgedb_protocol::named_args; -use edgedb_protocol::value::{EnumValue, Value}; -use edgedb_tokio::{Client, Queryable}; -use futures_util::stream::{self, StreamExt}; -use serde::{Deserialize, Serialize}; - -use crate::server::SERVER; - -#[tokio::test] -async fn simple() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - let value = client.query::("SELECT 7*93", &()).await?; - assert_eq!(value, vec![651]); - - let value = client.query_single::("SELECT 5*11", &()).await?; - assert_eq!(value, Some(55)); - - let value = client - .query_single::("SELECT {}", &()) - .await?; - assert_eq!(value, None); - - let value = client - .query_required_single::("SELECT 5*11", &()) - .await?; - assert_eq!(value, 55); - - let err = client - .query_required_single::("SELECT {}", &()) - .await - .unwrap_err(); - assert!(err.is::()); - - let value = client.query_json("SELECT 'x' ++ 'y'", &()).await?; - assert_eq!(value.as_ref(), r#"["xy"]"#); - - let value = client.query_single_json("SELECT 'x' ++ 'y'", &()).await?; - assert_eq!(value.as_deref(), Some(r#""xy""#)); - - let value = client.query_single_json("SELECT {}", &()).await?; - assert_eq!(value.as_deref(), None); - - let value = client.query_json("SELECT {}", &()).await?; - assert_eq!(value, Json::new_unchecked("[]".to_string())); - - let err = client - .query_required_single_json("SELECT {}", &()) - .await - .unwrap_err(); - assert!(err.is::()); - - client.execute("SELECT 1+1", &()).await?; - client - .execute("START MIGRATION TO {}; ABORT MIGRATION", &()) - .await?; - - // basic enum param - let enum_query = "SELECT ($0) = 'waiting'"; - assert!(client - .query_required_single::(enum_query, &(Value::Enum(EnumValue::from("waiting")),)) - .await - .unwrap()); - - // unsupported: enum param as Value::Str - client - .query_required_single::(enum_query, &(Value::Str("waiting".to_string()),)) - .await - .unwrap_err(); - - // unsupported: enum param as String - client - .query_required_single::(enum_query, &("waiting".to_string(),)) - .await - .unwrap_err(); - - // enum param as &str - assert!(client - .query_required_single::(enum_query, &("waiting",),) - .await - .unwrap()); - - // named args - let value = client - .query_required_single::( - "select ( - std::array_join(>$msg1, ' ') - ++ ($question ?? ' the ultimate question of life') - ++ ': ' - ++ $answer - );", - &named_args! { - "msg1" => vec!["the".to_string(), "answer".to_string(), "to".to_string()], - "question" => None::, - "answer" => 42_i64, - }, - ) - .await - .unwrap(); - assert_eq!( - value.as_str(), - "the answer to the ultimate question of life: 42" - ); - - // args for values - let uuid = "43299d0a-f993-4dcb-a8a2-50041bf5af79"; - let value = client - .query_required_single::( - "select $my_uuid;", - &named_args! { - "my_uuid" => Uuid::from_str("43299d0a-f993-4dcb-a8a2-50041bf5af79").unwrap(), - }, - ) - .await - .unwrap(); - assert_eq!(value, Uuid::from_str(uuid).unwrap()); - - Ok(()) -} - -#[tokio::test] -async fn parallel_queries() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - let result = stream::iter(0..10i64) - .map(|idx| { - let cli = client.clone(); - async move { - cli.query_required_single::("SELECT $0*10", &(idx,)) - .await - } - }) - .buffer_unordered(7) - .collect::>() - .await; - let mut result: Vec<_> = result.into_iter().collect::>()?; - result.sort(); - - assert_eq!(result, (0..100).step_by(10).collect::>()); - - Ok(()) -} - -#[tokio::test] -async fn json() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - client - .execute::<_>( - "insert test::OtpPhoneRequest { - phone := '0123456789', - sent_at := datetime_of_statement(), - otp := 98271 - }", - &(), - ) - .await - .unwrap(); - - #[derive(Clone, Debug, Serialize, Deserialize, Queryable)] - #[edgedb(json)] - pub struct OtpPhoneRequest { - pub phone: String, - pub otp: i32, - } - - let res = client.query::( - "select (select test::OtpPhoneRequest { phone, otp } filter .phone = '0123456789')", - &() - ) - .await?; - let res = res.into_iter().next().unwrap(); - assert_eq!(res.phone, "0123456789"); - assert_eq!(res.otp, 98271); - - Ok(()) -} - -#[tokio::test] -async fn big_num() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - let res = client - .query_required_single::("select 1234567890123456789012345678900000n", &()) - .await - .unwrap(); - let Value::BigInt(res) = res else { panic!() }; - assert_eq!(res.to_string(), "1234567890123456789012345678900000"); - - let res = client - .query_required_single::("select 1234567891234567890.12345678900000n", &()) - .await - .unwrap(); - let Value::Decimal(res) = res else { panic!() }; - assert_eq!(res.to_string(), "1234567891234567890.12345678900000"); - - let res = client - .query_required_single::("select 0.00012n", &()) - .await - .unwrap(); - let Value::Decimal(res) = res else { panic!() }; - assert!(!res.negative()); - assert_eq!(res.decimal_digits(), 5); - assert_eq!(res.digits(), [1, 2000]); - assert_eq!(res.weight() * 4, -4); - assert_eq!(res.to_string(), "0.00012"); - - let res = client - .query_required_single::("select 0.000000000000000000001", &()) - .await - .unwrap(); - let Value::Decimal(res) = res else { panic!() }; - assert!(!res.negative()); - assert_eq!(res.decimal_digits(), 21); - assert_eq!(res.digits(), [1000]); - assert_eq!(res.weight() * 4, -24); - assert_eq!(res.to_string(), "0.000000000000000000001"); - - Ok(()) -} - -#[tokio::test] -async fn bytes() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - #[derive(Queryable)] - struct MyResult { - data: bytes::Bytes, - } - - let res = client - .query_required_single::("select { data := b'101' } limit 1", &()) - .await - .unwrap(); - - assert_eq!(res.data, b"101"[..]); - Ok(()) -} - -#[tokio::test] -async fn wrong_field_number() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - #[derive(Queryable, PartialEq, Debug)] - struct Thing { - a: String, - b: String, - } - let err = client - .query_required_single::("select { a := 'hello' }", &()) - .await - .unwrap_err(); - assert_eq!( - format!("{err:#}"), - "DescriptorMismatch: expected 2 fields, got 1" - ); - - let err = client - .query_required_single::("select { a := 'hello', b := 'world', c := 42 }", &()) - .await - .unwrap_err(); - assert_eq!( - format!("{err:#}"), - "DescriptorMismatch: expected 2 fields, got 3" - ); - - let err = client - .query_required_single::("select { a := 'hello', c := 'world' }", &()) - .await - .unwrap_err(); - assert_eq!( - format!("{err:#}"), - "DescriptorMismatch: unexpected field c, expected b" - ); - - Ok(()) -} - -#[tokio::test] -async fn warnings() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - let res = client - .query_verbose::("select std::_warn_on_call()", &()) - .await - .unwrap(); - assert_eq!(res.warnings.len(), 1); - - // TODO: test that the warning is logged - - Ok(()) -} diff --git a/edgedb-tokio/tests/func/dbschema/test.esdl b/edgedb-tokio/tests/func/dbschema/test.esdl deleted file mode 100644 index 91123edf..00000000 --- a/edgedb-tokio/tests/func/dbschema/test.esdl +++ /dev/null @@ -1,22 +0,0 @@ -module test { - scalar type State extending enum<'done', 'waiting', 'blocked'>; - - type Counter { - required property name -> str { - constraint std::exclusive; - } - required property value -> int32 { - default := 0; - } - } - - global str_val -> str; - global int_val -> int32; - - type OtpPhoneRequest { - required phone: str; - required otp: int32; - required sent_at: datetime; - confirmed_at: datetime; - } -} diff --git a/edgedb-tokio/tests/func/derive.rs b/edgedb-tokio/tests/func/derive.rs deleted file mode 100644 index bc92d086..00000000 --- a/edgedb-tokio/tests/func/derive.rs +++ /dev/null @@ -1,73 +0,0 @@ -use edgedb_derive::Queryable; -use edgedb_protocol::model::Uuid; -use edgedb_tokio::Client; - -use crate::server::SERVER; - -#[derive(Queryable, Debug, PartialEq)] -struct FreeOb { - one: i64, - two: i64, -} - -#[derive(Queryable, Debug, PartialEq)] -struct SchemaType { - name: String, -} - -#[derive(Queryable, Debug, PartialEq)] -struct SchemaTypeId { - id: Uuid, - name: String, -} - -#[tokio::test] -async fn free_object() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - let value = client - .query_required_single::("SELECT { one := 1, two := 2 }", &()) - .await?; - assert_eq!(value, FreeOb { one: 1, two: 2 }); - - Ok(()) -} - -#[tokio::test] -async fn schema_type() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - let value = client - .query_required_single::( - " - SELECT schema::ObjectType { name } - FILTER .name = 'schema::Object' - LIMIT 1 - ", - &(), - ) - .await?; - assert_eq!( - value, - SchemaType { - name: "schema::Object".into(), - } - ); - - let value = client - .query_required_single::( - " - SELECT schema::ObjectType { id, name } - FILTER .name = 'schema::Object' - LIMIT 1 - ", - &(), - ) - .await?; - // id is unstable - assert_eq!(value.name, "schema::Object"); - - Ok(()) -} diff --git a/edgedb-tokio/tests/func/globals.rs b/edgedb-tokio/tests/func/globals.rs deleted file mode 100644 index 3ce7f2e5..00000000 --- a/edgedb-tokio/tests/func/globals.rs +++ /dev/null @@ -1,58 +0,0 @@ -use edgedb_tokio::Client; - -use crate::server::SERVER; - -#[tokio::test] -async fn global_fn() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - let value = client - .with_default_module(Some("test")) - .with_globals_fn(|m| m.set("str_val", "hello")) - .query::("SELECT (global str_val)", &()) - .await?; - assert_eq!(value, vec![String::from("hello")]); - - let value = client - .with_default_module(Some("test")) - .with_globals_fn(|m| m.set("int_val", 127)) - .query::("SELECT (global int_val)", &()) - .await?; - assert_eq!(value, vec![127]); - Ok(()) -} - -#[derive(edgedb_derive::GlobalsDelta)] -struct Globals { - str_val: &'static str, - int_val: i32, -} - -#[cfg(feature = "derive")] -#[tokio::test] -async fn global_struct() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client.ensure_connected().await?; - - let value = client - .with_default_module(Some("test")) - .with_globals(&Globals { - str_val: "value1", - int_val: 345, - }) - .query::("SELECT (global str_val)", &()) - .await?; - assert_eq!(value, vec![String::from("value1")]); - - let value = client - .with_default_module(Some("test")) - .with_globals(&Globals { - str_val: "value2", - int_val: 678, - }) - .query::("SELECT (global int_val)", &()) - .await?; - assert_eq!(value, vec![678]); - Ok(()) -} diff --git a/edgedb-tokio/tests/func/main.rs b/edgedb-tokio/tests/func/main.rs deleted file mode 100644 index 4d5c2474..00000000 --- a/edgedb-tokio/tests/func/main.rs +++ /dev/null @@ -1,17 +0,0 @@ -#[cfg(not(windows))] -mod server; - -#[cfg(all(not(windows), feature = "unstable"))] -mod raw; - -#[cfg(not(windows))] -mod client; - -#[cfg(not(windows))] -mod transactions; - -#[cfg(not(windows))] -mod globals; - -#[cfg(not(windows))] -mod derive; diff --git a/edgedb-tokio/tests/func/raw.rs b/edgedb-tokio/tests/func/raw.rs deleted file mode 100644 index 9f77070f..00000000 --- a/edgedb-tokio/tests/func/raw.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::sync::Arc; - -use bytes::Bytes; - -use edgedb_protocol::common::Capabilities; -use edgedb_protocol::common::{Cardinality, CompilationOptions, InputLanguage, IoFormat}; -use edgedb_protocol::encoding::Annotations; -use edgedb_tokio::raw::{Pool, PoolState}; - -use crate::server::SERVER; - -#[tokio::test] -async fn poll_connect() -> anyhow::Result<()> { - let pool = Pool::new(&SERVER.config); - let mut conn = pool.acquire().await?; - assert!(conn.is_consistent()); - - let state = Arc::new(PoolState::default()); - let annotations = Arc::new(Annotations::default()); - let options = CompilationOptions { - implicit_limit: None, - implicit_typenames: false, - implicit_typeids: false, - allow_capabilities: Capabilities::empty(), - explicit_objectids: true, - input_language: InputLanguage::EdgeQL, - io_format: IoFormat::Binary, - expected_cardinality: Cardinality::Many, - }; - - let desc = conn - .parse(&options, "SELECT 7*8", &state, &annotations) - .await?; - assert!(conn.is_consistent()); - - let data = conn - .execute( - &options, - "SELECT 7*8", - &state, - &annotations, - &desc, - &Bytes::new(), - ) - .await?; - assert!(conn.is_consistent()); - assert_eq!(data.len(), 1); - assert_eq!(data[0].data.len(), 1); - assert_eq!(&data[0].data[0][..], b"\0\0\0\0\0\0\0\x38"); - Ok(()) -} diff --git a/edgedb-tokio/tests/func/server.rs b/edgedb-tokio/tests/func/server.rs deleted file mode 100644 index 5aef96f7..00000000 --- a/edgedb-tokio/tests/func/server.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::{path::PathBuf, str::FromStr}; - -use edgedb_tokio::{Builder, Config}; -use once_cell::sync::Lazy; -use test_utils::server::ServerInstance; - -pub struct ServerGuard { - instance: ServerInstance, - pub config: Config, -} - -pub static SERVER: Lazy = Lazy::new(start_server); - -/// Starts edgedb-server. Stops it after the test process exits. -/// Writes its log into a tmp file. -/// -/// To debug, run any test with --nocapture Rust flag. -fn start_server() -> ServerGuard { - shutdown_hooks::add_shutdown_hook(stop_server); - - let instance = ServerInstance::start(); - - instance.apply_schema(&PathBuf::from_str("./tests/func/dbschema").unwrap()); - - let cert_data = std::fs::read_to_string(&instance.info.tls_cert_file) - .expect("cert file should be readable"); - let config = Builder::new() - .port(instance.info.port) - .unwrap() - .pem_certificates(&cert_data) - .unwrap() - .constrained_build() // if this method is not found, you need --features=unstable - .unwrap(); - ServerGuard { instance, config } -} - -extern "C" fn stop_server() { - SERVER.instance.stop() -} diff --git a/edgedb-tokio/tests/func/transactions.rs b/edgedb-tokio/tests/func/transactions.rs deleted file mode 100644 index 12b3b3b0..00000000 --- a/edgedb-tokio/tests/func/transactions.rs +++ /dev/null @@ -1,208 +0,0 @@ -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::Arc; - -use tokio::sync::Mutex; - -use edgedb_errors::NoDataError; -use edgedb_tokio::{Client, Transaction}; - -use crate::server::SERVER; - -struct OnceBarrier(AtomicBool, tokio::sync::Barrier); - -impl OnceBarrier { - fn new(n: usize) -> OnceBarrier { - OnceBarrier(AtomicBool::new(false), tokio::sync::Barrier::new(n)) - } - async fn wait(&self) { - if self.0.load(Ordering::SeqCst) { - return; - } - self.1.wait().await; - self.0.store(true, Ordering::SeqCst) - } -} - -async fn transaction1( - client: Client, - name: &str, - iterations: Arc, - barrier: Arc, - lock: Arc>, -) -> anyhow::Result { - let val = client - .transaction(|mut tx| { - let lock = lock.clone(); - let iterations = iterations.clone(); - let barrier = barrier.clone(); - async move { - iterations.fetch_add(1, Ordering::SeqCst); - // This magic query makes starts real transaction, - // that is otherwise started lazily - tx.query::("SELECT 1", &()).await?; - barrier.wait().await; - let val = { - let _lock = lock.lock().await; - tx.query_required_single( - " - SELECT ( - INSERT test::Counter { - name := $0, - value := 1, - } UNLESS CONFLICT ON .name - ELSE ( - UPDATE test::Counter - SET { value := .value + 1 } - ) - ).value - ", - &(name,), - ) - .await? - }; - Ok(val) - } - }) - .await?; - Ok(val) -} - -#[test_log::test(tokio::test)] -async fn transaction_conflict() -> anyhow::Result<()> { - let cli1 = Client::new(&SERVER.config); - let cli2 = Client::new(&SERVER.config); - tokio::try_join!(cli1.ensure_connected(), cli2.ensure_connected())?; - let barrier = Arc::new(OnceBarrier::new(2)); - let lock = Arc::new(Mutex::new(())); - let iters = Arc::new(AtomicUsize::new(0)); - - // TODO(tailhook) set retry options - let res = tokio::try_join!( - transaction1(cli1, "x", iters.clone(), barrier.clone(), lock.clone()), - transaction1(cli2, "x", iters.clone(), barrier.clone(), lock.clone()), - ); - println!("Result {:#?}", res); - let tup = res?; - - assert!(tup == (1, 2) || tup == (2, 1), "Wrong result: {:?}", tup); - assert_eq!(iters.load(Ordering::SeqCst), 3); - Ok(()) -} - -async fn get_counter_value(tx: &mut Transaction, name: &str) -> anyhow::Result { - let value = tx - .query_required_single( - " - SELECT ( - INSERT test::Counter { - name := $0, - value := 1, - } UNLESS CONFLICT ON .name - ELSE ( - UPDATE test::Counter - SET { value := .value + 1 } - ) - ).value - ", - &(name,), - ) - .await?; - Ok(value) -} - -async fn transaction1e( - client: Client, - name: &str, - iterations: Arc, - barrier: Arc, - lock: Arc>, -) -> anyhow::Result { - let val = client - .transaction(|mut tx| { - let lock = lock.clone(); - let iterations = iterations.clone(); - let barrier = barrier.clone(); - async move { - iterations.fetch_add(1, Ordering::SeqCst); - // This magic query makes starts real transaction, - // that is otherwise started lazily - tx.query::("SELECT 1", &()).await?; - barrier.wait().await; - let _lock = lock.lock().await; - let val = get_counter_value(&mut tx, name).await?; - Ok(val) - } - }) - .await?; - Ok(val) -} - -#[tokio::test] -async fn transaction_conflict_with_complex_err() -> anyhow::Result<()> { - let cli1 = Client::new(&SERVER.config); - let cli2 = Client::new(&SERVER.config); - tokio::try_join!(cli1.ensure_connected(), cli2.ensure_connected())?; - let barrier = Arc::new(OnceBarrier::new(2)); - let lock = Arc::new(Mutex::new(())); - let iters = Arc::new(AtomicUsize::new(0)); - - // TODO(tailhook) set retry options - let res = tokio::try_join!( - transaction1e(cli1, "y", iters.clone(), barrier.clone(), lock.clone()), - transaction1e(cli2, "y", iters.clone(), barrier.clone(), lock.clone()), - ); - println!("Result {:#?}", res); - let tup = res?; - - assert!(tup == (1, 2) || tup == (2, 1), "Wrong result: {:?}", tup); - assert_eq!(iters.load(Ordering::SeqCst), 3); - Ok(()) -} - -#[tokio::test] -async fn queries() -> anyhow::Result<()> { - let client = Client::new(&SERVER.config); - client - .transaction(|mut tx| async move { - let value = tx.query::("SELECT 7*93", &()).await?; - assert_eq!(value, vec![651]); - - let value = tx.query_single::("SELECT 5*11", &()).await?; - assert_eq!(value, Some(55)); - - let value = tx.query_single::("SELECT {}", &()).await?; - assert_eq!(value, None); - - let value = tx - .query_required_single::("SELECT 5*11", &()) - .await?; - assert_eq!(value, 55); - - let err = tx - .query_required_single::("SELECT {}", &()) - .await - .unwrap_err(); - assert!(err.is::()); - - let value = tx.query_json("SELECT 'x' ++ 'y'", &()).await?; - assert_eq!(value.as_ref(), r#"["xy"]"#); - - let value = tx.query_single_json("SELECT 'x' ++ 'y'", &()).await?; - assert_eq!(value.as_deref(), Some(r#""xy""#)); - - let value = tx.query_single_json("SELECT {}", &()).await?; - assert_eq!(value.as_deref(), None); - - let err = tx - .query_required_single_json("SELECT {}", &()) - .await - .unwrap_err(); - assert!(err.is::()); - - tx.execute("SELECT 1+1", &()).await?; - - Ok(()) - }) - .await?; - Ok(()) -} diff --git a/examples/globals/Cargo.toml b/examples/globals/Cargo.toml deleted file mode 100644 index 2874203c..00000000 --- a/examples/globals/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "globals-example" -version = "0.1.0" -edition = "2021" -publish = false - -[dependencies] -anyhow = "1.0" -edgedb-tokio = { version = "0.5.0", path="../../edgedb-tokio"} -edgedb-derive = { version = "0.5.0", path="../../edgedb-derive"} -tokio = {version="1.20", features=["macros", "rt", "rt-multi-thread"]} -env_logger = "0.11.3" diff --git a/examples/globals/dbschema/default.esdl b/examples/globals/dbschema/default.esdl deleted file mode 100644 index d08521e3..00000000 --- a/examples/globals/dbschema/default.esdl +++ /dev/null @@ -1,3 +0,0 @@ -module default { - global str_global -> str; -} diff --git a/examples/globals/dbschema/migrations/00001.edgeql b/examples/globals/dbschema/migrations/00001.edgeql deleted file mode 100644 index ead23c7e..00000000 --- a/examples/globals/dbschema/migrations/00001.edgeql +++ /dev/null @@ -1,5 +0,0 @@ -CREATE MIGRATION m1ke2c655q6ptmr2xyqn7xtw6mntgqajamhauhqyrayamqchqz5awa - ONTO initial -{ - CREATE GLOBAL default::str_global -> std::str; -}; diff --git a/examples/globals/src/main.rs b/examples/globals/src/main.rs deleted file mode 100644 index 07589618..00000000 --- a/examples/globals/src/main.rs +++ /dev/null @@ -1,18 +0,0 @@ -use edgedb_derive::GlobalsDelta; - -#[derive(GlobalsDelta)] -struct Globals<'a> { - str_global: &'a str, -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("warn")).init(); - let conn = edgedb_tokio::create_client().await?; - let conn = conn.with_globals(&Globals { str_global: "val1" }); - let val = conn - .query_required_single::("SELECT (GLOBAL str_global)", &()) - .await?; - assert_eq!(val, "val1"); - Ok(()) -} diff --git a/examples/query-error/Cargo.toml b/examples/query-error/Cargo.toml deleted file mode 100644 index be0693df..00000000 --- a/examples/query-error/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "query-error-example" -version = "0.1.0" -edition = "2021" -publish = false - -[dependencies] -anyhow = "1.0" -edgedb-tokio = { path = "../../edgedb-tokio", features = ["miette-errors"] } -edgedb-derive = { path = "../../edgedb-derive" } -tokio = { version = "1.20", features = ["macros", "rt", "rt-multi-thread"] } -env_logger = "0.11.3" -miette = { version = "7.2.0", features = ["fancy"] } diff --git a/examples/query-error/src/main.rs b/examples/query-error/src/main.rs deleted file mode 100644 index b8824eac..00000000 --- a/examples/query-error/src/main.rs +++ /dev/null @@ -1,33 +0,0 @@ -use anyhow::Context; - -async fn do_something() -> anyhow::Result<()> { - let conn = edgedb_tokio::create_client().await?; - conn.query::("SELECT 1+2)", &()) - .await - .context("Query `select 1+2`")?; - Ok(()) -} - -#[tokio::main] -async fn main() { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("warn")).init(); - match do_something().await { - Ok(res) => res, - Err(e) => { - e.downcast::() - .map(|e| eprintln!("{:?}", miette::Report::new(e))) - .unwrap_or_else(|e| eprintln!("{:#}", e)); - std::process::exit(1); - } - } -} - -/* -/// Alternative error handling if you use miette thorough your application -#[tokio::main] -async fn main() -> miette::Result<()> { - let conn = edgedb_tokio::create_client().await?; - conn.query::("SELECT 1+2)", &()).await?; - Ok(()) -} -*/