diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index 79a76aad61..08be40a7c5 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -22,10 +22,6 @@ const ( // HTLCs and other parameters. This message is also used to redeclare initially // set channel parameters. type ChannelUpdate2 struct { - // Signature is used to validate the announced data and prove the - // ownership of node id. - Signature Sig - // ChainHash denotes the target chain that this channel was opened // within. This value should be the genesis hash of the target chain. // Along with the short channel ID, this uniquely identifies the @@ -74,10 +70,22 @@ type ChannelUpdate2 struct { // millionth of a satoshi. FeeProportionalMillionths tlv.RecordT[tlv.TlvType18, uint32] - // ExtraOpaqueData is the set of data that was appended to this message - // to fill out the full maximum transport message size. These fields can - // be used to specify optional data such as custom TLV fields. - ExtraOpaqueData ExtraOpaqueData + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraFieldsInSignedRange map[uint64][]byte +} + +// Encode serializes the target ChannelUpdate2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(c, w) } // Decode deserializes a serialized ChannelUpdate2 stored in the passed @@ -85,17 +93,6 @@ type ChannelUpdate2 struct { // // This is part of the lnwire.Message interface. func (c *ChannelUpdate2) Decode(r io.Reader, _ uint32) error { - err := ReadElement(r, &c.Signature) - if err != nil { - return err - } - c.Signature.ForceSchnorr() - - return c.DecodeTLVRecords(r) -} - -// DecodeTLVRecords decodes only the TLV section of the message. -func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { // First extract into extra opaque data. var tlvRecords ExtraOpaqueData if err := ReadElements(r, &tlvRecords); err != nil { @@ -111,10 +108,12 @@ func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { &secondPeer, &c.CLTVExpiryDelta, &c.HTLCMinimumMsat, &c.HTLCMaximumMsat, &c.FeeBaseMsat, &c.FeeProportionalMillionths, + &c.Signature, ) if err != nil { return err } + c.Signature.Val.ForceSchnorr() // By default, the chain-hash is the bitcoin mainnet genesis block hash. c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash @@ -150,38 +149,21 @@ func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { c.FeeProportionalMillionths.Val = defaultFeeProportionalMillionths //nolint:lll } - if len(tlvRecords) != 0 { - c.ExtraOpaqueData = tlvRecords - } + c.ExtraFieldsInSignedRange = ExtraSignedFieldsFromTypeMap(typeMap) return nil } -// Encode serializes the target ChannelUpdate2 into the passed io.Writer -// observing the protocol version specified. +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. // -// This is part of the lnwire.Message interface. -func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { - _, err := w.Write(c.Signature.RawBytes()) - if err != nil { - return err - } - - _, err = c.DataToSign() - if err != nil { - return err - } - - return WriteBytes(w, c.ExtraOpaqueData) -} +// NOTE: this is part of the PureTLVMessage interface. +func (c *ChannelUpdate2) AllRecords() []tlv.Record { + var recordProducers []tlv.RecordProducer -// DataToSign is used to retrieve part of the announcement message which should -// be signed. For the ChannelUpdate2 message, this includes the serialised TLV -// records. -func (c *ChannelUpdate2) DataToSign() ([]byte, error) { // The chain-hash record is only included if it is _not_ equal to the // bitcoin mainnet genisis block hash. - var recordProducers []tlv.RecordProducer if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() hash.Val = c.ChainHash.Val @@ -190,7 +172,7 @@ func (c *ChannelUpdate2) DataToSign() ([]byte, error) { } recordProducers = append(recordProducers, - &c.ShortChannelID, &c.BlockHeight, + &c.ShortChannelID, &c.BlockHeight, &c.Signature, ) // Only include the disable flags if any bit is set. @@ -225,12 +207,11 @@ func (c *ChannelUpdate2) DataToSign() ([]byte, error) { ) } - err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) - if err != nil { - return nil, err - } + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(c.ExtraFieldsInSignedRange), + )...) - return c.ExtraOpaqueData, nil + return ProduceRecordsSorted(recordProducers...) } // MsgType returns the integer uniquely identifying this message type on the @@ -241,8 +222,14 @@ func (c *ChannelUpdate2) MsgType() MessageType { return MsgChannelUpdate2 } -func (c *ChannelUpdate2) ExtraData() ExtraOpaqueData { - return c.ExtraOpaqueData +func (c *ChannelUpdate2) ExtraData() (ExtraOpaqueData, error) { + var buf *bytes.Buffer + err := EncodeRecordsTo(buf, tlv.MapToRecords(c.ExtraFieldsInSignedRange)) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil } // A compile time check to ensure ChannelUpdate2 implements the diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index e5ae1707eb..6be13fea91 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1603,11 +1603,13 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, MsgChannelUpdate2: func(v []reflect.Value, r *rand.Rand) { - req := ChannelUpdate2{ - Signature: testSchnorrSig, - ExtraOpaqueData: make([]byte, 0), - } + var req ChannelUpdate2 + + req.ExtraFieldsInSignedRange = randSignedRangeRecords( + t, r, + ) + req.Signature.Val = testSchnorrSig req.ShortChannelID.Val = NewShortChanIDFromInt( uint64(r.Int63()), ) @@ -1668,15 +1670,6 @@ func TestLightningWireProtocol(t *testing.T) { ChanUpdateDisableOutgoing } - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - v[0] = reflect.ValueOf(req) }, } diff --git a/netann/channel_update.go b/netann/channel_update.go index af91abdd24..bf64f10f36 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -235,7 +235,7 @@ func verifyChannelUpdate2Signature(c *lnwire.ChannelUpdate2, return fmt.Errorf("unable to reconstruct message data: %w", err) } - nodeSig, err := c.Signature.ToSignature() + nodeSig, err := c.Signature.Val.ToSignature() if err != nil { return err } @@ -323,7 +323,7 @@ func ChanUpdate2DigestTag() []byte { // chanUpdate2DigestToSign computes the digest of the ChannelUpdate2 message to // be signed. func chanUpdate2DigestToSign(c *lnwire.ChannelUpdate2) ([]byte, error) { - data, err := c.DataToSign() + data, err := lnwire.SerialiseFieldsToSign(c) if err != nil { return nil, err }