From c99730eb623eb052b8e538f055eec3c8a45981c6 Mon Sep 17 00:00:00 2001 From: dzmitry-lahoda Date: Tue, 23 Jul 2024 20:29:48 +0100 Subject: [PATCH] feat: borsh and json schemas support (#43) --- .github/workflows/test-borsh.yml | 15 ++++ .github/workflows/test-serde.yml | 17 ++--- Cargo.toml | 6 ++ rust-toolchain | 1 + src/lib.rs | 113 +++++++++++++++++++++++++++++++ tests/tests.rs | 64 +++++++++++++++++ 6 files changed, 206 insertions(+), 10 deletions(-) create mode 100644 .github/workflows/test-borsh.yml create mode 100644 rust-toolchain diff --git a/.github/workflows/test-borsh.yml b/.github/workflows/test-borsh.yml new file mode 100644 index 0000000..fc29be4 --- /dev/null +++ b/.github/workflows/test-borsh.yml @@ -0,0 +1,15 @@ +name: test borsh +run-name: ${{ github.actor }}'s patch +on: [push] +jobs: + build-and-test: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + cache: true + toolchain: nightly + - run: | + cargo test --no-default-features --features=borsh + cargo test --no-default-features --features=borsh,std diff --git a/.github/workflows/test-serde.yml b/.github/workflows/test-serde.yml index e8dbf1c..94e9fe8 100644 --- a/.github/workflows/test-serde.yml +++ b/.github/workflows/test-serde.yml @@ -5,15 +5,12 @@ jobs: build-and-test: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 - - uses: actions/setup-node@v3 - with: - node-version: '14' - - uses: actions-rs/toolchain@v1 + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 with: + cache: true toolchain: nightly - override: true - - uses: actions-rs/cargo@v1 - with: - command: test - args: --no-default-features --features serde + - run: | + cargo test --no-default-features --features=serde + cargo test --no-default-features --features=serde,std + cargo test --no-default-features --features=schemars diff --git a/Cargo.toml b/Cargo.toml index 5357666..ab12cb6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,10 +27,16 @@ defmt = ["dep:defmt"] # Supports serde serde = ["dep:serde"] +borsh = ["dep:borsh"] + +schemars = ["dep:schemars", "std"] + [dependencies] num-traits = { version = "0.2.17", default-features = false, optional = true } defmt = { version = "0.3.5", optional = true } serde = { version = "1.0", optional = true, default-features = false} +borsh = { version = "1.5.1", optional = true, features = ["unstable__schema"], default-features = false } +schemars = { version = "0.8.1", optional = true, features = ["derive"], default-features = false } [dev-dependencies] serde_test = "1.0" diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 0000000..870bbe4 --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +stable \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 4c2b9c3..d937325 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,9 @@ )] #![cfg_attr(feature = "step_trait", feature(step_trait))] +#[cfg(all(feature = "borsh", not(feature = "std")))] +extern crate alloc; + use core::fmt::{Binary, Debug, Display, Formatter, LowerHex, Octal, UpperHex}; use core::hash::{Hash, Hasher}; #[cfg(feature = "step_trait")] @@ -18,6 +21,18 @@ use core::ops::{ #[cfg(feature = "serde")] use serde::{Deserialize, Deserializer, Serialize, Serializer}; +#[cfg(feature = "borsh")] +use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; + +#[cfg(all(feature = "borsh", not(feature = "std")))] +use alloc::{collections::BTreeMap, string::ToString}; + +#[cfg(all(feature = "borsh", feature = "std"))] +use std::{collections::BTreeMap, string::ToString}; + +#[cfg(feature = "schemars")] +use schemars::JsonSchema; + #[derive(Debug, Clone, Eq, PartialEq)] pub struct TryNewError; @@ -1054,6 +1069,78 @@ where } } +// Borsh is byte-size little-endian de-needs-external-schema no-bit-compression serde. +// Current ser/de for it is not optimal impl because const math is not stable nor primitives has bits traits. +// Uses minimal amount of bytes to fit needed amount of bits without compression (borsh does not have it anyway). +#[cfg(feature = "borsh")] +impl BorshSerialize for UInt +where + Self: Number, + T: BorshSerialize + + From + + BitAnd + + TryInto + + Copy + + Shr, + as Number>::UnderlyingType: + Shr + TryInto + From + BitAnd, +{ + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + let value = self.value(); + let length = (BITS + 7) / 8; + let mut bytes = 0; + let mask: T = u8::MAX.into(); + while bytes < length { + let le_byte: u8 = ((value >> (bytes << 3)) & mask) + .try_into() + .ok() + .expect("we cut to u8 via mask"); + writer.write(&[le_byte])?; + bytes += 1; + } + Ok(()) + } +} + +#[cfg(feature = "borsh")] +impl< + T: BorshDeserialize + core::cmp::PartialOrd< as Number>::UnderlyingType>, + const BITS: usize, + > BorshDeserialize for UInt +where + Self: Number, +{ + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let mut buf = vec![0u8; core::mem::size_of::()]; + reader.read(&mut buf)?; + let value = T::deserialize(&mut &buf[..])?; + if value >= Self::MIN.value() && value <= Self::MAX.value() { + Ok(Self { value }) + } else { + Err(borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + "Value out of range", + )) + } + } +} + +#[cfg(feature = "borsh")] +impl BorshSchema for UInt { + fn add_definitions_recursively( + definitions: &mut BTreeMap, + ) { + definitions.insert( + ["u", &BITS.to_string()].concat(), + borsh::schema::Definition::Primitive(((BITS + 7) / 8) as u8), + ); + } + + fn declaration() -> borsh::schema::Declaration { + ["u", &BITS.to_string()].concat() + } +} + #[cfg(feature = "serde")] impl Serialize for UInt where @@ -1106,6 +1193,32 @@ where } } +#[cfg(feature = "schemars")] +impl JsonSchema for UInt +where + Self: Number, +{ + fn schema_name() -> String { + ["uint", &BITS.to_string()].concat() + } + + fn json_schema(_gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + use schemars::schema::{NumberValidation, Schema, SchemaObject}; + let schema_object = SchemaObject { + instance_type: Some(schemars::schema::InstanceType::Integer.into()), + format: Some(Self::schema_name()), + number: Some(Box::new(NumberValidation { + // can be done with https://github.com/rust-lang/rfcs/pull/2484 + // minimum: Some(Self::MIN.value().try_into().ok().unwrap()), + // maximum: Some(Self::MAX.value().try_into().ok().unwrap()), + ..Default::default() + })), + ..Default::default() + }; + Schema::Object(schema_object) + } +} + impl Hash for UInt where T: Hash, diff --git a/tests/tests.rs b/tests/tests.rs index e050f00..3687875 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1911,3 +1911,67 @@ fn serde() { "invalid value: integer `-1`, expected u128", ); } + +#[cfg(all(feature = "borsh", feature = "std"))] +#[test] +fn borsh() { + use borsh::schema::BorshSchemaContainer; + use borsh::{BorshDeserialize, BorshSerialize}; + let mut buf = Vec::new(); + let base_input: u8 = 42; + let input = u9::new(base_input.into()); + input.serialize(&mut buf).unwrap(); + let output = u9::deserialize(&mut buf.as_ref()).unwrap(); + let fits = u16::new(base_input.into()); + assert_eq!(buf, fits.to_le_bytes()); + assert_eq!(input, output); + + let input = u63::MAX; + let fits = u64::new(input.value()); + let mut buf = Vec::new(); + input.serialize(&mut buf).unwrap(); + let output: u63 = u63::deserialize(&mut buf.as_ref()).unwrap(); + assert_eq!(buf, fits.to_le_bytes()); + assert_eq!(input, output); + + let schema = BorshSchemaContainer::for_type::(); + match schema.get_definition("u9").expect("exists") { + borsh::schema::Definition::Primitive(2) => {} + _ => panic!("unexpected schema"), + } + + let input = u50::MAX; + let fits = u64::new(input.value()); + let mut buf = Vec::new(); + input.serialize(&mut buf).unwrap(); + assert!(buf.len() < fits.to_le_bytes().len()); + assert_eq!(buf, fits.to_le_bytes()[0..((u50::BITS + 7) / 8)]); + let output: u50 = u50::deserialize(&mut buf.as_ref()).unwrap(); + assert_eq!(input, output); +} + +#[cfg(feature = "schemars")] +#[test] +fn schemars() { + use schemars::schema_for; + let mut u8 = schema_for!(u8); + let u9 = schema_for!(u9); + assert_eq!( + u8.schema.format.clone().unwrap().replace("8", "9"), + u9.schema.format.clone().unwrap() + ); + u8.schema.format = u9.schema.format.clone(); + assert_eq!( + u8.schema + .metadata + .clone() + .unwrap() + .title + .unwrap() + .replace("8", "9"), + u9.schema.metadata.clone().unwrap().title.unwrap() + ); + u8.schema.metadata = u9.schema.metadata.clone(); + u8.schema.number = u9.schema.number.clone(); + assert_eq!(u8, u9); +}