Skip to content

Commit

Permalink
Simplify balance conversion
Browse files Browse the repository at this point in the history
This makes the newly added proptests pass
  • Loading branch information
cronokirby committed Dec 13, 2024
1 parent a37c046 commit f34601f
Showing 1 changed file with 28 additions and 155 deletions.
183 changes: 28 additions & 155 deletions crates/core/asset/src/balance.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use anyhow::anyhow;
use ark_r1cs_std::prelude::*;
use ark_r1cs_std::uint8::UInt8;
use ark_relations::r1cs::SynthesisError;
Expand Down Expand Up @@ -25,7 +26,7 @@ mod imbalance;
mod iter;
use commitment::VALUE_BLINDING_GENERATOR;
use decaf377::{r1cs::ElementVar, Fq, Fr};
use imbalance::Imbalance;
use imbalance::{Imbalance, Sign};

use self::commitment::BalanceCommitmentVar;
use penumbra_proto::{penumbra::core::asset::v1 as pb, DomainType};
Expand All @@ -39,6 +40,23 @@ pub struct Balance {
pub balance: BTreeMap<Id, Imbalance<NonZeroU128>>,
}

impl Balance {
fn from_signed_value(negated: bool, value: Value) -> Option<Self> {
let non_zero = NonZeroU128::try_from(value.amount.value()).ok()?;
Some(Self {
negated: false,
balance: BTreeMap::from([(
value.asset_id,
if negated {
Imbalance::Required(non_zero)
} else {
Imbalance::Provided(non_zero)
},
)]),
})
}
}

