diff --git a/lnwire/announcement_signatures_2.go b/lnwire/announcement_signatures_2.go index a104470321..526b995485 100644 --- a/lnwire/announcement_signatures_2.go +++ b/lnwire/announcement_signatures_2.go @@ -3,6 +3,8 @@ package lnwire import ( "bytes" "io" + + "github.com/lightningnetwork/lnd/tlv" ) // AnnounceSignatures2 is a direct message between two endpoints of a @@ -14,27 +16,40 @@ type AnnounceSignatures2 struct { // Channel id is better for users and debugging and short channel id is // used for quick test on existence of the particular utxo inside the // blockchain, because it contains information about block. - ChannelID ChannelID + ChannelID tlv.RecordT[tlv.TlvType0, ChannelID] // ShortChannelID is the unique description of the funding transaction. // It is constructed with the most significant 3 bytes as the block // height, the next 3 bytes indicating the transaction index within the // block, and the least significant two bytes indicating the output // index which pays to the channel. - ShortChannelID ShortChannelID + ShortChannelID tlv.RecordT[tlv.TlvType2, ShortChannelID] // PartialSignature is the combination of the partial Schnorr signature // created for the node's bitcoin key with the partial signature created // for the node's node ID key. - PartialSignature PartialSig - - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData ExtraOpaqueData + PartialSignature tlv.RecordT[tlv.TlvType4, PartialSig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraFieldsInSignedRange map[uint64][]byte +} + +// NewAnnSigs2 is a constructor for AnnounceSignatures2. +func NewAnnSigs2(chanID ChannelID, scid ShortChannelID, + partialSig PartialSig) *AnnounceSignatures2 { + + return &AnnounceSignatures2{ + ChannelID: tlv.NewRecordT[tlv.TlvType0, ChannelID](chanID), + ShortChannelID: tlv.NewRecordT[tlv.TlvType2, ShortChannelID]( + scid, + ), + PartialSignature: tlv.NewRecordT[tlv.TlvType4, PartialSig]( + partialSig, + ), + ExtraFieldsInSignedRange: make(map[uint64][]byte, 0), + } } // A compile time check to ensure AnnounceSignatures2 implements the @@ -46,32 +61,29 @@ var _ Message = (*AnnounceSignatures2)(nil) // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures2) Decode(r io.Reader, _ uint32) error { - return ReadElements(r, - &a.ChannelID, - &a.ShortChannelID, - &a.PartialSignature, - &a.ExtraOpaqueData, - ) -} - -// Encode serializes the target AnnounceSignatures2 into the passed io.Writer -// observing the protocol version specified. -// -// This is part of the lnwire.Message interface. -func (a *AnnounceSignatures2) Encode(w *bytes.Buffer, _ uint32) error { - if err := WriteChannelID(w, a.ChannelID); err != nil { + stream, err := tlv.NewStream(ProduceRecordsSorted( + &a.ChannelID, &a.ShortChannelID, &a.PartialSignature, + )...) + if err != nil { return err } - if err := WriteShortChannelID(w, a.ShortChannelID); err != nil { + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { return err } - if err := WriteElement(w, a.PartialSignature); err != nil { - return err - } + a.ExtraFieldsInSignedRange = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} - return WriteBytes(w, a.ExtraOpaqueData) +// Encode serializes the target AnnounceSignatures2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(a, w) } // MsgType returns the integer uniquely identifying this message type on the @@ -82,16 +94,34 @@ func (a *AnnounceSignatures2) MsgType() MessageType { return MsgAnnounceSignatures2 } +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. +// +// NOTE: this is part of the PureTLVMessage interface. +func (a *AnnounceSignatures2) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &a.ChannelID, &a.ShortChannelID, + &a.PartialSignature, + } + + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(a.ExtraFieldsInSignedRange), + )...) + + return ProduceRecordsSorted(recordProducers...) +} + // SCID returns the ShortChannelID of the channel. // // NOTE: this is part of the AnnounceSignatures interface. func (a *AnnounceSignatures2) SCID() ShortChannelID { - return a.ShortChannelID + return a.ShortChannelID.Val } // ChanID returns the ChannelID identifying the channel. // // NOTE: this is part of the AnnounceSignatures interface. func (a *AnnounceSignatures2) ChanID() ChannelID { - return a.ChannelID + return a.ChannelID.Val } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 6be13fea91..0dbe672579 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1519,29 +1519,19 @@ func TestLightningWireProtocol(t *testing.T) { MsgAnnounceSignatures2: func(v []reflect.Value, r *rand.Rand) { - req := AnnounceSignatures2{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - ExtraOpaqueData: make([]byte, 0), - } + var req AnnounceSignatures2 + + req.ExtraFieldsInSignedRange = randSignedRangeRecords( + t, r, + ) - _, err := r.Read(req.ChannelID[:]) + _, err := r.Read(req.ChannelID.Val[:]) require.NoError(t, err) partialSig, err := randPartialSig(r) require.NoError(t, err) - req.PartialSignature = *partialSig - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } + req.PartialSignature.Val = *partialSig v[0] = reflect.ValueOf(req) },