diff --git a/crates/core/asset/src/balance.rs b/crates/core/asset/src/balance.rs index f5013f1d83..db2aa0482e 100644 --- a/crates/core/asset/src/balance.rs +++ b/crates/core/asset/src/balance.rs @@ -1,3 +1,4 @@ +use anyhow::anyhow; use ark_r1cs_std::prelude::*; use ark_r1cs_std::uint8::UInt8; use ark_relations::r1cs::SynthesisError; @@ -25,7 +26,7 @@ mod imbalance; mod iter; use commitment::VALUE_BLINDING_GENERATOR; use decaf377::{r1cs::ElementVar, Fq, Fr}; -use imbalance::Imbalance; +use imbalance::{Imbalance, Sign}; use self::commitment::BalanceCommitmentVar; use penumbra_proto::{penumbra::core::asset::v1 as pb, DomainType}; @@ -39,6 +40,23 @@ pub struct Balance { pub balance: BTreeMap>, } +impl Balance { + fn from_signed_value(negated: bool, value: Value) -> Option { + let non_zero = NonZeroU128::try_from(value.amount.value()).ok()?; + Some(Self { + negated: false, + balance: BTreeMap::from([( + value.asset_id, + if negated { + Imbalance::Required(non_zero) + } else { + Imbalance::Provided(non_zero) + }, + )]), + }) + } +} + impl DomainType for Balance { type Proto = pb::Balance; } @@ -53,111 +71,15 @@ impl DomainType for Balance { impl TryFrom for Balance { type Error = anyhow::Error; - fn try_from(v: pb::Balance) -> Result { - let mut balance_map = BTreeMap::new(); - - for signed_value in v.values { - let proto_value = signed_value - .value - .ok_or_else(|| anyhow::anyhow!("missing value"))?; - let value: Value = proto_value.try_into()?; - let amount = match NonZeroU128::new(value.amount.into()) { - Some(amount) => amount, - None => continue, - }; - - // The 'negated' flag in SignedValue determines the imbalance type: - // true = Required, false = Provided - - match balance_map.entry(value.asset_id) { - // First entry for this asset ID in BTreeMap - std::collections::btree_map::Entry::Vacant(entry) => { - let imbalance = if signed_value.negated { - Imbalance::Required(amount) - } else { - Imbalance::Provided(amount) - }; - entry.insert(imbalance); - } - // Subsequent entries for this asset ID in BTreeMap - // - // 1. Asset ID has a Required imbalance - accumulate another required amount - // 2. Asset ID has a Required imbalance - accumulate another provided amount - // 3. Asset ID has a Provided imbalance - accumulate another required amount - // 4. Asset ID has a Provided imbalance - accumulate another provided amount - std::collections::btree_map::Entry::Occupied(mut entry) => { - let existing = entry.get_mut(); - match (existing, signed_value.negated) { - (Imbalance::Required(existing_amount), true) => { - *existing_amount = NonZeroU128::new( - existing_amount - .get() - .checked_add(amount.get()) - .ok_or_else(|| anyhow::anyhow!("Combining required amounts"))?, - ) - .ok_or_else(|| anyhow::anyhow!("Combining required amounts"))?; - } - (Imbalance::Required(existing_amount), false) => { - match existing_amount.get().checked_sub(amount.get()) { - Some(diff) if diff > 0 => { - *existing_amount = NonZeroU128::new(diff) - .ok_or_else(|| anyhow::anyhow!("Reduce required amount"))?; - } - Some(0) => { - entry.remove(); - } - _ => { - *entry.get_mut() = Imbalance::Provided( - NonZeroU128::new(amount.get() - existing_amount.get()) - .ok_or_else(|| { - anyhow::anyhow!( - "Convert required to provided amount" - ) - })?, - ); - } - } - } - (Imbalance::Provided(existing_amount), true) => { - match existing_amount.get().checked_sub(amount.get()) { - Some(diff) if diff > 0 => { - *existing_amount = NonZeroU128::new(diff) - .ok_or_else(|| anyhow::anyhow!("Reduce provided amount"))?; - } - Some(0) => { - entry.remove(); - } - _ => { - *entry.get_mut() = Imbalance::Required( - NonZeroU128::new(amount.get() - existing_amount.get()) - .ok_or_else(|| { - anyhow::anyhow!( - "Convert provided to required amount" - ) - })?, - ); - } - }; - } - (Imbalance::Provided(existing_amount), false) => { - *existing_amount = NonZeroU128::new( - existing_amount - .get() - .checked_add(amount.get()) - .ok_or_else(|| anyhow::anyhow!("Combining provided amounts"))?, - ) - .ok_or_else(|| anyhow::anyhow!("Combining provided amounts"))?; - } - } - } + fn try_from(balance: pb::Balance) -> Result { + let mut out = Self::default(); + for v in balance.values { + let value = v.value.ok_or_else(|| anyhow!("missing value"))?; + if let Some(b) = Balance::from_signed_value(v.negated, Value::try_from(value)?) { + out += b; } } - - // Normalize the `Balance`. - Ok(Self { - negated: false, - balance: balance_map, - }) + Ok(out) } } @@ -169,7 +91,7 @@ impl From for pb::Balance { .map(|(id, imbalance)| { // Decompose imbalance into it sign and magnitude, and convert // magnitude into raw amount and determine negation based on the sign. - let (_sign, magnitude) = imbalance.into_inner(); + let (sign, magnitude) = if v.negated { -imbalance } else { imbalance }.into_inner(); let amount = u128::from(magnitude); pb::balance::SignedValue { @@ -177,7 +99,7 @@ impl From for pb::Balance { asset_id: Some(id.into()), amount: Some(Amount::from(amount).into()), }), - negated: v.negated, + negated: matches!(sign, Sign::Required), } }) .collect(); @@ -921,55 +843,6 @@ mod test { assert!(Balance::try_from(proto_balance).is_err()); } - /// Implement fallible conversion (protobuf to domain type) for cases where [-x UM, +x UM] - /// [+x UM, -x UM]. - #[test] - fn try_from_fallible_conversion_failure_zero_invariant() { - let proto_balance_0 = pb::Balance { - values: vec![ - pb::balance::SignedValue { - value: Some(pb::Value { - asset_id: Some((*STAKING_TOKEN_ASSET_ID).into()), - amount: Some(Amount::from(100u128).into()), - }), - negated: true, - }, - pb::balance::SignedValue { - value: Some(pb::Value { - asset_id: Some((*STAKING_TOKEN_ASSET_ID).into()), - amount: Some(Amount::from(100u128).into()), - }), - negated: false, - }, - ], - }; - - let proto_balance_1 = pb::Balance { - values: vec![ - pb::balance::SignedValue { - value: Some(pb::Value { - asset_id: Some((*STAKING_TOKEN_ASSET_ID).into()), - amount: Some(Amount::from(100u128).into()), - }), - negated: false, - }, - pb::balance::SignedValue { - value: Some(pb::Value { - asset_id: Some((*STAKING_TOKEN_ASSET_ID).into()), - amount: Some(Amount::from(100u128).into()), - }), - negated: true, - }, - ], - }; - - let balance_0 = Balance::try_from(proto_balance_0).expect("fallible conversion"); - let balance_1 = Balance::try_from(proto_balance_1).expect("fallible conversion"); - - assert!(balance_0.balance.get(&STAKING_TOKEN_ASSET_ID).is_none()); - assert!(balance_1.balance.get(&STAKING_TOKEN_ASSET_ID).is_none()); - } - /// Implement infallible conversion (domain type to protobuf). #[test] fn from_infallible_conversion() {