impl DomainType for Balance {
type Proto = pb::Balance;
}
Expand All @@ -53,111 +71,15 @@ impl DomainType for Balance {
impl TryFrom<pb::Balance> for Balance {
type Error = anyhow::Error;

fn try_from(v: pb::Balance) -> Result<Self, Self::Error> {
let mut balance_map = BTreeMap::new();

for signed_value in v.values {
let proto_value = signed_value
.value
.ok_or_else(|| anyhow::anyhow!("missing value"))?;
let value: Value = proto_value.try_into()?;
let amount = match NonZeroU128::new(value.amount.into()) {
Some(amount) => amount,
None => continue,
};

// The 'negated' flag in SignedValue determines the imbalance type:
// true = Required, false = Provided

match balance_map.entry(value.asset_id) {
// First entry for this asset ID in BTreeMap
std::collections::btree_map::Entry::Vacant(entry) => {
let imbalance = if signed_value.negated {
Imbalance::Required(amount)
} else {
Imbalance::Provided(amount)
};
entry.insert(imbalance);
}
// Subsequent entries for this asset ID in BTreeMap
//
// 1. Asset ID has a Required imbalance - accumulate another required amount
// 2. Asset ID has a Required imbalance - accumulate another provided amount
// 3. Asset ID has a Provided imbalance - accumulate another required amount
// 4. Asset ID has a Provided imbalance - accumulate another provided amount
std::collections::btree_map::Entry::Occupied(mut entry) => {
let existing = entry.get_mut();
match (existing, signed_value.negated) {
(Imbalance::Required(existing_amount), true) => {
*existing_amount = NonZeroU128::new(
existing_amount
.get()
.checked_add(amount.get())
.ok_or_else(|| anyhow::anyhow!("Combining required amounts"))?,
)
.ok_or_else(|| anyhow::anyhow!("Combining required amounts"))?;
}
(Imbalance::Required(existing_amount), false) => {
match existing_amount.get().checked_sub(amount.get()) {
Some(diff) if diff > 0 => {
*existing_amount = NonZeroU128::new(diff)
.ok_or_else(|| anyhow::anyhow!("Reduce required amount"))?;
}
Some(0) => {
entry.remove();
}
_ => {
*entry.get_mut() = Imbalance::Provided(
NonZeroU128::new(amount.get() - existing_amount.get())
.ok_or_else(|| {
anyhow::anyhow!(
"Convert required to provided amount"
)
})?,
);
}
}
}
(Imbalance::Provided(existing_amount), true) => {
match existing_amount.get().checked_sub(amount.get()) {
Some(diff) if diff > 0 => {
*existing_amount = NonZeroU128::new(diff)
.ok_or_else(|| anyhow::anyhow!("Reduce provided amount"))?;
}
Some(0) => {
entry.remove();
}
_ => {
*entry.get_mut() = Imbalance::Required(
NonZeroU128::new(amount.get() - existing_amount.get())
.ok_or_else(|| {
anyhow::anyhow!(
"Convert provided to required amount"
)
})?,
);
}
};
}
(Imbalance::Provided(existing_amount), false) => {
*existing_amount = NonZeroU128::new(
existing_amount
.get()
.checked_add(amount.get())
.ok_or_else(|| anyhow::anyhow!("Combining provided amounts"))?,
)
.ok_or_else(|| anyhow::anyhow!("Combining provided amounts"))?;
}
}
}
fn try_from(balance: pb::Balance) -> Result<Self, Self::Error> {
let mut out = Self::default();
for v in balance.values {
let value = v.value.ok_or_else(|| anyhow!("missing value"))?;
if let Some(b) = Balance::from_signed_value(v.negated, Value::try_from(value)?) {
out += b;
}
}

// Normalize the `Balance`.
Ok(Self {
negated: false,
balance: balance_map,
})
Ok(out)
}
}

Expand All @@ -169,15 +91,15 @@ impl From<Balance> for pb::Balance {
.map(|(id, imbalance)| {
// Decompose imbalance into it sign and magnitude, and convert
// magnitude into raw amount and determine negation based on the sign.
let (_sign, magnitude) = imbalance.into_inner();
let (sign, magnitude) = if v.negated { -imbalance } else { imbalance }.into_inner();
let amount = u128::from(magnitude);

pb::balance::SignedValue {
value: Some(pb::Value {
asset_id: Some(id.into()),
amount: Some(Amount::from(amount).into()),
}),
negated: v.negated,
negated: matches!(sign, Sign::Required),
}
})
.collect();
Expand Down Expand Up @@ -921,55 +843,6 @@ mod test {
assert!(Balance::try_from(proto_balance).is_err());
}

/// Implement fallible conversion (protobuf to domain type) for cases where [-x UM, +x UM]
/// [+x UM, -x UM].
#[test]
fn try_from_fallible_conversion_failure_zero_invariant() {
let proto_balance_0 = pb::Balance {
values: vec![
pb::balance::SignedValue {
value: Some(pb::Value {
asset_id: Some((*STAKING_TOKEN_ASSET_ID).into()),
amount: Some(Amount::from(100u128).into()),
}),
negated: true,
},
pb::balance::SignedValue {
value: Some(pb::Value {
asset_id: Some((*STAKING_TOKEN_ASSET_ID).into()),
amount: Some(Amount::from(100u128).into()),
}),
negated: false,
},
],
};

let proto_balance_1 = pb::Balance {
values: vec![
pb::balance::SignedValue {
value: Some(pb::Value {
asset_id: Some((*STAKING_TOKEN_ASSET_ID).into()),
amount: Some(Amount::from(100u128).into()),
}),
negated: false,
},
pb::balance::SignedValue {
value: Some(pb::Value {
asset_id: Some((*STAKING_TOKEN_ASSET_ID).into()),
amount: Some(Amount::from(100u128).into()),
}),
negated: true,
},
],
};

let balance_0 = Balance::try_from(proto_balance_0).expect("fallible conversion");
let balance_1 = Balance::try_from(proto_balance_1).expect("fallible conversion");

assert!(balance_0.balance.get(&STAKING_TOKEN_ASSET_ID).is_none());
assert!(balance_1.balance.get(&STAKING_TOKEN_ASSET_ID).is_none());
}

/// Implement infallible conversion (domain type to protobuf).
#[test]
fn from_infallible_conversion() {
Expand Down

0 comments on commit f34601f

Please sign in to comment.