From 7783a4358b8bcfbd2c8fefe760bf1c73bee941a7 Mon Sep 17 00:00:00 2001 From: ian Date: Thu, 27 Jun 2024 10:13:42 +0800 Subject: [PATCH 1/2] feat: allow sha256 for HTLC Allow choosing the hash algorithm when adding the HTLC output. Related PR: https://github.com/nervosnetwork/cfn-scripts/pull/6 - Added a field `hash_algorithm` when adding TLC - Save the `hash_algorithm` in the internal TLC struct - Check hash against preimage using the saved `hash_algorithm` - Added `hash_algorithm` in the invoice as a new attribute --- docs/specs/payment-invoice.md | 9 +- src/ckb/channel.rs | 29 +++++- src/ckb/gen/cfn.rs | 49 ++++++--- src/ckb/gen/invoice.rs | 190 +++++++++++++++++++++++++++++++++- src/ckb/hash_algorithm.rs | 79 ++++++++++++++ src/ckb/mod.rs | 2 + src/ckb/schema/cfn.mol | 1 + src/ckb/schema/invoice.mol | 7 ++ src/ckb/types.rs | 24 ++++- src/invoice/invoice_impl.rs | 30 +++++- src/rpc/channel.rs | 3 + src/rpc/invoice.rs | 5 + 12 files changed, 403 insertions(+), 25 deletions(-) create mode 100644 src/ckb/hash_algorithm.rs diff --git a/docs/specs/payment-invoice.md b/docs/specs/payment-invoice.md index fcd11230b..409bac781 100644 --- a/docs/specs/payment-invoice.md +++ b/docs/specs/payment-invoice.md @@ -20,7 +20,6 @@ The human-readable part contains these two most important fields: - A standalone number, means the amount of CKB or UDT, for CKB it will be in unit of `shannon`, 1 CKB = 10^8 shannon - An empty value for this field means the amount of payment is not specified, which maybe used in the scenario of donation. - ## Encoding and Decoding With `molecule`, the data part can be easily converted to bytes. Considering that the bytes generated by molecule are not optimized for space and may contain consecutive zeros when certain fields are empty, the result from `bechm32` encoding is relatively long. We use [arcode-rs](https://github.com/cgbur/arcode-rs) to compress the bytes losslessly before `bechm32` encoding, resulting in a length reduction of almost half: @@ -35,9 +34,9 @@ The `signature` field: [optional] with type of `[u8; 65]` = 520 bits - The secp256k1 signature of the entire invoice, can be used to verify the integrity and correctness of the invoice, may also be used to imply the generator node of this invoice. By default, this filed is none, the method to generate signature: - - `message_hash = SHA256-hash (((human-readable part) → bytes) + (data bytes))` + - `message_hash = SHA256-hash (((human-readable part) → bytes) + (data bytes))` then sign it with `Secp256k1` - - It may use a customized sign function: `Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)` + - It may use a customized sign function: `Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)` ## Data Part @@ -66,3 +65,7 @@ The data part is designed to add non-mandatory fields easily, and it is very lik - The public key of the payee 9. `udt_script`: [optional] variable length - The script specified for the UDT token +10. `hash_algorithm`: [optional] 1 byte + - The hash algorithm used to generate the `payment_hash` from the preimage. When this is missing, the default hash algorithm ckb hash is used. + - 0: ckb hash + - 1: sha256 diff --git a/src/ckb/channel.rs b/src/ckb/channel.rs index c0ac8b26e..cff5de129 100644 --- a/src/ckb/channel.rs +++ b/src/ckb/channel.rs @@ -41,6 +41,7 @@ use crate::{ use super::{ config::{DEFAULT_CHANNEL_MINIMAL_CKB_AMOUNT, MIN_UDT_OCCUPIED_CAPACITY}, + hash_algorithm::HashAlgorithm, key::blake2b_hash_with_salt, network::CFNMessageWithPeerId, serde_utils::EntityHex, @@ -103,6 +104,7 @@ pub struct AddTlcCommand { pub preimage: Option, pub payment_hash: Option, pub expiry: LockTime, + pub hash_algorithm: HashAlgorithm, } #[derive(Debug)] @@ -585,6 +587,7 @@ impl ChannelActor { amount: tlc.amount, payment_hash: tlc.payment_hash, expiry: tlc.lock_time, + hash_algorithm: tlc.hash_algorithm, }), }; debug!("Sending AddTlc message: {:?}", &msg); @@ -2158,8 +2161,11 @@ impl ChannelActorState { reason, removed_at, current ); if let RemoveTlcReason::RemoveTlcFulfill(fulfill) = reason { - let filled_payment_hash: Hash256 = - blake2b_256(fulfill.payment_preimage).into(); + let filled_payment_hash: Hash256 = current + .tlc + .hash_algorithm + .hash(fulfill.payment_preimage) + .into(); if current.tlc.payment_hash != filled_payment_hash { return Err(ProcessingChannelError::InvalidParameter(format!( "Preimage {:?} is hashed to {}, which does not match payment hash {:?}", @@ -2464,7 +2470,7 @@ impl ChannelActorState { tlcs.iter() .map(|(tlc, local, remote)| { [ - (if tlc.tlc.is_offered() { [0] } else { [1] }).to_vec(), + vec![tlc.tlc.get_htlc_type()], tlc.tlc.amount.to_le_bytes().to_vec(), tlc.tlc.get_hash().to_vec(), local.serialize().to_vec(), @@ -2564,7 +2570,7 @@ impl ChannelActorState { let preimage = command.preimage.unwrap_or(get_random_preimage()); let payment_hash = command .payment_hash - .unwrap_or(blake2b_256(&preimage).into()); + .unwrap_or_else(|| command.hash_algorithm.hash(&preimage).into()); TLC { id: TLCId::Offered(id), @@ -2572,6 +2578,7 @@ impl ChannelActorState { payment_hash, lock_time: command.expiry, payment_preimage: Some(preimage), + hash_algorithm: command.hash_algorithm, } } @@ -2595,6 +2602,7 @@ impl ChannelActorState { payment_hash: message.payment_hash, lock_time: message.expiry, payment_preimage: None, + hash_algorithm: message.hash_algorithm, }) } } @@ -3215,6 +3223,7 @@ impl ChannelActorState { amount: info.tlc.amount, payment_hash: info.tlc.payment_hash, expiry: info.tlc.lock_time, + hash_algorithm: info.tlc.hash_algorithm, }), }), )) @@ -4214,6 +4223,8 @@ pub struct TLC { pub payment_hash: Hash256, /// The preimage of the hash to be sent to the counterparty. pub payment_preimage: Option, + /// Which hash algorithm is applied on the preimage + pub hash_algorithm: HashAlgorithm, } impl TLC { @@ -4230,6 +4241,16 @@ impl TLC { self.id.flip_mut() } + /// Get the value for the field `htlc_type` in commitment lock witness. + /// - Lowest 1 bit: 0 if the tlc is offered by the remote party, 1 otherwise. + /// - High 7 bits: + /// - 0: ckb hash + /// - 1: sha256 + pub fn get_htlc_type(&self) -> u8 { + let offered_flag = if self.is_offered() { 0u8 } else { 1u8 }; + ((self.hash_algorithm as u8) << 1) + offered_flag + } + fn get_hash(&self) -> ShortHash { self.payment_hash.as_ref()[..20].try_into().unwrap() } diff --git a/src/ckb/gen/cfn.rs b/src/ckb/gen/cfn.rs index 1579b5e5c..1d132457c 100644 --- a/src/ckb/gen/cfn.rs +++ b/src/ckb/gen/cfn.rs @@ -7042,6 +7042,7 @@ impl ::core::fmt::Display for AddTlc { write!(f, ", {}: {}", "amount", self.amount())?; write!(f, ", {}: {}", "payment_hash", self.payment_hash())?; write!(f, ", {}: {}", "expiry", self.expiry())?; + write!(f, ", {}: {}", "hash_algorithm", self.hash_algorithm())?; let extra_count = self.count_extra_fields(); if extra_count != 0 { write!(f, ", .. ({} fields)", extra_count)?; @@ -7056,14 +7057,14 @@ impl ::core::default::Default for AddTlc { } } impl AddTlc { - const DEFAULT_VALUE: [u8; 120] = [ - 120, 0, 0, 0, 24, 0, 0, 0, 56, 0, 0, 0, 64, 0, 0, 0, 80, 0, 0, 0, 112, 0, 0, 0, 0, 0, 0, 0, + const DEFAULT_VALUE: [u8; 125] = [ + 125, 0, 0, 0, 28, 0, 0, 0, 60, 0, 0, 0, 68, 0, 0, 0, 84, 0, 0, 0, 116, 0, 0, 0, 124, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, ]; - pub const FIELD_COUNT: usize = 5; + pub const FIELD_COUNT: usize = 6; pub fn total_size(&self) -> usize { molecule::unpack_number(self.as_slice()) as usize } @@ -7107,11 +7108,17 @@ impl AddTlc { pub fn expiry(&self) -> Uint64 { let slice = self.as_slice(); let start = molecule::unpack_number(&slice[20..]) as usize; + let end = molecule::unpack_number(&slice[24..]) as usize; + Uint64::new_unchecked(self.0.slice(start..end)) + } + pub fn hash_algorithm(&self) -> Byte { + let slice = self.as_slice(); + let start = molecule::unpack_number(&slice[24..]) as usize; if self.has_extra_fields() { - let end = molecule::unpack_number(&slice[24..]) as usize; - Uint64::new_unchecked(self.0.slice(start..end)) + let end = molecule::unpack_number(&slice[28..]) as usize; + Byte::new_unchecked(self.0.slice(start..end)) } else { - Uint64::new_unchecked(self.0.slice(start..)) + Byte::new_unchecked(self.0.slice(start..)) } } pub fn as_reader<'r>(&'r self) -> AddTlcReader<'r> { @@ -7146,6 +7153,7 @@ impl molecule::prelude::Entity for AddTlc { .amount(self.amount()) .payment_hash(self.payment_hash()) .expiry(self.expiry()) + .hash_algorithm(self.hash_algorithm()) } } #[derive(Clone, Copy)] @@ -7172,6 +7180,7 @@ impl<'r> ::core::fmt::Display for AddTlcReader<'r> { write!(f, ", {}: {}", "amount", self.amount())?; write!(f, ", {}: {}", "payment_hash", self.payment_hash())?; write!(f, ", {}: {}", "expiry", self.expiry())?; + write!(f, ", {}: {}", "hash_algorithm", self.hash_algorithm())?; let extra_count = self.count_extra_fields(); if extra_count != 0 { write!(f, ", .. ({} fields)", extra_count)?; @@ -7180,7 +7189,7 @@ impl<'r> ::core::fmt::Display for AddTlcReader<'r> { } } impl<'r> AddTlcReader<'r> { - pub const FIELD_COUNT: usize = 5; + pub const FIELD_COUNT: usize = 6; pub fn total_size(&self) -> usize { molecule::unpack_number(self.as_slice()) as usize } @@ -7224,11 +7233,17 @@ impl<'r> AddTlcReader<'r> { pub fn expiry(&self) -> Uint64Reader<'r> { let slice = self.as_slice(); let start = molecule::unpack_number(&slice[20..]) as usize; + let end = molecule::unpack_number(&slice[24..]) as usize; + Uint64Reader::new_unchecked(&self.as_slice()[start..end]) + } + pub fn hash_algorithm(&self) -> ByteReader<'r> { + let slice = self.as_slice(); + let start = molecule::unpack_number(&slice[24..]) as usize; if self.has_extra_fields() { - let end = molecule::unpack_number(&slice[24..]) as usize; - Uint64Reader::new_unchecked(&self.as_slice()[start..end]) + let end = molecule::unpack_number(&slice[28..]) as usize; + ByteReader::new_unchecked(&self.as_slice()[start..end]) } else { - Uint64Reader::new_unchecked(&self.as_slice()[start..]) + ByteReader::new_unchecked(&self.as_slice()[start..]) } } } @@ -7283,6 +7298,7 @@ impl<'r> molecule::prelude::Reader<'r> for AddTlcReader<'r> { Uint128Reader::verify(&slice[offsets[2]..offsets[3]], compatible)?; Byte32Reader::verify(&slice[offsets[3]..offsets[4]], compatible)?; Uint64Reader::verify(&slice[offsets[4]..offsets[5]], compatible)?; + ByteReader::verify(&slice[offsets[5]..offsets[6]], compatible)?; Ok(()) } } @@ -7293,9 +7309,10 @@ pub struct AddTlcBuilder { pub(crate) amount: Uint128, pub(crate) payment_hash: Byte32, pub(crate) expiry: Uint64, + pub(crate) hash_algorithm: Byte, } impl AddTlcBuilder { - pub const FIELD_COUNT: usize = 5; + pub const FIELD_COUNT: usize = 6; pub fn channel_id(mut self, v: Byte32) -> Self { self.channel_id = v; self @@ -7316,6 +7333,10 @@ impl AddTlcBuilder { self.expiry = v; self } + pub fn hash_algorithm(mut self, v: Byte) -> Self { + self.hash_algorithm = v; + self + } } impl molecule::prelude::Builder for AddTlcBuilder { type Entity = AddTlc; @@ -7327,6 +7348,7 @@ impl molecule::prelude::Builder for AddTlcBuilder { + self.amount.as_slice().len() + self.payment_hash.as_slice().len() + self.expiry.as_slice().len() + + self.hash_algorithm.as_slice().len() } fn write(&self, writer: &mut W) -> molecule::io::Result<()> { let mut total_size = molecule::NUMBER_SIZE * (Self::FIELD_COUNT + 1); @@ -7341,6 +7363,8 @@ impl molecule::prelude::Builder for AddTlcBuilder { total_size += self.payment_hash.as_slice().len(); offsets.push(total_size); total_size += self.expiry.as_slice().len(); + offsets.push(total_size); + total_size += self.hash_algorithm.as_slice().len(); writer.write_all(&molecule::pack_number(total_size as molecule::Number))?; for offset in offsets.into_iter() { writer.write_all(&molecule::pack_number(offset as molecule::Number))?; @@ -7350,6 +7374,7 @@ impl molecule::prelude::Builder for AddTlcBuilder { writer.write_all(self.amount.as_slice())?; writer.write_all(self.payment_hash.as_slice())?; writer.write_all(self.expiry.as_slice())?; + writer.write_all(self.hash_algorithm.as_slice())?; Ok(()) } fn build(&self) -> Self::Entity { diff --git a/src/ckb/gen/invoice.rs b/src/ckb/gen/invoice.rs index 42ca45575..7995faffb 100644 --- a/src/ckb/gen/invoice.rs +++ b/src/ckb/gen/invoice.rs @@ -4696,6 +4696,154 @@ impl molecule::prelude::Builder for PayeePublicKeyBuilder { } } #[derive(Clone)] +pub struct HashAlgorithm(molecule::bytes::Bytes); +impl ::core::fmt::LowerHex for HashAlgorithm { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + use molecule::hex_string; + if f.alternate() { + write!(f, "0x")?; + } + write!(f, "{}", hex_string(self.as_slice())) + } +} +impl ::core::fmt::Debug for HashAlgorithm { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + write!(f, "{}({:#x})", Self::NAME, self) + } +} +impl ::core::fmt::Display for HashAlgorithm { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + write!(f, "{} {{ ", Self::NAME)?; + write!(f, "{}: {}", "value", self.value())?; + write!(f, " }}") + } +} +impl ::core::default::Default for HashAlgorithm { + fn default() -> Self { + let v = molecule::bytes::Bytes::from_static(&Self::DEFAULT_VALUE); + HashAlgorithm::new_unchecked(v) + } +} +impl HashAlgorithm { + const DEFAULT_VALUE: [u8; 1] = [0]; + pub const TOTAL_SIZE: usize = 1; + pub const FIELD_SIZES: [usize; 1] = [1]; + pub const FIELD_COUNT: usize = 1; + pub fn value(&self) -> Byte { + Byte::new_unchecked(self.0.slice(0..1)) + } + pub fn as_reader<'r>(&'r self) -> HashAlgorithmReader<'r> { + HashAlgorithmReader::new_unchecked(self.as_slice()) + } +} +impl molecule::prelude::Entity for HashAlgorithm { + type Builder = HashAlgorithmBuilder; + const NAME: &'static str = "HashAlgorithm"; + fn new_unchecked(data: molecule::bytes::Bytes) -> Self { + HashAlgorithm(data) + } + fn as_bytes(&self) -> molecule::bytes::Bytes { + self.0.clone() + } + fn as_slice(&self) -> &[u8] { + &self.0[..] + } + fn from_slice(slice: &[u8]) -> molecule::error::VerificationResult { + HashAlgorithmReader::from_slice(slice).map(|reader| reader.to_entity()) + } + fn from_compatible_slice(slice: &[u8]) -> molecule::error::VerificationResult { + HashAlgorithmReader::from_compatible_slice(slice).map(|reader| reader.to_entity()) + } + fn new_builder() -> Self::Builder { + ::core::default::Default::default() + } + fn as_builder(self) -> Self::Builder { + Self::new_builder().value(self.value()) + } +} +#[derive(Clone, Copy)] +pub struct HashAlgorithmReader<'r>(&'r [u8]); +impl<'r> ::core::fmt::LowerHex for HashAlgorithmReader<'r> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + use molecule::hex_string; + if f.alternate() { + write!(f, "0x")?; + } + write!(f, "{}", hex_string(self.as_slice())) + } +} +impl<'r> ::core::fmt::Debug for HashAlgorithmReader<'r> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + write!(f, "{}({:#x})", Self::NAME, self) + } +} +impl<'r> ::core::fmt::Display for HashAlgorithmReader<'r> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + write!(f, "{} {{ ", Self::NAME)?; + write!(f, "{}: {}", "value", self.value())?; + write!(f, " }}") + } +} +impl<'r> HashAlgorithmReader<'r> { + pub const TOTAL_SIZE: usize = 1; + pub const FIELD_SIZES: [usize; 1] = [1]; + pub const FIELD_COUNT: usize = 1; + pub fn value(&self) -> ByteReader<'r> { + ByteReader::new_unchecked(&self.as_slice()[0..1]) + } +} +impl<'r> molecule::prelude::Reader<'r> for HashAlgorithmReader<'r> { + type Entity = HashAlgorithm; + const NAME: &'static str = "HashAlgorithmReader"; + fn to_entity(&self) -> Self::Entity { + Self::Entity::new_unchecked(self.as_slice().to_owned().into()) + } + fn new_unchecked(slice: &'r [u8]) -> Self { + HashAlgorithmReader(slice) + } + fn as_slice(&self) -> &'r [u8] { + self.0 + } + fn verify(slice: &[u8], _compatible: bool) -> molecule::error::VerificationResult<()> { + use molecule::verification_error as ve; + let slice_len = slice.len(); + if slice_len != Self::TOTAL_SIZE { + return ve!(Self, TotalSizeNotMatch, Self::TOTAL_SIZE, slice_len); + } + Ok(()) + } +} +#[derive(Clone, Debug, Default)] +pub struct HashAlgorithmBuilder { + pub(crate) value: Byte, +} +impl HashAlgorithmBuilder { + pub const TOTAL_SIZE: usize = 1; + pub const FIELD_SIZES: [usize; 1] = [1]; + pub const FIELD_COUNT: usize = 1; + pub fn value(mut self, v: Byte) -> Self { + self.value = v; + self + } +} +impl molecule::prelude::Builder for HashAlgorithmBuilder { + type Entity = HashAlgorithm; + const NAME: &'static str = "HashAlgorithmBuilder"; + fn expected_length(&self) -> usize { + Self::TOTAL_SIZE + } + fn write(&self, writer: &mut W) -> molecule::io::Result<()> { + writer.write_all(self.value.as_slice())?; + Ok(()) + } + fn build(&self) -> Self::Entity { + let mut inner = Vec::with_capacity(self.expected_length()); + self.write(&mut inner) + .unwrap_or_else(|_| panic!("{} build should be ok", Self::NAME)); + HashAlgorithm::new_unchecked(inner.into()) + } +} +#[derive(Clone)] pub struct InvoiceAttr(molecule::bytes::Bytes); impl ::core::fmt::LowerHex for InvoiceAttr { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { @@ -4726,7 +4874,7 @@ impl ::core::default::Default for InvoiceAttr { } impl InvoiceAttr { const DEFAULT_VALUE: [u8; 20] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - pub const ITEMS_COUNT: usize = 8; + pub const ITEMS_COUNT: usize = 9; pub fn item_id(&self) -> molecule::Number { molecule::unpack_number(self.as_slice()) } @@ -4741,6 +4889,7 @@ impl InvoiceAttr { 5 => Feature::new_unchecked(inner).into(), 6 => UdtScript::new_unchecked(inner).into(), 7 => PayeePublicKey::new_unchecked(inner).into(), + 8 => HashAlgorithm::new_unchecked(inner).into(), _ => panic!("{}: invalid data", Self::NAME), } } @@ -4797,7 +4946,7 @@ impl<'r> ::core::fmt::Display for InvoiceAttrReader<'r> { } } impl<'r> InvoiceAttrReader<'r> { - pub const ITEMS_COUNT: usize = 8; + pub const ITEMS_COUNT: usize = 9; pub fn item_id(&self) -> molecule::Number { molecule::unpack_number(self.as_slice()) } @@ -4812,6 +4961,7 @@ impl<'r> InvoiceAttrReader<'r> { 5 => FeatureReader::new_unchecked(inner).into(), 6 => UdtScriptReader::new_unchecked(inner).into(), 7 => PayeePublicKeyReader::new_unchecked(inner).into(), + 8 => HashAlgorithmReader::new_unchecked(inner).into(), _ => panic!("{}: invalid data", Self::NAME), } } @@ -4845,6 +4995,7 @@ impl<'r> molecule::prelude::Reader<'r> for InvoiceAttrReader<'r> { 5 => FeatureReader::verify(inner_slice, compatible), 6 => UdtScriptReader::verify(inner_slice, compatible), 7 => PayeePublicKeyReader::verify(inner_slice, compatible), + 8 => HashAlgorithmReader::verify(inner_slice, compatible), _ => ve!(Self, UnknownItem, Self::ITEMS_COUNT, item_id), }?; Ok(()) @@ -4853,7 +5004,7 @@ impl<'r> molecule::prelude::Reader<'r> for InvoiceAttrReader<'r> { #[derive(Clone, Debug, Default)] pub struct InvoiceAttrBuilder(pub(crate) InvoiceAttrUnion); impl InvoiceAttrBuilder { - pub const ITEMS_COUNT: usize = 8; + pub const ITEMS_COUNT: usize = 9; pub fn set(mut self, v: I) -> Self where I: ::core::convert::Into, @@ -4889,6 +5040,7 @@ pub enum InvoiceAttrUnion { Feature(Feature), UdtScript(UdtScript), PayeePublicKey(PayeePublicKey), + HashAlgorithm(HashAlgorithm), } #[derive(Debug, Clone, Copy)] pub enum InvoiceAttrUnionReader<'r> { @@ -4900,6 +5052,7 @@ pub enum InvoiceAttrUnionReader<'r> { Feature(FeatureReader<'r>), UdtScript(UdtScriptReader<'r>), PayeePublicKey(PayeePublicKeyReader<'r>), + HashAlgorithm(HashAlgorithmReader<'r>), } impl ::core::default::Default for InvoiceAttrUnion { fn default() -> Self { @@ -4939,6 +5092,9 @@ impl ::core::fmt::Display for InvoiceAttrUnion { InvoiceAttrUnion::PayeePublicKey(ref item) => { write!(f, "{}::{}({})", Self::NAME, PayeePublicKey::NAME, item) } + InvoiceAttrUnion::HashAlgorithm(ref item) => { + write!(f, "{}::{}({})", Self::NAME, HashAlgorithm::NAME, item) + } } } } @@ -4975,6 +5131,9 @@ impl<'r> ::core::fmt::Display for InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::PayeePublicKey(ref item) => { write!(f, "{}::{}({})", Self::NAME, PayeePublicKey::NAME, item) } + InvoiceAttrUnionReader::HashAlgorithm(ref item) => { + write!(f, "{}::{}({})", Self::NAME, HashAlgorithm::NAME, item) + } } } } @@ -4989,6 +5148,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::Feature(ref item) => write!(f, "{}", item), InvoiceAttrUnion::UdtScript(ref item) => write!(f, "{}", item), InvoiceAttrUnion::PayeePublicKey(ref item) => write!(f, "{}", item), + InvoiceAttrUnion::HashAlgorithm(ref item) => write!(f, "{}", item), } } } @@ -5003,6 +5163,7 @@ impl<'r> InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::Feature(ref item) => write!(f, "{}", item), InvoiceAttrUnionReader::UdtScript(ref item) => write!(f, "{}", item), InvoiceAttrUnionReader::PayeePublicKey(ref item) => write!(f, "{}", item), + InvoiceAttrUnionReader::HashAlgorithm(ref item) => write!(f, "{}", item), } } } @@ -5046,6 +5207,11 @@ impl ::core::convert::From for InvoiceAttrUnion { InvoiceAttrUnion::PayeePublicKey(item) } } +impl ::core::convert::From for InvoiceAttrUnion { + fn from(item: HashAlgorithm) -> Self { + InvoiceAttrUnion::HashAlgorithm(item) + } +} impl<'r> ::core::convert::From> for InvoiceAttrUnionReader<'r> { fn from(item: ExpiryTimeReader<'r>) -> Self { InvoiceAttrUnionReader::ExpiryTime(item) @@ -5088,6 +5254,11 @@ impl<'r> ::core::convert::From> for InvoiceAttrUnionRea InvoiceAttrUnionReader::PayeePublicKey(item) } } +impl<'r> ::core::convert::From> for InvoiceAttrUnionReader<'r> { + fn from(item: HashAlgorithmReader<'r>) -> Self { + InvoiceAttrUnionReader::HashAlgorithm(item) + } +} impl InvoiceAttrUnion { pub const NAME: &'static str = "InvoiceAttrUnion"; pub fn as_bytes(&self) -> molecule::bytes::Bytes { @@ -5100,6 +5271,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::Feature(item) => item.as_bytes(), InvoiceAttrUnion::UdtScript(item) => item.as_bytes(), InvoiceAttrUnion::PayeePublicKey(item) => item.as_bytes(), + InvoiceAttrUnion::HashAlgorithm(item) => item.as_bytes(), } } pub fn as_slice(&self) -> &[u8] { @@ -5112,6 +5284,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::Feature(item) => item.as_slice(), InvoiceAttrUnion::UdtScript(item) => item.as_slice(), InvoiceAttrUnion::PayeePublicKey(item) => item.as_slice(), + InvoiceAttrUnion::HashAlgorithm(item) => item.as_slice(), } } pub fn item_id(&self) -> molecule::Number { @@ -5124,6 +5297,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::Feature(_) => 5, InvoiceAttrUnion::UdtScript(_) => 6, InvoiceAttrUnion::PayeePublicKey(_) => 7, + InvoiceAttrUnion::HashAlgorithm(_) => 8, } } pub fn item_name(&self) -> &str { @@ -5136,6 +5310,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::Feature(_) => "Feature", InvoiceAttrUnion::UdtScript(_) => "UdtScript", InvoiceAttrUnion::PayeePublicKey(_) => "PayeePublicKey", + InvoiceAttrUnion::HashAlgorithm(_) => "HashAlgorithm", } } pub fn as_reader<'r>(&'r self) -> InvoiceAttrUnionReader<'r> { @@ -5148,6 +5323,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::Feature(item) => item.as_reader().into(), InvoiceAttrUnion::UdtScript(item) => item.as_reader().into(), InvoiceAttrUnion::PayeePublicKey(item) => item.as_reader().into(), + InvoiceAttrUnion::HashAlgorithm(item) => item.as_reader().into(), } } } @@ -5163,6 +5339,7 @@ impl<'r> InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::Feature(item) => item.as_slice(), InvoiceAttrUnionReader::UdtScript(item) => item.as_slice(), InvoiceAttrUnionReader::PayeePublicKey(item) => item.as_slice(), + InvoiceAttrUnionReader::HashAlgorithm(item) => item.as_slice(), } } pub fn item_id(&self) -> molecule::Number { @@ -5175,6 +5352,7 @@ impl<'r> InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::Feature(_) => 5, InvoiceAttrUnionReader::UdtScript(_) => 6, InvoiceAttrUnionReader::PayeePublicKey(_) => 7, + InvoiceAttrUnionReader::HashAlgorithm(_) => 8, } } pub fn item_name(&self) -> &str { @@ -5187,6 +5365,7 @@ impl<'r> InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::Feature(_) => "Feature", InvoiceAttrUnionReader::UdtScript(_) => "UdtScript", InvoiceAttrUnionReader::PayeePublicKey(_) => "PayeePublicKey", + InvoiceAttrUnionReader::HashAlgorithm(_) => "HashAlgorithm", } } } @@ -5230,6 +5409,11 @@ impl From for InvoiceAttr { Self::new_builder().set(value).build() } } +impl From for InvoiceAttr { + fn from(value: HashAlgorithm) -> Self { + Self::new_builder().set(value).build() + } +} #[derive(Clone)] pub struct InvoiceAttrsVec(molecule::bytes::Bytes); impl ::core::fmt::LowerHex for InvoiceAttrsVec { diff --git a/src/ckb/hash_algorithm.rs b/src/ckb/hash_algorithm.rs new file mode 100644 index 000000000..99618411d --- /dev/null +++ b/src/ckb/hash_algorithm.rs @@ -0,0 +1,79 @@ +use bitcoin::hashes::{sha256::Hash as Sha256, Hash as _}; +use ckb_hash::blake2b_256; +use ckb_types::packed; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[repr(u8)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum HashAlgorithm { + CkbHash = 0, + Sha256 = 1, +} + +impl HashAlgorithm { + pub fn hash>(&self, s: T) -> [u8; 32] { + match self { + HashAlgorithm::CkbHash => blake2b_256(s), + HashAlgorithm::Sha256 => sha256(s), + } + } +} + +impl Default for HashAlgorithm { + fn default() -> Self { + HashAlgorithm::CkbHash + } +} + +/// The error type wrap various ser/de errors. +#[derive(Error, Debug)] +#[error("Unknown Hash Algorithm: {0}")] +pub struct UnknownHashAlgorithmError(pub u8); + +impl TryFrom for HashAlgorithm { + type Error = UnknownHashAlgorithmError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(HashAlgorithm::CkbHash), + 1 => Ok(HashAlgorithm::Sha256), + _ => Err(UnknownHashAlgorithmError(value)), + } + } +} + +impl TryFrom for HashAlgorithm { + type Error = UnknownHashAlgorithmError; + + fn try_from(value: packed::Byte) -> Result { + let value: u8 = value.into(); + value.try_into() + } +} + +pub fn sha256>(s: T) -> [u8; 32] { + Sha256::hash(s.as_ref()).to_byte_array() +} + +#[cfg(test)] +mod tests { + #[test] + fn test_hash_algorithm_serialization_sha256() { + let algorithm = super::HashAlgorithm::Sha256; + let serialized = serde_json::to_string(&algorithm).unwrap(); + assert_eq!(serialized, r#""sha256""#); + let deserialized: super::HashAlgorithm = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized, algorithm); + } + + #[test] + fn test_hash_algorithm_serialization_ckb_hash() { + let algorithm = super::HashAlgorithm::CkbHash; + let serialized = serde_json::to_string(&algorithm).unwrap(); + assert_eq!(serialized, r#""ckb_hash""#); + let deserialized: super::HashAlgorithm = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized, algorithm); + } +} diff --git a/src/ckb/mod.rs b/src/ckb/mod.rs index 2c77fc69d..3418b52da 100644 --- a/src/ckb/mod.rs +++ b/src/ckb/mod.rs @@ -16,6 +16,8 @@ pub mod channel; pub mod types; +pub mod hash_algorithm; + pub mod serde_utils; #[cfg(test)] diff --git a/src/ckb/schema/cfn.mol b/src/ckb/schema/cfn.mol index ce7f9e1f5..814d990ac 100644 --- a/src/ckb/schema/cfn.mol +++ b/src/ckb/schema/cfn.mol @@ -101,6 +101,7 @@ table AddTlc { amount: Uint128, payment_hash: Byte32, expiry: Uint64, + hash_algorithm: byte, } table RevokeAndAck { diff --git a/src/ckb/schema/invoice.mol b/src/ckb/schema/invoice.mol index d48158b4d..9edcb43fe 100644 --- a/src/ckb/schema/invoice.mol +++ b/src/ckb/schema/invoice.mol @@ -46,6 +46,12 @@ table PayeePublicKey { value: Bytes, } +// 0 - ckb hash (Default) +// 1 - sha256 +struct HashAlgorithm { + value: byte, +} + union InvoiceAttr { ExpiryTime, Description, @@ -55,6 +61,7 @@ union InvoiceAttr { Feature, UdtScript, PayeePublicKey, + HashAlgorithm, } vector InvoiceAttrsVec ; diff --git a/src/ckb/types.rs b/src/ckb/types.rs index 9804ea59a..639efcd5a 100644 --- a/src/ckb/types.rs +++ b/src/ckb/types.rs @@ -1,6 +1,7 @@ use std::str::FromStr; use super::gen::cfn::{self as molecule_cfn, PubNonce as Byte66}; +use super::hash_algorithm::{HashAlgorithm, UnknownHashAlgorithmError}; use super::serde_utils::SliceHex; use anyhow::anyhow; use ckb_sdk::{Since, SinceType}; @@ -920,13 +921,14 @@ impl TryFrom for ClosingSigned { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct AddTlc { pub channel_id: Hash256, pub tlc_id: u64, pub amount: u128, pub payment_hash: Hash256, pub expiry: LockTime, + pub hash_algorithm: HashAlgorithm, } impl From for molecule_cfn::AddTlc { @@ -937,6 +939,7 @@ impl From for molecule_cfn::AddTlc { .amount(add_tlc.amount.pack()) .payment_hash(add_tlc.payment_hash.into()) .expiry(add_tlc.expiry.into()) + .hash_algorithm(Byte::new(add_tlc.hash_algorithm as u8)) .build() } } @@ -951,6 +954,10 @@ impl TryFrom for AddTlc { amount: add_tlc.amount().unpack(), payment_hash: add_tlc.payment_hash().into(), expiry: add_tlc.expiry().try_into()?, + hash_algorithm: add_tlc + .hash_algorithm() + .try_into() + .map_err(|err: UnknownHashAlgorithmError| Error::AnyHow(err.into()))?, }) } } @@ -1335,4 +1342,19 @@ mod tests { let pubkey: Pubkey = serde_json::from_str(&pk_str).unwrap(); assert_eq!(pubkey, public_key) } + + #[test] + fn test_add_tlc_serialization() { + let add_tlc = super::AddTlc { + channel_id: [42; 32].into(), + tlc_id: 42, + amount: 42, + payment_hash: [42; 32].into(), + expiry: 42.into(), + hash_algorithm: super::HashAlgorithm::Sha256, + }; + let add_tlc_mol: super::molecule_cfn::AddTlc = add_tlc.clone().into(); + let add_tlc2 = add_tlc_mol.try_into().expect("decode"); + assert_eq!(add_tlc, add_tlc2); + } } diff --git a/src/invoice/invoice_impl.rs b/src/invoice/invoice_impl.rs index 32f0ffe74..ee5c4f54e 100644 --- a/src/invoice/invoice_impl.rs +++ b/src/invoice/invoice_impl.rs @@ -1,6 +1,7 @@ use super::errors::VerificationError; use super::utils::*; use crate::ckb::gen::invoice::{self as gen_invoice, *}; +use crate::ckb::hash_algorithm::HashAlgorithm; use crate::ckb::serde_utils::EntityHex; use crate::ckb::serde_utils::U128Hex; use crate::ckb::types::Hash256; @@ -16,7 +17,6 @@ use bitcoin::{ Message, PublicKey, }, }; -use ckb_hash::blake2b_256; use ckb_types::{ packed::{Byte, Script}, prelude::{Pack, Unpack}, @@ -87,6 +87,7 @@ pub enum Attribute { FallbackAddr(String), UdtScript(CkbScript), PayeePublicKey(PublicKey), + HashAlgorithm(HashAlgorithm), Feature(u64), } @@ -452,6 +453,11 @@ impl From for InvoiceAttr { .value(pubkey.serialize().pack()) .build(), ), + Attribute::HashAlgorithm(hash_algorithm) => InvoiceAttrUnion::HashAlgorithm( + gen_invoice::HashAlgorithm::new_builder() + .value(Byte::new(hash_algorithm as u8)) + .build(), + ), }; InvoiceAttr::new_builder().set(a).build() } @@ -487,6 +493,12 @@ impl From for Attribute { let value: Vec = x.value().unpack(); Attribute::PayeePublicKey(PublicKey::from_slice(&value).unwrap()) } + InvoiceAttrUnion::HashAlgorithm(x) => { + let value = x.value(); + // Consider unknown algorithm as the default one. + let hash_algorithm = value.try_into().unwrap_or_default(); + Attribute::HashAlgorithm(hash_algorithm) + } } } } @@ -553,6 +565,10 @@ impl InvoiceBuilder { self.add_attr(Attribute::UdtScript(CkbScript(script))) } + pub fn hash_algorithm(self, algorithm: HashAlgorithm) -> Self { + self.add_attr(Attribute::HashAlgorithm(algorithm)) + } + attr_setter!(description, Description, String); attr_setter!(payee_pub_key, PayeePublicKey, PublicKey); attr_setter!(expiry_time, ExpiryTime, Duration); @@ -566,7 +582,16 @@ impl InvoiceBuilder { return Err(InvoiceError::BothPaymenthashAndPreimage); } let payment_hash: Hash256 = if let Some(preimage) = preimage { - blake2b_256(preimage.as_ref()).into() + let algo = self + .attrs + .iter() + .find_map(|attr| match attr { + Attribute::HashAlgorithm(algo) => Some(algo), + _ => None, + }) + .copied() + .unwrap_or_default(); + algo.hash(preimage.as_ref()).into() } else if let Some(payment_hash) = self.payment_hash { payment_hash } else { @@ -708,6 +733,7 @@ mod tests { key::{KeyPair, Secp256k1}, secp256k1::SecretKey, }; + use ckb_hash::blake2b_256; use std::time::{SystemTime, UNIX_EPOCH}; fn gen_rand_public_key() -> PublicKey { diff --git a/src/rpc/channel.rs b/src/rpc/channel.rs index b6d80932d..e35fc8fb7 100644 --- a/src/rpc/channel.rs +++ b/src/rpc/channel.rs @@ -3,6 +3,7 @@ use crate::ckb::{ AddTlcCommand, ChannelActorStateStore, ChannelCommand, ChannelCommandWithId, ChannelState, RemoveTlcCommand, ShutdownCommand, }, + hash_algorithm::HashAlgorithm, network::{AcceptChannelCommand, OpenChannelCommand}, serde_utils::{U128Hex, U32Hex, U64Hex}, types::{Hash256, LockTime, RemoveTlcFail, RemoveTlcFulfill}, @@ -93,6 +94,7 @@ pub struct AddTlcParams { pub amount: u128, pub payment_hash: Hash256, pub expiry: LockTime, + pub hash_algorithm: Option, } #[serde_as] @@ -278,6 +280,7 @@ where preimage: None, payment_hash: Some(params.payment_hash), expiry: params.expiry, + hash_algorithm: params.hash_algorithm.unwrap_or_default(), }, rpc_reply, ), diff --git a/src/rpc/invoice.rs b/src/rpc/invoice.rs index 0cf5ab234..f435146fb 100644 --- a/src/rpc/invoice.rs +++ b/src/rpc/invoice.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use crate::ckb::hash_algorithm::HashAlgorithm; use crate::ckb::serde_utils::{U128Hex, U64Hex}; use crate::ckb::types::Hash256; use crate::invoice::{CkbInvoice, Currency, InvoiceBuilder, InvoiceStore}; @@ -25,6 +26,7 @@ pub struct NewInvoiceParams { #[serde_as(as = "Option")] pub final_htlc_timeout: Option, pub udt_type_script: Option