From 6a6499770a40f12369991a8952672c88826f549d Mon Sep 17 00:00:00 2001 From: Andrew Westberg Date: Mon, 9 Sep 2024 14:03:01 +0000 Subject: [PATCH] fix[pallas-math]: use malachite as default --- .github/workflows/validate.yml | 10 +- pallas-crypto/Cargo.toml | 3 +- pallas-crypto/src/nonce/epoch_nonce.rs | 86 ----- pallas-crypto/src/nonce/mod.rs | 199 ++++++++++- pallas-crypto/src/nonce/rolling_nonce.rs | 168 --------- pallas-crypto/src/vrf/mod.rs | 143 ++++++-- pallas-math/Cargo.toml | 7 +- pallas-math/src/lib.rs | 12 +- pallas-math/src/math.rs | 287 +++++++++++++-- pallas-math/src/math_gmp.rs | 50 +++ .../src/{math_num.rs => math_malachite.rs} | 334 +++++++++++++----- 11 files changed, 858 insertions(+), 441 deletions(-) delete mode 100644 pallas-crypto/src/nonce/epoch_nonce.rs delete mode 100644 pallas-crypto/src/nonce/rolling_nonce.rs rename pallas-math/src/{math_num.rs => math_malachite.rs} (60%) diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index dd24b439..9499dfed 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -27,11 +27,11 @@ jobs: - name: Run cargo check Windows if: matrix.os == 'windows-latest' - run: cargo check --no-default-features --features num + run: cargo check - name: Run cargo check if: matrix.os != 'windows-latest' - run: cargo check + run: cargo check --features gmp test: name: Test Suite @@ -46,7 +46,7 @@ jobs: toolchain: stable - name: Run cargo test - run: cargo test + run: cargo test --features gmp test-windows: name: Test Suite Windows @@ -61,7 +61,7 @@ jobs: toolchain: stable - name: Run cargo test - run: cargo test --no-default-features --features num + run: cargo test lints: name: Lints @@ -82,4 +82,4 @@ jobs: - name: Run cargo clippy run: | cargo clippy -- -D warnings - cargo clippy --no-default-features --features num -- -D warnings \ No newline at end of file + cargo clippy --features gmp -- -D warnings diff --git a/pallas-crypto/Cargo.toml b/pallas-crypto/Cargo.toml index 6ba4181e..dffe4beb 100644 --- a/pallas-crypto/Cargo.toml +++ b/pallas-crypto/Cargo.toml @@ -21,10 +21,9 @@ rand_core = "0.6" pallas-codec = { version = "=0.30.2", path = "../pallas-codec" } serde = "1.0.143" -# FIXME: This needs to be a properly deployed crate from the input-output-hk/vrf repository after my PR is merged # The vrf crate has not been fully tested in production environments and still has several upstream issues that # are open PRs but not merged yet. -vrf_dalek = { git = "https://github.com/AndrewWestberg/vrf", rev = "6fc1440b197098feb6d75e2b71517019b8e2e9c2" } +vrf_dalek = { git = "https://github.com/input-output-hk/vrf", rev = "a3185620b72e6a9647285941b961021186f16693" } [dev-dependencies] itertools = "0.13" diff --git a/pallas-crypto/src/nonce/epoch_nonce.rs b/pallas-crypto/src/nonce/epoch_nonce.rs deleted file mode 100644 index a3ea955d..00000000 --- a/pallas-crypto/src/nonce/epoch_nonce.rs +++ /dev/null @@ -1,86 +0,0 @@ -use crate::hash::{Hash, Hasher}; -use crate::nonce::{Error, NonceGenerator}; - -/// A nonce generator that calculates an epoch nonce from the eta_v value (nc) of the block right before -/// the stability window and the block hash of the first block from the previous epoch (nh). -#[derive(Debug, Clone)] -pub struct EpochNonceGenerator { - pub nonce: Hash<32>, -} - -impl EpochNonceGenerator { - /// Create a new [`EpochNonceGenerator`] generator. - /// params: - /// - nc: the eta_v value of the block right before the stability window. - /// - nh: the block hash of the first block from the previous epoch. - /// - extra_entropy: optional extra entropy to be used in the nonce calculation. - pub fn new(nc: Hash<32>, nh: Hash<32>, extra_entropy: Option<&[u8]>) -> Self { - let mut hasher = Hasher::<256>::new(); - hasher.input(nc.as_ref()); - hasher.input(nh.as_ref()); - let epoch_nonce = hasher.finalize(); - if let Some(extra_entropy) = extra_entropy { - let mut hasher = Hasher::<256>::new(); - hasher.input(epoch_nonce.as_ref()); - hasher.input(extra_entropy); - let extra_nonce = hasher.finalize(); - Self { nonce: extra_nonce } - } else { - Self { nonce: epoch_nonce } - } - } -} - -impl NonceGenerator for EpochNonceGenerator { - fn finalize(&mut self) -> Result, Error> { - Ok(self.nonce) - } -} - -#[cfg(test)] -mod tests { - use itertools::izip; - - use crate::hash::Hash; - - use super::*; - - #[test] - fn test_epoch_nonce() { - let nc_values = vec![ - hex::decode("e86e133bd48ff5e79bec43af1ac3e348b539172f33e502d2c96735e8c51bd04d") - .unwrap(), - hex::decode("d1340a9c1491f0face38d41fd5c82953d0eb48320d65e952414a0c5ebaf87587") - .unwrap(), - ]; - let nh_values = vec![ - hex::decode("d7a1ff2a365abed59c9ae346cba842b6d3df06d055dba79a113e0704b44cc3e9") - .unwrap(), - hex::decode("ee91d679b0a6ce3015b894c575c799e971efac35c7a8cbdc2b3f579005e69abd") - .unwrap(), - ]; - let ee = hex::decode("d982e06fd33e7440b43cefad529b7ecafbaa255e38178ad4189a37e4ce9bf1fa") - .unwrap(); - let extra_entropy_values: Vec> = vec![None, Some(&ee)]; - let expected_epoch_nonces = vec![ - hex::decode("e536a0081ddd6d19786e9d708a85819a5c3492c0da7349f59c8ad3e17e4acd98") - .unwrap(), - hex::decode("0022cfa563a5328c4fb5c8017121329e964c26ade5d167b1bd9b2ec967772b60") - .unwrap(), - ]; - - for (nc_value, nh_value, extra_entropy_value, expected_epoch_nonce) in izip!( - nc_values.iter(), - nh_values.iter(), - extra_entropy_values.iter(), - expected_epoch_nonces.iter() - ) { - let nc: Hash<32> = Hash::from(nc_value.as_slice()); - let nh: Hash<32> = Hash::from(nh_value.as_slice()); - let extra_entropy = *extra_entropy_value; - let mut epoch_nonce = EpochNonceGenerator::new(nc, nh, extra_entropy); - let nonce = epoch_nonce.finalize().unwrap(); - assert_eq!(nonce.as_ref(), expected_epoch_nonce.as_slice()); - } - } -} diff --git a/pallas-crypto/src/nonce/mod.rs b/pallas-crypto/src/nonce/mod.rs index 174b7498..c40364c4 100644 --- a/pallas-crypto/src/nonce/mod.rs +++ b/pallas-crypto/src/nonce/mod.rs @@ -1,17 +1,192 @@ -use thiserror::Error; +use crate::hash::{Hash, Hasher}; -use crate::hash::Hash; - -pub mod epoch_nonce; -pub mod rolling_nonce; +/// A nonce generator function that calculates an epoch nonce from the eta_v value (nc) of the block right before +/// the stability window and the block hash of the first block from the previous epoch (nh). +pub fn generate_epoch_nonce(nc: Hash<32>, nh: Hash<32>, extra_entropy: Option<&[u8]>) -> Hash<32> { + let mut hasher = Hasher::<256>::new(); + hasher.input(nc.as_ref()); + hasher.input(nh.as_ref()); + let epoch_nonce = hasher.finalize(); + if let Some(extra_entropy) = extra_entropy { + let mut hasher = Hasher::<256>::new(); + hasher.input(epoch_nonce.as_ref()); + hasher.input(extra_entropy); + hasher.finalize() + } else { + epoch_nonce + } +} -#[derive(Error, Debug)] -pub enum Error { - #[error("Nonce error: {0}")] - Nonce(String), +/// A nonce generator function that calculates a rolling nonce (eta_v) by applying each cardano block in +/// the shelley era and beyond. These rolling nonce values are used to help calculate the epoch +/// nonce values used in consensus for the Ouroboros protocols (tpraos, praos, cpraos). +pub fn generate_rolling_nonce(previous_block_eta_v: Hash<32>, block_eta_vrf_0: &[u8]) -> Hash<32> { + assert!( + block_eta_vrf_0.len() == 32 || block_eta_vrf_0.len() == 64, + "Invalid block_eta_vrf_0 length: {}, expected 32 or 64", + block_eta_vrf_0.len() + ); + let mut hasher = Hasher::<256>::new(); + hasher.input(previous_block_eta_v.as_ref()); + hasher.input(Hasher::<256>::hash(block_eta_vrf_0).as_ref()); + hasher.finalize() } -/// A trait for generating nonces. -pub trait NonceGenerator: Sized { - fn finalize(&mut self) -> Result, Error>; +#[cfg(test)] +mod tests { + use itertools::izip; + + use crate::hash::Hash; + + use super::*; + + #[test] + fn test_epoch_nonce() { + let nc_values = vec![ + hex::decode("e86e133bd48ff5e79bec43af1ac3e348b539172f33e502d2c96735e8c51bd04d") + .unwrap(), + hex::decode("d1340a9c1491f0face38d41fd5c82953d0eb48320d65e952414a0c5ebaf87587") + .unwrap(), + ]; + let nh_values = vec![ + hex::decode("d7a1ff2a365abed59c9ae346cba842b6d3df06d055dba79a113e0704b44cc3e9") + .unwrap(), + hex::decode("ee91d679b0a6ce3015b894c575c799e971efac35c7a8cbdc2b3f579005e69abd") + .unwrap(), + ]; + let ee = hex::decode("d982e06fd33e7440b43cefad529b7ecafbaa255e38178ad4189a37e4ce9bf1fa") + .unwrap(); + let extra_entropy_values: Vec> = vec![None, Some(&ee)]; + let expected_epoch_nonces = vec![ + hex::decode("e536a0081ddd6d19786e9d708a85819a5c3492c0da7349f59c8ad3e17e4acd98") + .unwrap(), + hex::decode("0022cfa563a5328c4fb5c8017121329e964c26ade5d167b1bd9b2ec967772b60") + .unwrap(), + ]; + + for (nc_value, nh_value, extra_entropy_value, expected_epoch_nonce) in izip!( + nc_values.iter(), + nh_values.iter(), + extra_entropy_values.iter(), + expected_epoch_nonces.iter() + ) { + let nc: Hash<32> = Hash::from(nc_value.as_slice()); + let nh: Hash<32> = Hash::from(nh_value.as_slice()); + let extra_entropy = *extra_entropy_value; + let epoch_nonce = generate_epoch_nonce(nc, nh, extra_entropy); + assert_eq!(epoch_nonce.as_ref(), expected_epoch_nonce.as_slice()); + } + } + + #[test] + fn test_rolling_nonce() { + let shelley_genesis_hash = + hex::decode("1a3be38bcbb7911969283716ad7aa550250226b76a61fc51cc9a9a35d9276d81") + .unwrap(); + + let eta_vrf_0_values = vec![ + hex::decode("36ec5378d1f5041a59eb8d96e61de96f0950fb41b49ff511f7bc7fd109d4383e1d24be7034e6749c6612700dd5ceb0c66577b88a19ae286b1321d15bce1ab736").unwrap(), + hex::decode("e0bf34a6b73481302f22987cde4c12807cbc2c3fea3f7fcb77261385a50e8ccdda3226db3efff73e9fb15eecf841bbc85ce37550de0435ebcdcb205e0ed08467").unwrap(), + hex::decode("7107ef8c16058b09f4489715297e55d145a45fc0df75dfb419cab079cd28992854a034ad9dc4c764544fb70badd30a9611a942a03523c6f3d8967cf680c4ca6b").unwrap(), + hex::decode("6f561aad83884ee0d7b19fd3d757c6af096bfd085465d1290b13a9dfc817dfcdfb0b59ca06300206c64d1ba75fd222a88ea03c54fbbd5d320b4fbcf1c228ba4e").unwrap(), + hex::decode("3d3ba80724db0a028783afa56a85d684ee778ae45b9aa9af3120f5e1847be1983bd4868caf97fcfd82d5a3b0b7c1a6d53491d75440a75198014eb4e707785cad").unwrap(), + hex::decode("0b07976bc04321c2e7ba0f1acb3c61bd92b5fc780a855632e30e6746ab4ac4081490d816928762debd3e512d22ad512a558612adc569718df1784261f5c26aff").unwrap(), + hex::decode("5e9e001fb1e2ddb0dc7ff40af917ecf4ba9892491d4bcbf2c81db2efc57627d40d7aac509c9bcf5070d4966faaeb84fd76bb285af2e51af21a8c024089f598c1").unwrap(), + hex::decode("182e83f8c67ad2e6bddead128e7108499ebcbc272b50c42783ef08f035aa688fecc7d15be15a90dbfe7fe5d7cd9926987b6ec12b05f2eadfe0eb6cad5130aca4").unwrap(), + hex::decode("275e7404b2385a9d606d67d0e29f5516fb84c1c14aaaf91afa9a9b3dcdfe09075efdadbaf158cfa1e9f250cc7c691ed2db4a29288d2426bd74a371a2a4b91b57").unwrap(), + hex::decode("0f35c7217792f8b0cbb721ae4ae5c9ae7f2869df49a3db256aacc10d23997a09e0273261b44ebbcecd6bf916f2c1cd79cf25b0c2851645d75dd0747a8f6f92f5").unwrap(), + hex::decode("14c28bf9b10421e9f90ffc9ab05df0dc8c8a07ffac1c51725fba7e2b7972d0769baea248f93ed0f2067d11d719c2858c62fc1d8d59927b41d4c0fbc68d805b32").unwrap(), + hex::decode("e4ce96fee9deb9378a107db48587438cddf8e20a69e21e5e4fbd35ef0c56530df77eba666cb152812111ba66bbd333ed44f627c727115f8f4f15b31726049a19").unwrap(), + hex::decode("b38f315e3ce369ea2551bf4f44e723dd15c7d67ba4b3763997909f65e46267d6540b9b00a7a65ae3d1f3a3316e57a821aeaac33e4e42ded415205073134cd185").unwrap(), + hex::decode("4bcbf774af9c8ff24d4d96099001ec06a24802c88fea81680ea2411392d32dbd9b9828a690a462954b894708d511124a2db34ec4179841e07a897169f0f1ac0e").unwrap(), + hex::decode("65247ace6355f978a12235265410c44f3ded02849ec8f8e6db2ac705c3f57d322ea073c13cf698e15d7e1d7f2bc95e7b3533be0dee26f58864f1664df0c1ebba").unwrap(), + hex::decode("d0c2bb451d0a3465a7fef7770718e5e49bf092a85dbf5af66ea26ec9c1b359026905fc1457e2b98b01ede7ba42aedcc525301f747a0ed9a9b61c37f27f9d8812").unwrap(), + hex::decode("250d9ec7ebec73e885798ae9427e1ea47b5ae66059b465b7c0fd132d17a9c2dcae29ba72863c1861cfb776d342812c4e9000981c4a40819430d0e84aa8bfeb0d").unwrap(), + hex::decode("0549cc0a5e5b9920796b88784c49b7d9a04cf2e86ab18d5af7b00780e60fb0fb5a7129945f4f918201dbad5348d4ccface4370f266540f8e072cdb46d3705930").unwrap(), + hex::decode("e543a26031dbdc8597b1beeba48a4f1cf6ab90c0e5b9343936b6e948a791198fc4fa22928e21edec812a04d0c9629772bf78e475d91a323cd8a8a6e005f92b4d").unwrap(), + hex::decode("4e4be69ad170fb8b3b17835913391ee537098d49e4452844a71ab2147ac55e45871c8943271806034ee9450b31c9486db9d26942946f48040ece7eea81424af1").unwrap(), + hex::decode("cb8a528288f902349250f9e8015e8334b0e24c2eeb9bb7d75e73c39024685804577565e62aca35948d2686ea38e9f8de97837ea30d2fb08347768394416e4a38").unwrap(), + hex::decode("fce94c47196a56a5cb94d5151ca429daf1c563ae889d0a42c2d03cfe43c94a636221c7e21b0668de9e5b6b32ee1e78b2c9aabc16537bf79c7b85eb956f433ac7").unwrap(), + hex::decode("fc8a125c9e2418c87907db4437a0ad6a378bba728ac8e0ce0e64f2a2f4b8201315e1b08d7983ce597cb68be2a2400d6d0d59b7359fe3dc9daca73d468da48972").unwrap(), + hex::decode("49290417311420d67f029a80b013b754150dd0097aa64de1c14a2467ab2e26cc2724071c04cb90cb0cf6c6353cf31f63235af7849d6ba023fd0fc0bc79d32f0b").unwrap(), + hex::decode("45c65effdc8007c9f2fc9057af986e94eb5c12b755465058d4b933ee37638452c5eeca4b43b8cbddabc60f29cbe5676b0bc55c0da88f8d0c36068e7d17ee603a").unwrap(), + hex::decode("a51e4e0f28aee3024207d87a5a1965313bdba4df44c6b845f7ca3408e5dabfe873df6b6ba26000e841f83f69e1de7857122ba538b42f255da2d013208af806ba").unwrap(), + hex::decode("5dbd891bf3bcfd5d054274759c13552aeaa187949875d81ee62ed394253ae25182e78b3a4a1976a7674e425bab860931d57f8a1d4fdc81fa4c3e8e8bf9016d5d").unwrap(), + hex::decode("3b5b044026e9066d62ce2f5a1fb01052a8cfe200dea28d421fc70f42c4d2b890b90ffef5675de1e47e4a20c9ca8700ceea23a61338ac759a098d167fa71642cb").unwrap(), + hex::decode("bb4017880cfa1e37f256dfe2a9cdb1349ed5dea8f69de75dc5933540dcf49e69afc33c837ba8a791857e16fad8581c4e9046778c49ca1ecd1fb675983be6d721").unwrap(), + hex::decode("517bbdb6e9e5f4702193064543204e780f5d33a866d0dcd65ada19f05715dea60ca81b842de5dca8f6b84a9cf469c8fb81991369dba21571476cc9c8d4ff2136").unwrap(), + ]; + + let expected_eta_v_values = vec![ + hex::decode("2af15f57076a8ff225746624882a77c8d2736fe41d3db70154a22b50af851246") + .unwrap(), + hex::decode("a815ff978369b57df09b0072485c26920dc0ec8e924a852a42f0715981cf0042") + .unwrap(), + hex::decode("f112d91435b911b6b5acaf27198762905b1cdec8c5a7b712f925ce3c5c76bb5f") + .unwrap(), + hex::decode("5450d95d9be4194a0ded40fbb4036b48d1f1d6da796e933fefd2c5c888794b4b") + .unwrap(), + hex::decode("c5c0f406cb522ad3fead4ecc60bce9c31e80879bc17eb1bb9acaa9b998cdf8bf") + .unwrap(), + hex::decode("5857048c728580549de645e087ba20ef20bb7c51cc84b5bc89df6b8b0ed98c41") + .unwrap(), + hex::decode("d6f40ef403687115db061b2cb9b1ab4ddeb98222075d5a3e03c8d217d4d7c40e") + .unwrap(), + hex::decode("5489d75a9f4971c1824462b5e2338609a91f121241f21fee09811bd5772ae0a8") + .unwrap(), + hex::decode("04716326833ecdb595153adac9566a4b39e5c16e8d02526cb4166e4099a00b1a") + .unwrap(), + hex::decode("39db709f50c8a279f0a94adcefb9360dbda6cdce168aed4288329a9cd53492b6") + .unwrap(), + hex::decode("c784b8c8678e0a04748a3ad851dd7c34ed67141cd9dc0c50ceaff4df804699a7") + .unwrap(), + hex::decode("cc1a5861358c075de93a26a91c5a951d5e71190d569aa2dc786d4ca8fc80cc38") + .unwrap(), + hex::decode("514979c89313c49e8f59fb8445113fa7623e99375cc4917fe79df54f8d4bdfce") + .unwrap(), + hex::decode("6a783e04481b9e04e8f3498a3b74c90c06a1031fb663b6793ce592a6c26f56f4") + .unwrap(), + hex::decode("1190f5254599dcee4f3cf1afdf4181085c36a6db6c30f334bfe6e6f320a6ed91") + .unwrap(), + hex::decode("91c777d6db066fe58edd67cd751fc7240268869b365393f6910e0e8f0fa58af3") + .unwrap(), + hex::decode("c545d83926c011b5c68a72de9a4e2f9da402703f4aab1b967456eae73d9f89b3") + .unwrap(), + hex::decode("ec31d2348bf543482842843a61d5b32691dedf801f198d68126c423ddf391e8b") + .unwrap(), + hex::decode("de223867d5c972895dd99ac0280a3e02947a7fb018ed42ed048266f913d2dfc2") + .unwrap(), + hex::decode("4dd9801752aade9c6e06bf03e9d2ec8a30ef7c6f30106790a23a9599e90ee08a") + .unwrap(), + hex::decode("fcb183abd512271f40408a5872827ce79cc2dda685a986a7dbdc61d842495a91") + .unwrap(), + hex::decode("e834d8ffd6dd042167b13e38512c62afdaf4d635d5b1ab0d513e08e9bef0ef63") + .unwrap(), + hex::decode("270a78257a958cd5fdb26f0b9ab302df2d2196fd04989f7ca1bb703e4dd904f0") + .unwrap(), + hex::decode("7e324f67af787dfddee10354128c60c60bf601bd8147c867d2471749a7b0f334") + .unwrap(), + hex::decode("54521ed42e0e782b5268ec55f80cff582162bc23fdcee5cdaa0f1a2ce7fa1f02") + .unwrap(), + hex::decode("557c296a71d8c9cb3fe7dcd95fbf4d70f6a3974d93c71b450d62a41b9a85d5a1") + .unwrap(), + hex::decode("20e078301ca282857378bbf10ac40965445c4c9fa73a160e0a116b4cf808b4b4") + .unwrap(), + hex::decode("b5a741dd3ff6a5a3d27b4d046dfb7a3901aacd37df7e931ba05e1320ad155c1c") + .unwrap(), + hex::decode("8b445f35f4a7b76e5d279d71fa9e05376a7c4533ca8b2b98fd2dbaf814d3bf8f") + .unwrap(), + hex::decode("08e7b5277abc139deb50f61264375fa091c580f8a85f259be78a002f7023c31f") + .unwrap(), + ]; + + let mut previous_block_eta_v = Hash::<32>::from(shelley_genesis_hash.as_slice()); + + for (eta_vrf_0, expected_eta_v) in eta_vrf_0_values.iter().zip(expected_eta_v_values.iter()) + { + let rolling_nonce = generate_rolling_nonce(previous_block_eta_v, eta_vrf_0); + assert_eq!(rolling_nonce.as_ref(), expected_eta_v.as_slice()); + previous_block_eta_v = rolling_nonce; + } + } } diff --git a/pallas-crypto/src/nonce/rolling_nonce.rs b/pallas-crypto/src/nonce/rolling_nonce.rs deleted file mode 100644 index 84cf8f65..00000000 --- a/pallas-crypto/src/nonce/rolling_nonce.rs +++ /dev/null @@ -1,168 +0,0 @@ -use crate::hash::{Hash, Hasher}; -use crate::nonce::{Error, NonceGenerator}; - -/// A nonce generator that calculates a rolling nonce by applying each cardano block in -/// the shelley era and beyond. These rolling nonce values are used to help calculate the epoch -/// nonce values used in consensus for the Ouroboros protocols (tpraos, praos, cpraos). -#[derive(Debug, Clone)] -pub struct RollingNonceGenerator { - pub nonce: Hash<32>, - block_eta_v: Option>, -} - -impl RollingNonceGenerator { - pub fn new(nonce: Hash<32>) -> Self { - Self { - nonce, - block_eta_v: None, - } - } - - pub fn apply_block(&mut self, eta_vrf_0: &[u8]) -> Result<(), Error> { - let len = eta_vrf_0.len(); - if len != 64 && len != 32 { - return Err(Error::Nonce(format!( - "Invalid eta_vrf_0 length: {}, expected 32 or 64", - eta_vrf_0.len() - ))); - } - self.block_eta_v = Some(Hasher::<256>::hash(eta_vrf_0)); - Ok(()) - } -} - -impl NonceGenerator for RollingNonceGenerator { - fn finalize(&mut self) -> Result, Error> { - if self.block_eta_v.is_none() { - return Err(Error::Nonce( - "Must call apply_block before finalize!".to_string(), - )); - } - let mut hasher = Hasher::<256>::new(); - hasher.input(self.nonce.as_ref()); - hasher.input(self.block_eta_v.unwrap().as_ref()); - Ok(hasher.finalize()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_rolling_nonce() { - let shelley_genesis_hash = - hex::decode("1a3be38bcbb7911969283716ad7aa550250226b76a61fc51cc9a9a35d9276d81") - .unwrap(); - - let eta_vrf_0_values = vec![ - hex::decode("36ec5378d1f5041a59eb8d96e61de96f0950fb41b49ff511f7bc7fd109d4383e1d24be7034e6749c6612700dd5ceb0c66577b88a19ae286b1321d15bce1ab736").unwrap(), - hex::decode("e0bf34a6b73481302f22987cde4c12807cbc2c3fea3f7fcb77261385a50e8ccdda3226db3efff73e9fb15eecf841bbc85ce37550de0435ebcdcb205e0ed08467").unwrap(), - hex::decode("7107ef8c16058b09f4489715297e55d145a45fc0df75dfb419cab079cd28992854a034ad9dc4c764544fb70badd30a9611a942a03523c6f3d8967cf680c4ca6b").unwrap(), - hex::decode("6f561aad83884ee0d7b19fd3d757c6af096bfd085465d1290b13a9dfc817dfcdfb0b59ca06300206c64d1ba75fd222a88ea03c54fbbd5d320b4fbcf1c228ba4e").unwrap(), - hex::decode("3d3ba80724db0a028783afa56a85d684ee778ae45b9aa9af3120f5e1847be1983bd4868caf97fcfd82d5a3b0b7c1a6d53491d75440a75198014eb4e707785cad").unwrap(), - hex::decode("0b07976bc04321c2e7ba0f1acb3c61bd92b5fc780a855632e30e6746ab4ac4081490d816928762debd3e512d22ad512a558612adc569718df1784261f5c26aff").unwrap(), - hex::decode("5e9e001fb1e2ddb0dc7ff40af917ecf4ba9892491d4bcbf2c81db2efc57627d40d7aac509c9bcf5070d4966faaeb84fd76bb285af2e51af21a8c024089f598c1").unwrap(), - hex::decode("182e83f8c67ad2e6bddead128e7108499ebcbc272b50c42783ef08f035aa688fecc7d15be15a90dbfe7fe5d7cd9926987b6ec12b05f2eadfe0eb6cad5130aca4").unwrap(), - hex::decode("275e7404b2385a9d606d67d0e29f5516fb84c1c14aaaf91afa9a9b3dcdfe09075efdadbaf158cfa1e9f250cc7c691ed2db4a29288d2426bd74a371a2a4b91b57").unwrap(), - hex::decode("0f35c7217792f8b0cbb721ae4ae5c9ae7f2869df49a3db256aacc10d23997a09e0273261b44ebbcecd6bf916f2c1cd79cf25b0c2851645d75dd0747a8f6f92f5").unwrap(), - hex::decode("14c28bf9b10421e9f90ffc9ab05df0dc8c8a07ffac1c51725fba7e2b7972d0769baea248f93ed0f2067d11d719c2858c62fc1d8d59927b41d4c0fbc68d805b32").unwrap(), - hex::decode("e4ce96fee9deb9378a107db48587438cddf8e20a69e21e5e4fbd35ef0c56530df77eba666cb152812111ba66bbd333ed44f627c727115f8f4f15b31726049a19").unwrap(), - hex::decode("b38f315e3ce369ea2551bf4f44e723dd15c7d67ba4b3763997909f65e46267d6540b9b00a7a65ae3d1f3a3316e57a821aeaac33e4e42ded415205073134cd185").unwrap(), - hex::decode("4bcbf774af9c8ff24d4d96099001ec06a24802c88fea81680ea2411392d32dbd9b9828a690a462954b894708d511124a2db34ec4179841e07a897169f0f1ac0e").unwrap(), - hex::decode("65247ace6355f978a12235265410c44f3ded02849ec8f8e6db2ac705c3f57d322ea073c13cf698e15d7e1d7f2bc95e7b3533be0dee26f58864f1664df0c1ebba").unwrap(), - hex::decode("d0c2bb451d0a3465a7fef7770718e5e49bf092a85dbf5af66ea26ec9c1b359026905fc1457e2b98b01ede7ba42aedcc525301f747a0ed9a9b61c37f27f9d8812").unwrap(), - hex::decode("250d9ec7ebec73e885798ae9427e1ea47b5ae66059b465b7c0fd132d17a9c2dcae29ba72863c1861cfb776d342812c4e9000981c4a40819430d0e84aa8bfeb0d").unwrap(), - hex::decode("0549cc0a5e5b9920796b88784c49b7d9a04cf2e86ab18d5af7b00780e60fb0fb5a7129945f4f918201dbad5348d4ccface4370f266540f8e072cdb46d3705930").unwrap(), - hex::decode("e543a26031dbdc8597b1beeba48a4f1cf6ab90c0e5b9343936b6e948a791198fc4fa22928e21edec812a04d0c9629772bf78e475d91a323cd8a8a6e005f92b4d").unwrap(), - hex::decode("4e4be69ad170fb8b3b17835913391ee537098d49e4452844a71ab2147ac55e45871c8943271806034ee9450b31c9486db9d26942946f48040ece7eea81424af1").unwrap(), - hex::decode("cb8a528288f902349250f9e8015e8334b0e24c2eeb9bb7d75e73c39024685804577565e62aca35948d2686ea38e9f8de97837ea30d2fb08347768394416e4a38").unwrap(), - hex::decode("fce94c47196a56a5cb94d5151ca429daf1c563ae889d0a42c2d03cfe43c94a636221c7e21b0668de9e5b6b32ee1e78b2c9aabc16537bf79c7b85eb956f433ac7").unwrap(), - hex::decode("fc8a125c9e2418c87907db4437a0ad6a378bba728ac8e0ce0e64f2a2f4b8201315e1b08d7983ce597cb68be2a2400d6d0d59b7359fe3dc9daca73d468da48972").unwrap(), - hex::decode("49290417311420d67f029a80b013b754150dd0097aa64de1c14a2467ab2e26cc2724071c04cb90cb0cf6c6353cf31f63235af7849d6ba023fd0fc0bc79d32f0b").unwrap(), - hex::decode("45c65effdc8007c9f2fc9057af986e94eb5c12b755465058d4b933ee37638452c5eeca4b43b8cbddabc60f29cbe5676b0bc55c0da88f8d0c36068e7d17ee603a").unwrap(), - hex::decode("a51e4e0f28aee3024207d87a5a1965313bdba4df44c6b845f7ca3408e5dabfe873df6b6ba26000e841f83f69e1de7857122ba538b42f255da2d013208af806ba").unwrap(), - hex::decode("5dbd891bf3bcfd5d054274759c13552aeaa187949875d81ee62ed394253ae25182e78b3a4a1976a7674e425bab860931d57f8a1d4fdc81fa4c3e8e8bf9016d5d").unwrap(), - hex::decode("3b5b044026e9066d62ce2f5a1fb01052a8cfe200dea28d421fc70f42c4d2b890b90ffef5675de1e47e4a20c9ca8700ceea23a61338ac759a098d167fa71642cb").unwrap(), - hex::decode("bb4017880cfa1e37f256dfe2a9cdb1349ed5dea8f69de75dc5933540dcf49e69afc33c837ba8a791857e16fad8581c4e9046778c49ca1ecd1fb675983be6d721").unwrap(), - hex::decode("517bbdb6e9e5f4702193064543204e780f5d33a866d0dcd65ada19f05715dea60ca81b842de5dca8f6b84a9cf469c8fb81991369dba21571476cc9c8d4ff2136").unwrap(), - ]; - - let expected_eta_v_values = vec![ - hex::decode("2af15f57076a8ff225746624882a77c8d2736fe41d3db70154a22b50af851246") - .unwrap(), - hex::decode("a815ff978369b57df09b0072485c26920dc0ec8e924a852a42f0715981cf0042") - .unwrap(), - hex::decode("f112d91435b911b6b5acaf27198762905b1cdec8c5a7b712f925ce3c5c76bb5f") - .unwrap(), - hex::decode("5450d95d9be4194a0ded40fbb4036b48d1f1d6da796e933fefd2c5c888794b4b") - .unwrap(), - hex::decode("c5c0f406cb522ad3fead4ecc60bce9c31e80879bc17eb1bb9acaa9b998cdf8bf") - .unwrap(), - hex::decode("5857048c728580549de645e087ba20ef20bb7c51cc84b5bc89df6b8b0ed98c41") - .unwrap(), - hex::decode("d6f40ef403687115db061b2cb9b1ab4ddeb98222075d5a3e03c8d217d4d7c40e") - .unwrap(), - hex::decode("5489d75a9f4971c1824462b5e2338609a91f121241f21fee09811bd5772ae0a8") - .unwrap(), - hex::decode("04716326833ecdb595153adac9566a4b39e5c16e8d02526cb4166e4099a00b1a") - .unwrap(), - hex::decode("39db709f50c8a279f0a94adcefb9360dbda6cdce168aed4288329a9cd53492b6") - .unwrap(), - hex::decode("c784b8c8678e0a04748a3ad851dd7c34ed67141cd9dc0c50ceaff4df804699a7") - .unwrap(), - hex::decode("cc1a5861358c075de93a26a91c5a951d5e71190d569aa2dc786d4ca8fc80cc38") - .unwrap(), - hex::decode("514979c89313c49e8f59fb8445113fa7623e99375cc4917fe79df54f8d4bdfce") - .unwrap(), - hex::decode("6a783e04481b9e04e8f3498a3b74c90c06a1031fb663b6793ce592a6c26f56f4") - .unwrap(), - hex::decode("1190f5254599dcee4f3cf1afdf4181085c36a6db6c30f334bfe6e6f320a6ed91") - .unwrap(), - hex::decode("91c777d6db066fe58edd67cd751fc7240268869b365393f6910e0e8f0fa58af3") - .unwrap(), - hex::decode("c545d83926c011b5c68a72de9a4e2f9da402703f4aab1b967456eae73d9f89b3") - .unwrap(), - hex::decode("ec31d2348bf543482842843a61d5b32691dedf801f198d68126c423ddf391e8b") - .unwrap(), - hex::decode("de223867d5c972895dd99ac0280a3e02947a7fb018ed42ed048266f913d2dfc2") - .unwrap(), - hex::decode("4dd9801752aade9c6e06bf03e9d2ec8a30ef7c6f30106790a23a9599e90ee08a") - .unwrap(), - hex::decode("fcb183abd512271f40408a5872827ce79cc2dda685a986a7dbdc61d842495a91") - .unwrap(), - hex::decode("e834d8ffd6dd042167b13e38512c62afdaf4d635d5b1ab0d513e08e9bef0ef63") - .unwrap(), - hex::decode("270a78257a958cd5fdb26f0b9ab302df2d2196fd04989f7ca1bb703e4dd904f0") - .unwrap(), - hex::decode("7e324f67af787dfddee10354128c60c60bf601bd8147c867d2471749a7b0f334") - .unwrap(), - hex::decode("54521ed42e0e782b5268ec55f80cff582162bc23fdcee5cdaa0f1a2ce7fa1f02") - .unwrap(), - hex::decode("557c296a71d8c9cb3fe7dcd95fbf4d70f6a3974d93c71b450d62a41b9a85d5a1") - .unwrap(), - hex::decode("20e078301ca282857378bbf10ac40965445c4c9fa73a160e0a116b4cf808b4b4") - .unwrap(), - hex::decode("b5a741dd3ff6a5a3d27b4d046dfb7a3901aacd37df7e931ba05e1320ad155c1c") - .unwrap(), - hex::decode("8b445f35f4a7b76e5d279d71fa9e05376a7c4533ca8b2b98fd2dbaf814d3bf8f") - .unwrap(), - hex::decode("08e7b5277abc139deb50f61264375fa091c580f8a85f259be78a002f7023c31f") - .unwrap(), - ]; - - let mut rolling_nonce_generator = - RollingNonceGenerator::new(Hash::from(shelley_genesis_hash.as_slice())); - - for (eta_vrf_0, expected_eta_v) in eta_vrf_0_values.iter().zip(expected_eta_v_values.iter()) - { - rolling_nonce_generator.apply_block(eta_vrf_0).unwrap(); - rolling_nonce_generator = - RollingNonceGenerator::new(rolling_nonce_generator.finalize().unwrap()); - assert_eq!( - rolling_nonce_generator.nonce.as_ref(), - expected_eta_v.as_slice() - ); - } - } -} diff --git a/pallas-crypto/src/vrf/mod.rs b/pallas-crypto/src/vrf/mod.rs index 250c23f6..a0143fac 100644 --- a/pallas-crypto/src/vrf/mod.rs +++ b/pallas-crypto/src/vrf/mod.rs @@ -1,36 +1,111 @@ +use crate::hash::Hash; +use crate::memsec::Scrubbed; use thiserror::Error; use vrf_dalek::vrf03::{PublicKey03, SecretKey03, VrfProof03}; #[derive(Error, Debug)] pub enum Error { - #[error("TryFromSlice {0}")] - TryFromSlice(#[from] std::array::TryFromSliceError), - #[error("VrfError {0}")] VrfError(#[from] vrf_dalek::errors::VrfError), } -/// Sign a seed value with a vrf secret key and produce a proof signature -pub fn vrf_prove(secret_key: &[u8], seed: &[u8]) -> Result, Error> { - let sk = SecretKey03::from_bytes(secret_key[..32].try_into()?); - let pk = PublicKey03::from(&sk); - let proof = VrfProof03::generate(&pk, &sk, seed); - Ok(proof.to_bytes().to_vec()) +pub const VRF_SEED_SIZE: usize = 32; +pub const VRF_PROOF_SIZE: usize = 80; +pub const VRF_PUBLIC_KEY_SIZE: usize = 32; +pub const VRF_SECRET_KEY_SIZE: usize = 32; +pub const VRF_PROOF_HASH_SIZE: usize = 64; + +// Wrapper for VRF secret key +pub struct VrfSecretKey { + secret_key_03: SecretKey03, } -/// Convert a proof signature to a hash -pub fn vrf_proof_to_hash(proof: &[u8]) -> Result, Error> { - let proof = VrfProof03::from_bytes(proof[..80].try_into()?)?; - Ok(proof.proof_to_hash().to_vec()) +// Wrapper for VRF public key +pub struct VrfPublicKey { + public_key_03: PublicKey03, } -/// Verify a proof signature with a vrf public key. This will return a hash to compare with the original -/// signature hash, but any non-error result is considered a successful verification without needing -/// to do the extra comparison check. -pub fn vrf_verify(public_key: &[u8], signature: &[u8], seed: &[u8]) -> Result, Error> { - let pk = PublicKey03::from_bytes(public_key.try_into()?); - let proof = VrfProof03::from_bytes(signature.try_into()?)?; - Ok(proof.verify(&pk, seed)?.to_vec()) +// Wrapper for VRF proof +pub struct VrfProof { + proof_03: VrfProof03, +} + +// Create a VrfSecretKey from a slice +impl From<&[u8; VRF_SECRET_KEY_SIZE]> for VrfSecretKey { + fn from(slice: &[u8; VRF_SECRET_KEY_SIZE]) -> Self { + VrfSecretKey { + secret_key_03: SecretKey03::from_bytes(slice), + } + } +} + +// Create a VrfPublicKey from a slice +impl From<&[u8; VRF_PUBLIC_KEY_SIZE]> for VrfPublicKey { + fn from(slice: &[u8; VRF_PUBLIC_KEY_SIZE]) -> Self { + VrfPublicKey { + public_key_03: PublicKey03::from_bytes(slice), + } + } +} + +// Create a VrfProof from a slice +impl From<&[u8; VRF_PROOF_SIZE]> for VrfProof { + fn from(slice: &[u8; VRF_PROOF_SIZE]) -> Self { + VrfProof { + proof_03: VrfProof03::from_bytes(slice).unwrap(), + } + } +} + +// Create a VrfPublicKey from a VrfSecretKey +impl From<&VrfSecretKey> for VrfPublicKey { + fn from(secret_key: &VrfSecretKey) -> Self { + VrfPublicKey { + public_key_03: PublicKey03::from(&secret_key.secret_key_03), + } + } +} + +// Wipe the secret key +impl Scrubbed for VrfSecretKey { + fn scrub(&mut self) { + self.secret_key_03.to_bytes().scrub(); + } +} + +impl Drop for VrfSecretKey { + fn drop(&mut self) { + self.scrub(); + } +} + +impl VrfSecretKey { + /// Sign a challenge message value with a vrf secret key and produce a proof signature + pub fn prove(&self, challenge: &[u8]) -> VrfProof { + let pk = PublicKey03::from(&self.secret_key_03); + let proof = VrfProof03::generate(&pk, &self.secret_key_03, challenge); + VrfProof { proof_03: proof } + } +} + +impl VrfProof { + /// Convert a proof signature to a hash + pub fn to_hash(&self) -> Hash { + Hash::from(self.proof_03.proof_to_hash()) + } + + /// Verify a proof signature with a vrf public key. This will return a hash to compare with the original + /// signature hash, but any non-error result is considered a successful verification without needing + /// to do the extra comparison check. + pub fn verify( + &self, + public_key: &VrfPublicKey, + seed: &[u8], + ) -> Result, Error> { + Ok(Hash::from( + self.proof_03.verify(&public_key.public_key_03, seed)?, + )) + } } #[cfg(test)] @@ -53,22 +128,32 @@ mod tests { // "description": "VRF Signing Key", // "cborHex": "5840adb9c97bec60189aa90d01d113e3ef405f03477d82a94f81da926c90cd46a374e0ff2371508ac339431b50af7d69cde0f120d952bb876806d3136f9a7fda4381" // } - - let vrf_skey = hex::decode("adb9c97bec60189aa90d01d113e3ef405f03477d82a94f81da926c90cd46a374e0ff2371508ac339431b50af7d69cde0f120d952bb876806d3136f9a7fda4381").unwrap(); - let vrf_vkey = + let raw_vrf_skey: Vec = hex::decode("adb9c97bec60189aa90d01d113e3ef405f03477d82a94f81da926c90cd46a374e0ff2371508ac339431b50af7d69cde0f120d952bb876806d3136f9a7fda4381").unwrap(); + let raw_vrf_vkey: Vec = hex::decode("e0ff2371508ac339431b50af7d69cde0f120d952bb876806d3136f9a7fda4381") .unwrap(); - // random seed to sign with vrf_skey - let mut seed = [0u8; 64]; - thread_rng().fill(&mut seed); + let vrf_skey = VrfSecretKey::from(&raw_vrf_skey[..VRF_SECRET_KEY_SIZE].try_into().unwrap()); + let vrf_vkey = + VrfPublicKey::from(&raw_vrf_vkey[..VRF_PUBLIC_KEY_SIZE].try_into().unwrap() + as &[u8; VRF_PUBLIC_KEY_SIZE]); + + let calculated_vrf_vkey = VrfPublicKey::from(&vrf_skey); + assert_eq!( + vrf_vkey.public_key_03.as_bytes(), + calculated_vrf_vkey.public_key_03.as_bytes() + ); + + // random challenge to sign with vrf_skey + let mut challenge = [0u8; 64]; + thread_rng().fill(&mut challenge); // create a proof signature and hash of the seed - let proof_signature = vrf_prove(&vrf_skey, &seed).unwrap(); - let proof_hash = vrf_proof_to_hash(&proof_signature).unwrap(); + let proof = vrf_skey.prove(&challenge); + let proof_hash = proof.to_hash(); // verify the proof signature with the public vrf public key - let verified_hash = vrf_verify(&vrf_vkey, &proof_signature, &seed).unwrap(); + let verified_hash = proof.verify(&vrf_vkey, &challenge).unwrap(); assert_eq!(proof_hash, verified_hash); } } diff --git a/pallas-math/Cargo.toml b/pallas-math/Cargo.toml index fe3a9cd7..c558625f 100644 --- a/pallas-math/Cargo.toml +++ b/pallas-math/Cargo.toml @@ -12,16 +12,13 @@ authors = ["Andrew Westberg "] exclude = ["tests/data/*"] [features] -default = ["gmp"] gmp = ["dep:gmp-mpfr-sys"] -num = ["dep:num-bigint", "dep:num-integer", "dep:num-traits"] [dependencies] gmp-mpfr-sys = { version = "1.6.4", features = ["mpc"], default-features = false, optional = true } once_cell = "1.19.0" -num-bigint = { version = "0.4.6", optional = true } -num-integer = { version = "0.1.46", optional = true } -num-traits = { version = "0.2.19", optional = true } +malachite = "0.4.16" +malachite-base = "0.4.16" regex = "1.10.5" thiserror = "1.0.61" diff --git a/pallas-math/src/lib.rs b/pallas-math/src/lib.rs index 485178c8..ca3da9da 100644 --- a/pallas-math/src/lib.rs +++ b/pallas-math/src/lib.rs @@ -1,14 +1,6 @@ pub mod math; -// Ensure only one of `gmp` or `num` is enabled, not both. -#[cfg(all(feature = "gmp", feature = "num"))] -compile_error!("Features `gmp` and `num` are mutually exclusive."); - -#[cfg(all(not(feature = "gmp"), not(feature = "num")))] -compile_error!("One of the features `gmp` or `num` must be enabled."); - #[cfg(feature = "gmp")] pub mod math_gmp; - -#[cfg(feature = "num")] -pub mod math_num; +#[cfg(not(feature = "gmp"))] +pub mod math_malachite; diff --git a/pallas-math/src/math.rs b/pallas-math/src/math.rs index df9eba81..a50ab1c1 100644 --- a/pallas-math/src/math.rs +++ b/pallas-math/src/math.rs @@ -8,9 +8,9 @@ use std::ops::{Div, Mul, Neg, Sub}; use thiserror::Error; #[cfg(feature = "gmp")] -use crate::math_gmp::Decimal; -#[cfg(feature = "num")] -use crate::math_num::Decimal; +pub type FixedDecimal = crate::math_gmp::Decimal; +#[cfg(not(feature = "gmp"))] +pub type FixedDecimal = crate::math_malachite::Decimal; #[derive(Debug, Error)] pub enum Error { @@ -49,6 +49,18 @@ pub trait FixedPrecision: /// Entry point for bounded iterations for comparing two exp values. fn exp_cmp(&self, max_n: u64, bound_self: i64, compare: &Self) -> ExpCmpOrdering; + + /// Round to the nearest integer number + fn round(&self) -> Self; + + /// Round down to the nearest integer number + fn floor(&self) -> Self; + + /// Round up to the nearest integer number + fn ceil(&self) -> Self; + + /// Truncate to the nearest integer number + fn trunc(&self) -> Self; } #[derive(Debug, Clone, PartialEq)] @@ -72,54 +84,51 @@ impl From<&str> for ExpOrdering { pub struct ExpCmpOrdering { pub iterations: u64, pub estimation: ExpOrdering, - pub approx: Decimal, + pub approx: FixedDecimal, } #[cfg(test)] mod tests { + use super::*; use std::fs::File; use std::io::BufRead; use std::path::PathBuf; - #[cfg(feature = "gmp")] - use crate::math_gmp::Decimal; - #[cfg(feature = "num")] - use crate::math_num::Decimal; - - use super::*; - #[test] fn test_fixed_precision() { - let fp: Decimal = Decimal::new(34); + let fp: FixedDecimal = FixedDecimal::new(34); assert_eq!(fp.precision(), 34); assert_eq!(fp.to_string(), "0.0000000000000000000000000000000000"); } #[test] fn test_fixed_precision_eq() { - let fp1: Decimal = Decimal::new(34); - let fp2: Decimal = Decimal::new(34); + let fp1: FixedDecimal = FixedDecimal::new(34); + let fp2: FixedDecimal = FixedDecimal::new(34); assert_eq!(fp1, fp2); } #[test] fn test_fixed_precision_from_str() { - let fp: Decimal = Decimal::from_str("1234567890123456789012345678901234", 34).unwrap(); + let fp: FixedDecimal = + FixedDecimal::from_str("1234567890123456789012345678901234", 34).unwrap(); assert_eq!(fp.precision(), 34); assert_eq!(fp.to_string(), "0.1234567890123456789012345678901234"); - let fp: Decimal = Decimal::from_str("-1234567890123456789012345678901234", 30).unwrap(); + let fp: FixedDecimal = + FixedDecimal::from_str("-1234567890123456789012345678901234", 30).unwrap(); assert_eq!(fp.precision(), 30); assert_eq!(fp.to_string(), "-1234.567890123456789012345678901234"); - let fp: Decimal = Decimal::from_str("-1234567890123456789012345678901234", 34).unwrap(); + let fp: FixedDecimal = + FixedDecimal::from_str("-1234567890123456789012345678901234", 34).unwrap(); assert_eq!(fp.precision(), 34); assert_eq!(fp.to_string(), "-0.1234567890123456789012345678901234"); } #[test] fn test_fixed_precision_exp() { - let fp: Decimal = Decimal::from(1u64); + let fp: FixedDecimal = FixedDecimal::from(1u64); assert_eq!(fp.to_string(), "1.0000000000000000000000000000000000"); let exp_fp = fp.exp(); assert_eq!(exp_fp.to_string(), "2.7182818284590452353602874043083282"); @@ -127,8 +136,10 @@ mod tests { #[test] fn test_fixed_precision_mul() { - let fp1: Decimal = Decimal::from_str("52500000000000000000000000000000000", 34).unwrap(); - let fp2: Decimal = Decimal::from_str("43000000000000000000000000000000000", 34).unwrap(); + let fp1: FixedDecimal = + FixedDecimal::from_str("52500000000000000000000000000000000", 34).unwrap(); + let fp2: FixedDecimal = + FixedDecimal::from_str("43000000000000000000000000000000000", 34).unwrap(); let fp3 = &fp1 * &fp2; assert_eq!(fp3.to_string(), "22.5750000000000000000000000000000000"); let fp4 = fp1 * fp2; @@ -137,8 +148,8 @@ mod tests { #[test] fn test_fixed_precision_div() { - let fp1: Decimal = Decimal::from_str("1", 34).unwrap(); - let fp2: Decimal = Decimal::from_str("10", 34).unwrap(); + let fp1: FixedDecimal = FixedDecimal::from_str("1", 34).unwrap(); + let fp2: FixedDecimal = FixedDecimal::from_str("10", 34).unwrap(); let fp3 = &fp1 / &fp2; assert_eq!(fp3.to_string(), "0.1000000000000000000000000000000000"); let fp4 = fp1 / fp2; @@ -147,9 +158,9 @@ mod tests { #[test] fn test_fixed_precision_sub() { - let fp1: Decimal = Decimal::from_str("1", 34).unwrap(); + let fp1: FixedDecimal = FixedDecimal::from_str("1", 34).unwrap(); assert_eq!(fp1.to_string(), "0.0000000000000000000000000000000001"); - let fp2: Decimal = Decimal::from_str("10", 34).unwrap(); + let fp2: FixedDecimal = FixedDecimal::from_str("10", 34).unwrap(); assert_eq!(fp2.to_string(), "0.0000000000000000000000000000000010"); let fp3 = &fp1 - &fp2; assert_eq!(fp3.to_string(), "-0.0000000000000000000000000000000009"); @@ -157,6 +168,214 @@ mod tests { assert_eq!(fp4.to_string(), "-0.0000000000000000000000000000000009"); } + #[test] + fn test_fixed_precision_round() { + let fp1: FixedDecimal = + FixedDecimal::from_str("11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp1.round().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.round().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.round().to_string(), + "2.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("1500", 3).unwrap(); + assert_eq!(fp4.round().to_string(), "2.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("1499", 3).unwrap(); + assert_eq!(fp5.round().to_string(), "1.000"); + let fp6: FixedDecimal = + FixedDecimal::from_str("-11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp6.round().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("-14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.round().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("-15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.round().to_string(), + "-2.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("-1500", 3).unwrap(); + assert_eq!(fp4.round().to_string(), "-2.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("-1499", 3).unwrap(); + assert_eq!(fp5.round().to_string(), "-1.000"); + let fp6: FixedDecimal = FixedDecimal::from_str("1000", 3).unwrap(); + assert_eq!(fp6.round().to_string(), "1.000"); + let fp7: FixedDecimal = FixedDecimal::from_str("-1000", 3).unwrap(); + assert_eq!(fp7.round().to_string(), "-1.000"); + } + + #[test] + fn test_fixed_precision_floor() { + let fp1: FixedDecimal = + FixedDecimal::from_str("11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp1.floor().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.floor().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.floor().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("1500", 3).unwrap(); + assert_eq!(fp4.floor().to_string(), "1.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("1499", 3).unwrap(); + assert_eq!(fp5.floor().to_string(), "1.000"); + let fp6: FixedDecimal = + FixedDecimal::from_str("-11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp6.floor().to_string(), + "-2.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("-14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.floor().to_string(), + "-2.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("-15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.floor().to_string(), + "-2.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("-1500", 3).unwrap(); + assert_eq!(fp4.floor().to_string(), "-2.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("-1499", 3).unwrap(); + assert_eq!(fp5.floor().to_string(), "-2.000"); + let fp6: FixedDecimal = FixedDecimal::from_str("1000", 3).unwrap(); + assert_eq!(fp6.floor().to_string(), "1.000"); + let fp7: FixedDecimal = FixedDecimal::from_str("-1000", 3).unwrap(); + assert_eq!(fp7.floor().to_string(), "-1.000"); + } + + #[test] + fn test_fixed_precision_ceil() { + let fp1: FixedDecimal = + FixedDecimal::from_str("11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp1.ceil().to_string(), + "2.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.ceil().to_string(), + "2.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.ceil().to_string(), + "2.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("1500", 3).unwrap(); + assert_eq!(fp4.ceil().to_string(), "2.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("1499", 3).unwrap(); + assert_eq!(fp5.ceil().to_string(), "2.000"); + let fp6: FixedDecimal = + FixedDecimal::from_str("-11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp6.ceil().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("-14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.ceil().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("-15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.ceil().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("-1500", 3).unwrap(); + assert_eq!(fp4.ceil().to_string(), "-1.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("-1499", 3).unwrap(); + assert_eq!(fp5.ceil().to_string(), "-1.000"); + let fp6: FixedDecimal = FixedDecimal::from_str("1000", 3).unwrap(); + assert_eq!(fp6.ceil().to_string(), "1.000"); + let fp7: FixedDecimal = FixedDecimal::from_str("-1000", 3).unwrap(); + assert_eq!(fp7.ceil().to_string(), "-1.000"); + } + + #[test] + fn test_fixed_precision_trunc() { + let fp1: FixedDecimal = + FixedDecimal::from_str("11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp1.trunc().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.trunc().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.trunc().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("1500", 3).unwrap(); + assert_eq!(fp4.trunc().to_string(), "1.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("1499", 3).unwrap(); + assert_eq!(fp5.trunc().to_string(), "1.000"); + let fp6: FixedDecimal = + FixedDecimal::from_str("-11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp6.trunc().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("-14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.trunc().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("-15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.trunc().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("-1500", 3).unwrap(); + assert_eq!(fp4.trunc().to_string(), "-1.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("-1499", 3).unwrap(); + assert_eq!(fp5.trunc().to_string(), "-1.000"); + let fp6: FixedDecimal = FixedDecimal::from_str("1000", 3).unwrap(); + assert_eq!(fp6.trunc().to_string(), "1.000"); + let fp7: FixedDecimal = FixedDecimal::from_str("-1000", 3).unwrap(); + assert_eq!(fp7.trunc().to_string(), "-1.000"); + } + #[test] fn golden_tests() { let mut data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); @@ -172,20 +391,20 @@ mod tests { let file = File::open(data_path).expect("golden_tests_result.txt: file not found"); let result_reader = std::io::BufReader::new(file); - let one: Decimal = Decimal::from(1u64); - let ten: Decimal = Decimal::from(10u64); - let f: Decimal = &one / &ten; + let one: FixedDecimal = FixedDecimal::from(1u64); + let ten: FixedDecimal = FixedDecimal::from(10u64); + let f: FixedDecimal = &one / &ten; assert_eq!(f.to_string(), "0.1000000000000000000000000000000000"); for (test_line, result_line) in reader.lines().zip(result_reader.lines()) { let test_line = test_line.expect("failed to read line"); // println!("test_line: {}", test_line); let mut parts = test_line.split_whitespace(); - let x = Decimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) + let x = FixedDecimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) .expect("failed to parse x"); - let a = Decimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) + let a = FixedDecimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) .expect("failed to parse a"); - let b = Decimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) + let b = FixedDecimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) .expect("failed to parse b"); let result_line = result_line.expect("failed to read line"); // println!("result_line: {}", result_line); @@ -210,7 +429,13 @@ mod tests { let c = &one - &f; assert_eq!(c.to_string(), "0.9000000000000000000000000000000000"); let threshold_b = c.pow(&b); - assert_eq!((&one - &threshold_b).to_string(), expected_threshold_b); + assert_eq!( + (&one - &threshold_b).to_string(), + expected_threshold_b, + "(1 - f) *** b failed to match! - (1 - f)={}, b={}", + &c, + &b + ); // do Taylor approximation for // a < 1 - (1 - f) *** b <=> 1/(1-a) < exp(-b * ln' (1 - f)) diff --git a/pallas-math/src/math_gmp.rs b/pallas-math/src/math_gmp.rs index 3a3b2e22..b64e2e88 100644 --- a/pallas-math/src/math_gmp.rs +++ b/pallas-math/src/math_gmp.rs @@ -124,6 +124,19 @@ impl Neg for Decimal { } } +// Implement Neg for a reference to Decimal +impl<'a> Neg for &'a Decimal { + type Output = Decimal; + + fn neg(self) -> Self::Output { + unsafe { + let mut result = Decimal::new(self.precision); + mpz_neg(&mut result.data, &self.data); + result + } + } +} + impl Mul for Decimal { type Output = Self; @@ -201,6 +214,31 @@ impl<'a, 'b> Sub<&'b Decimal> for &'a Decimal { } } +impl Add for Decimal { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + unsafe { + let mut result = Decimal::new(self.precision); + mpz_add(&mut result.data, &self.data, &rhs.data); + result + } + } +} + +// Implement Add for a reference to Decimal +impl<'a, 'b> Add<&'b Decimal> for &'a Decimal { + type Output = Decimal; + + fn add(self, rhs: &'b Decimal) -> Self::Output { + unsafe { + let mut result = Decimal::new(self.precision); + mpz_add(&mut result.data, &self.data, &rhs.data); + result + } + } +} + impl FixedPrecision for Decimal { fn new(precision: u64) -> Self { unsafe { @@ -279,6 +317,18 @@ impl FixedPrecision for Decimal { ) } } + + fn round(&self) -> Self { + todo!() + } + + fn floor(&self) -> Self { + todo!() + } + + fn ceil(&self) -> Self { + todo!() + } } /// # Safety diff --git a/pallas-math/src/math_num.rs b/pallas-math/src/math_malachite.rs similarity index 60% rename from pallas-math/src/math_num.rs rename to pallas-math/src/math_malachite.rs index 0ae7ae37..d56b526f 100644 --- a/pallas-math/src/math_num.rs +++ b/pallas-math/src/math_malachite.rs @@ -2,24 +2,24 @@ # Cardano Math functions using the num-bigint crate */ +use crate::math::{Error, ExpCmpOrdering, ExpOrdering, FixedPrecision, DEFAULT_PRECISION}; +use malachite::num::arithmetic::traits::{Abs, DivRem, DivRound, Pow, PowAssign}; +use malachite::num::basic::traits::One; +use malachite::rounding_modes::RoundingMode; +use malachite::{Integer, Natural}; +use malachite_base::num::arithmetic::traits::Sign; +use once_cell::sync::Lazy; +use regex::Regex; use std::cmp::Ordering; use std::fmt::{Display, Formatter}; -use std::ops::{Div, Mul, Neg, Sub}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::str::FromStr; -use num_bigint::BigInt; -use num_integer::Integer; -use num_traits::{Signed, ToPrimitive}; -use once_cell::sync::Lazy; -use regex::Regex; - -use crate::math::{Error, ExpCmpOrdering, ExpOrdering, FixedPrecision, DEFAULT_PRECISION}; - #[derive(Debug, Clone)] pub struct Decimal { precision: u64, - precision_multiplier: BigInt, - data: BigInt, + precision_multiplier: Integer, + data: Integer, } impl PartialEq for Decimal { @@ -58,7 +58,7 @@ impl Display for Decimal { impl From for Decimal { fn from(n: u64) -> Self { let mut result = Decimal::new(DEFAULT_PRECISION); - result.data = BigInt::from(n) * &result.precision_multiplier; + result.data = Integer::from(n) * &result.precision_multiplier; result } } @@ -66,19 +66,43 @@ impl From for Decimal { impl From for Decimal { fn from(n: i64) -> Self { let mut result = Decimal::new(DEFAULT_PRECISION); - result.data = BigInt::from(n) * &result.precision_multiplier; + result.data = Integer::from(n) * &result.precision_multiplier; + result + } +} + +impl From for Decimal { + fn from(n: Integer) -> Self { + let mut result = Decimal::new(DEFAULT_PRECISION); + result.data.clone_from(&n); result } } -impl From<&BigInt> for Decimal { - fn from(n: &BigInt) -> Self { +impl From<&Integer> for Decimal { + fn from(n: &Integer) -> Self { let mut result = Decimal::new(DEFAULT_PRECISION); result.data.clone_from(n); result } } +impl From for Decimal { + fn from(n: Natural) -> Self { + let mut result = Decimal::new(DEFAULT_PRECISION); + result.data.clone_from(&Integer::from(n)); + result + } +} + +impl From<&Natural> for Decimal { + fn from(n: &Natural) -> Self { + let mut result = Decimal::new(DEFAULT_PRECISION); + result.data.clone_from(&Integer::from(n)); + result + } +} + impl Neg for Decimal { type Output = Self; @@ -89,6 +113,17 @@ impl Neg for Decimal { } } +// Implement Neg for a reference to Decimal +impl<'a> Neg for &'a Decimal { + type Output = Decimal; + + fn neg(self) -> Self::Output { + let mut result = Decimal::new(self.precision); + result.data = -&self.data; + result + } +} + impl Mul for Decimal { type Output = Self; @@ -100,6 +135,13 @@ impl Mul for Decimal { } } +impl MulAssign for Decimal { + fn mul_assign(&mut self, rhs: Self) { + self.data *= &rhs.data; + scale(&mut self.data); + } +} + // Implement Mul for a reference to Decimal impl<'a, 'b> Mul<&'b Decimal> for &'a Decimal { type Output = Decimal; @@ -112,6 +154,13 @@ impl<'a, 'b> Mul<&'b Decimal> for &'a Decimal { } } +impl<'a, 'b> MulAssign<&'b Decimal> for &'a mut Decimal { + fn mul_assign(&mut self, rhs: &'b Decimal) { + self.data *= &rhs.data; + scale(&mut self.data); + } +} + impl Div for Decimal { type Output = Self; @@ -122,6 +171,13 @@ impl Div for Decimal { } } +impl DivAssign for Decimal { + fn div_assign(&mut self, rhs: Self) { + let temp = self.data.clone(); + div(&mut self.data, &temp, &rhs.data); + } +} + // Implement Div for a reference to Decimal impl<'a, 'b> Div<&'b Decimal> for &'a Decimal { type Output = Decimal; @@ -133,6 +189,13 @@ impl<'a, 'b> Div<&'b Decimal> for &'a Decimal { } } +impl<'a, 'b> DivAssign<&'b Decimal> for &'a mut Decimal { + fn div_assign(&mut self, rhs: &'b Decimal) { + let temp = self.data.clone(); + div(&mut self.data, &temp, &rhs.data); + } +} + impl Sub for Decimal { type Output = Self; @@ -143,6 +206,12 @@ impl Sub for Decimal { } } +impl SubAssign for Decimal { + fn sub_assign(&mut self, rhs: Self) { + self.data -= &rhs.data; + } +} + // Implement Sub for a reference to Decimal impl<'a, 'b> Sub<&'b Decimal> for &'a Decimal { type Output = Decimal; @@ -154,11 +223,50 @@ impl<'a, 'b> Sub<&'b Decimal> for &'a Decimal { } } +impl<'a, 'b> SubAssign<&'b Decimal> for &'a mut Decimal { + fn sub_assign(&mut self, rhs: &'b Decimal) { + self.data -= &rhs.data; + } +} + +impl Add for Decimal { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + let mut result = Decimal::new(self.precision); + result.data = &self.data + &rhs.data; + result + } +} + +impl AddAssign for Decimal { + fn add_assign(&mut self, rhs: Self) { + self.data += &rhs.data; + } +} + +// Implement Add for a reference to Decimal +impl<'a, 'b> Add<&'b Decimal> for &'a Decimal { + type Output = Decimal; + + fn add(self, rhs: &'b Decimal) -> Self::Output { + let mut result = Decimal::new(self.precision); + result.data = &self.data + &rhs.data; + result + } +} + +impl<'a, 'b> AddAssign<&'b Decimal> for &'a mut Decimal { + fn add_assign(&mut self, rhs: &'b Decimal) { + self.data += &rhs.data; + } +} + impl FixedPrecision for Decimal { fn new(precision: u64) -> Self { - let ten = BigInt::from(10); - let precision_multiplier = ten.pow(precision as u32); - let data = BigInt::from(0); + let mut precision_multiplier = Integer::from(10); + precision_multiplier.pow_assign(precision); + let data = Integer::from(0); Decimal { precision, precision_multiplier, @@ -175,7 +283,7 @@ impl FixedPrecision for Decimal { } let mut decimal = Decimal::new(precision); - decimal.data = BigInt::from_str(s).unwrap(); + decimal.data = Integer::from_str(s).unwrap(); Ok(decimal) } @@ -211,9 +319,51 @@ impl FixedPrecision for Decimal { &compare.data, ) } + + fn round(&self) -> Self { + let mut result = self.clone(); + let half = &self.precision_multiplier / Integer::from(2); + let remainder = &self.data % &self.precision_multiplier; + if (&remainder).abs() >= half { + if self.data.sign() == Ordering::Less { + result.data -= &self.precision_multiplier + remainder; + } else { + result.data += &self.precision_multiplier - remainder; + } + } else { + result.data -= remainder; + } + result + } + + fn floor(&self) -> Self { + let mut result = self.clone(); + let remainder = &self.data % &self.precision_multiplier; + if self.data.sign() == Ordering::Less && remainder != 0 { + result.data -= &self.precision_multiplier; + } + result.data -= remainder; + result + } + + fn ceil(&self) -> Self { + let mut result = self.clone(); + let remainder = &self.data % &self.precision_multiplier; + if self.data.sign() == Ordering::Greater && remainder != 0 { + result.data += &self.precision_multiplier; + } + result.data -= remainder; + result + } + + fn trunc(&self) -> Self { + let mut result = self.clone(); + result.data -= &self.data % &self.precision_multiplier; + result + } } -fn print_fixedp(n: &BigInt, precision: &BigInt, width: usize) -> String { +fn print_fixedp(n: &Integer, precision: &Integer, width: usize) -> String { let (mut temp_q, mut temp_r) = n.div_rem(precision); let is_negative_q = temp_q < ZERO.value; @@ -243,11 +393,11 @@ fn print_fixedp(n: &BigInt, precision: &BigInt, width: usize) -> String { } struct Constant { - value: BigInt, + value: Integer, } impl Constant { - pub fn new(init: fn() -> BigInt) -> Constant { + pub fn new(init: fn() -> Integer) -> Constant { Constant { value: init() } } } @@ -256,14 +406,14 @@ unsafe impl Sync for Constant {} unsafe impl Send for Constant {} static DIGITS_REGEX: Lazy = Lazy::new(|| Regex::new(r"^-?\d+$").unwrap()); -static TEN: Lazy = Lazy::new(|| Constant::new(|| BigInt::from(10))); -static PRECISION: Lazy = Lazy::new(|| Constant::new(|| TEN.value.pow(34))); -static EPS: Lazy = Lazy::new(|| Constant::new(|| TEN.value.pow(34 - 24))); -static ONE: Lazy = Lazy::new(|| Constant::new(|| BigInt::from(1) * &PRECISION.value)); -static ZERO: Lazy = Lazy::new(|| Constant::new(|| BigInt::from(0))); +static TEN: Lazy = Lazy::new(|| Constant::new(|| Integer::from(10))); +static PRECISION: Lazy = Lazy::new(|| Constant::new(|| TEN.value.clone().pow(34))); +static EPS: Lazy = Lazy::new(|| Constant::new(|| TEN.value.clone().pow(34 - 24))); +static ONE: Lazy = Lazy::new(|| Constant::new(|| Integer::from(1) * &PRECISION.value)); +static ZERO: Lazy = Lazy::new(|| Constant::new(|| Integer::from(0))); static E: Lazy = Lazy::new(|| { Constant::new(|| { - let mut e = BigInt::from(0); + let mut e = Integer::from(0); ref_exp(&mut e, &ONE.value); e }) @@ -271,29 +421,28 @@ static E: Lazy = Lazy::new(|| { /// Entry point for 'exp' approximation. First does the scaling of 'x' to [0,1] /// and then calls the continued fraction approximation function. -fn ref_exp(rop: &mut BigInt, x: &BigInt) -> i32 { +fn ref_exp(rop: &mut Integer, x: &Integer) -> i32 { let mut iterations = 0; match x.cmp(&ZERO.value) { - std::cmp::Ordering::Equal => { + Ordering::Equal => { // rop = 1 rop.clone_from(&ONE.value); } - std::cmp::Ordering::Less => { + Ordering::Less => { let x_ = -x; - let mut temp = BigInt::from(0); + let mut temp = Integer::from(0); iterations = ref_exp(&mut temp, &x_); // rop = 1 / temp div(rop, &ONE.value, &temp); } - std::cmp::Ordering::Greater => { - let mut n_exponent = x.div_ceil(&PRECISION.value); - let n = n_exponent.to_u32().expect("n_exponent to_u32 failed"); - n_exponent *= &PRECISION.value; /* ceil(x) */ - let x_ = x / n; + Ordering::Greater => { + let (n_exponent, _) = x.div_round(&PRECISION.value, RoundingMode::Ceiling); + let x_ = x / &n_exponent; iterations = mp_exp_taylor(rop, 1000, &x_, &EPS.value); // rop = rop.pow(n) - ipow(rop, &rop.clone(), n as i64); + let n_exponent_i64: i64 = i64::try_from(&n_exponent).expect("n_exponent to_i64 failed"); + ipow(rop, &rop.clone(), n_exponent_i64); } } @@ -302,15 +451,15 @@ fn ref_exp(rop: &mut BigInt, x: &BigInt) -> i32 { /// Division with quotent and remainder #[inline] -fn div_qr(q: &mut BigInt, r: &mut BigInt, x: &BigInt, y: &BigInt) { +fn div_qr(q: &mut Integer, r: &mut Integer, x: &Integer, y: &Integer) { (*q, *r) = x.div_rem(y); } /// Division -pub fn div(rop: &mut BigInt, x: &BigInt, y: &BigInt) { - let mut temp_q = BigInt::from(0); - let mut temp_r = BigInt::from(0); - let mut temp: BigInt; +pub fn div(rop: &mut Integer, x: &Integer, y: &Integer) { + let mut temp_q = Integer::from(0); + let mut temp_r = Integer::from(0); + let mut temp: Integer; div_qr(&mut temp_q, &mut temp_r, x, y); temp = &temp_q * &PRECISION.value; @@ -322,7 +471,7 @@ pub fn div(rop: &mut BigInt, x: &BigInt, y: &BigInt) { *rop = temp; } /// Taylor / MacLaurin series approximation -fn mp_exp_taylor(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) -> i32 { +fn mp_exp_taylor(rop: &mut Integer, max_n: i32, x: &Integer, epsilon: &Integer) -> i32 { let mut divisor = ONE.value.clone(); let mut last_x = ONE.value.clone(); rop.clone_from(&ONE.value); @@ -333,12 +482,12 @@ fn mp_exp_taylor(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) -> let next_x2 = next_x.clone(); div(&mut next_x, &next_x2, &divisor); - if next_x.abs() < epsilon.abs() { + if (&next_x).abs() < epsilon.abs() { break; } divisor += &ONE.value; - *rop += &next_x; + *rop = &*rop + &next_x; last_x.clone_from(&next_x); n += 1; } @@ -346,27 +495,27 @@ fn mp_exp_taylor(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) -> n } -fn scale(rop: &mut BigInt) { - let mut temp = BigInt::from(0); - let mut a = BigInt::from(0); +pub(crate) fn scale(rop: &mut Integer) { + let mut temp = Integer::from(0); + let mut a = Integer::from(0); div_qr(&mut a, &mut temp, rop, &PRECISION.value); if *rop < ZERO.value && temp != ZERO.value { - a -= 1; + a -= Integer::ONE; } *rop = a; } /// Integer power internal function -fn ipow_(rop: &mut BigInt, x: &BigInt, n: i64) { +fn ipow_(rop: &mut Integer, x: &Integer, n: i64) { if n == 0 { rop.clone_from(&ONE.value); } else if n % 2 == 0 { - let mut res = BigInt::from(0); + let mut res = Integer::from(0); ipow_(&mut res, x, n / 2); *rop = &res * &res; scale(rop); } else { - let mut res = BigInt::from(0); + let mut res = Integer::from(0); ipow_(&mut res, x, n - 1); *rop = res * x; scale(rop); @@ -374,9 +523,9 @@ fn ipow_(rop: &mut BigInt, x: &BigInt, n: i64) { } /// Integer power -fn ipow(rop: &mut BigInt, x: &BigInt, n: i64) { +fn ipow(rop: &mut Integer, x: &Integer, n: i64) { if n < 0 { - let mut temp = BigInt::from(0); + let mut temp = Integer::from(0); ipow_(&mut temp, x, -n); div(rop, &ONE.value, &temp); } else { @@ -388,32 +537,32 @@ fn ipow(rop: &mut BigInt, x: &BigInt, n: i64) { /// maximum of 'maxN' iterations or until the absolute difference between two /// succeeding convergents is smaller than 'eps'. Assumes 'x' to be within /// [1,e). -fn mp_ln_n(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) { - let mut ba: BigInt; - let mut aa: BigInt; - let mut ab: BigInt; - let mut bb: BigInt; - let mut a_: BigInt; - let mut b_: BigInt; - let mut diff: BigInt; - let mut convergent: BigInt = BigInt::from(0); - let mut last: BigInt = BigInt::from(0); +fn mp_ln_n(rop: &mut Integer, max_n: i32, x: &Integer, epsilon: &Integer) { + let mut ba: Integer; + let mut aa: Integer; + let mut ab: Integer; + let mut bb: Integer; + let mut a_: Integer; + let mut b_: Integer; + let mut diff: Integer; + let mut convergent: Integer = Integer::from(0); + let mut last: Integer = Integer::from(0); let mut first = true; let mut n = 1; - let mut a: BigInt; + let mut a: Integer; let mut b = ONE.value.clone(); let mut an_m2 = ONE.value.clone(); - let mut bn_m2 = BigInt::from(0); - let mut an_m1 = BigInt::from(0); + let mut bn_m2 = Integer::from(0); + let mut an_m1 = Integer::from(0); let mut bn_m1 = ONE.value.clone(); let mut curr_a = 1; while n <= max_n + 2 { let curr_a_2 = curr_a * curr_a; - a = x * curr_a_2; + a = x * Integer::from(curr_a_2); if n > 1 && n % 2 == 1 { curr_a += 1; } @@ -455,12 +604,11 @@ fn mp_ln_n(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) { *rop = convergent; } -fn find_e(x: &BigInt) -> i64 { - let mut x_: BigInt = BigInt::from(0); - let mut x__: BigInt; +fn find_e(x: &Integer) -> i64 { + let mut x_: Integer = Integer::from(0); + let mut x__: Integer = E.value.clone(); div(&mut x_, &ONE.value, &E.value); - x__ = E.value.clone(); let mut l = -1; let mut u = 1; @@ -491,17 +639,17 @@ fn find_e(x: &BigInt) -> i64 { /// Entry point for 'ln' approximation. First does the necessary scaling, and /// then calls the continued fraction calculation. For any value outside the /// domain, i.e., 'x in (-inf,0]', the function returns '-INFINITY'. -fn ref_ln(rop: &mut BigInt, x: &BigInt) -> bool { - let mut factor = BigInt::from(0); - let mut x_ = BigInt::from(0); +fn ref_ln(rop: &mut Integer, x: &Integer) -> bool { + let mut factor = Integer::from(0); + let mut x_ = Integer::from(0); if x <= &ZERO.value { return false; } let n = find_e(x); - *rop = BigInt::from(n); - *rop = rop.clone() * &PRECISION.value; + *rop = Integer::from(n); + *rop = &*rop * &PRECISION.value; ref_exp(&mut factor, rop); div(&mut x_, x, &factor); @@ -510,14 +658,14 @@ fn ref_ln(rop: &mut BigInt, x: &BigInt) -> bool { let x_2 = x_.clone(); mp_ln_n(&mut x_, 1000, &x_2, &EPS.value); - *rop = rop.clone() + &x_; + *rop = &*rop + &x_; true } -fn ref_pow(rop: &mut BigInt, base: &BigInt, exponent: &BigInt) { +fn ref_pow(rop: &mut Integer, base: &Integer, exponent: &Integer) { /* x^y = exp(y * ln x) */ - let mut tmp: BigInt = BigInt::from(0); + let mut tmp: Integer = Integer::from(0); ref_ln(&mut tmp, base); tmp *= exponent; scale(&mut tmp); @@ -535,20 +683,20 @@ fn ref_pow(rop: &mut BigInt, base: &BigInt, exponent: &BigInt) { /// Lagrange remainder require knowledge of the maximum value to compute the /// maximal error of the remainder. fn ref_exp_cmp( - rop: &mut BigInt, + rop: &mut Integer, max_n: u64, - x: &BigInt, + x: &Integer, bound_x: i64, - compare: &BigInt, + compare: &Integer, ) -> ExpCmpOrdering { rop.clone_from(&ONE.value); let mut n = 0u64; - let mut divisor: BigInt; - let mut next_x: BigInt; - let mut error: BigInt; - let mut upper: BigInt; - let mut lower: BigInt; - let mut error_term: BigInt; + let mut divisor: Integer; + let mut next_x: Integer; + let mut error: Integer; + let mut upper: Integer; + let mut lower: Integer; + let mut error_term: Integer; divisor = ONE.value.clone(); error = x.clone(); @@ -556,7 +704,7 @@ fn ref_exp_cmp( let mut estimate = ExpOrdering::UNKNOWN; while n < max_n { next_x = error.clone(); - if next_x.abs() < EPS.value.abs() { + if (&next_x).abs() < (&EPS.value).abs() { break; } divisor += &ONE.value; @@ -568,8 +716,8 @@ fn ref_exp_cmp( scale(&mut error); let e2 = error.clone(); div(&mut error, &e2, &divisor); - error_term = &error * bound_x; - *rop += &next_x; + error_term = &error * Integer::from(bound_x); + *rop = &*rop + &next_x; /* compare is guaranteed to be above overall result */ upper = &*rop + &error_term;