From 53bdf650aa2851fe7b53cca8148018d0d7d4a111 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 2 Dec 2024 12:47:56 +0200 Subject: [PATCH] lnwire: add NodeAnnouncement2 --- lnwire/lnwire_test.go | 85 +++++++ lnwire/message.go | 5 + lnwire/node_announcement_2.go | 429 ++++++++++++++++++++++++++++++++++ 3 files changed, 519 insertions(+) create mode 100644 lnwire/node_announcement_2.go diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 339c9ed782..6c5d384121 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -3,6 +3,7 @@ package lnwire import ( "bytes" crand "crypto/rand" + "encoding/base64" "encoding/binary" "encoding/hex" "fmt" @@ -1660,6 +1661,83 @@ func TestLightningWireProtocol(t *testing.T) { ChanUpdateDisableOutgoing } + v[0] = reflect.ValueOf(req) + }, + MsgNodeAnnouncement2: func(v []reflect.Value, r *rand.Rand) { + var req NodeAnnouncement2 + + req.ExtraSignedFields = ExtraSignedFields( + randSignedRangeRecords(t, r), + ) + req.Signature.Val = testSchnorrSig + + req.NodeID.Val = randRawKey(t) + req.BlockHeight.Val = r.Uint32() + req.Features.Val = *randRawFeatureVector(r) + + // Sometimes set the colour field. + if r.Int31()%2 == 0 { + color := tlv.ZeroRecordT[tlv.TlvType1, Color]() + color.Val = Color{ + R: uint8(r.Int31()), + G: uint8(r.Int31()), + B: uint8(r.Int31()), + } + req.Color = tlv.SomeRecordT(color) + } + + n := r.Intn(33) + b := make([]byte, n) + _, err := rand.Read(b) + require.NoError(t, err) + if n > 0 { + alias := []byte( + base64.StdEncoding.EncodeToString(b), + ) + if len(alias) > 32 { + alias = alias[:32] + } + + aliasRec := tlv.ZeroRecordT[ + tlv.TlvType3, []byte, + ]() + aliasRec.Val = alias + } + + // Sometimes add some ipv4 addrs. + if r.Int31()%2 == 0 { + ipv4Addr, err := randTCP4Addr(r) + require.NoError(t, err) + + ipv4AddrRecord := tlv.ZeroRecordT[ + tlv.TlvType5, IPV4Addrs, + ]() + ipv4AddrRecord.Val = IPV4Addrs{ipv4Addr} + req.IPV4Addrs = tlv.SomeRecordT(ipv4AddrRecord) + } + // Sometimes add some ipv6 addrs. + if r.Int31()%2 == 0 { + ipv6Addr, err := randTCP6Addr(r) + require.NoError(t, err) + + ipv6AddrRecord := tlv.ZeroRecordT[ + tlv.TlvType7, IPV6Addrs, + ]() + ipv6AddrRecord.Val = IPV6Addrs{ipv6Addr} + req.IPV6Addrs = tlv.SomeRecordT(ipv6AddrRecord) + } + // Sometimes add some torv3 addrs. + if r.Int31()%2 == 0 { + torAddr, err := randV3OnionAddr(r) + require.NoError(t, err) + + torAddrRecord := tlv.ZeroRecordT[ + tlv.TlvType9, TorV3Addrs, + ]() + torAddrRecord.Val = TorV3Addrs{torAddr} + req.TorV3Addrs = tlv.SomeRecordT(torAddrRecord) + } + v[0] = reflect.ValueOf(req) }, } @@ -1902,12 +1980,19 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { msgType: MsgChannelUpdate2, scenario: func(m ChannelUpdate2) bool { return mainScenario(&m) }, }, + { + msgType: MsgNodeAnnouncement2, + scenario: func(m NodeAnnouncement2) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { t.Run(test.msgType.String(), func(t *testing.T) { diff --git a/lnwire/message.go b/lnwire/message.go index 68b09692e5..18d0e59978 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -59,6 +59,7 @@ const ( MsgReplyChannelRange = 264 MsgGossipTimestampRange = 265 MsgChannelAnnouncement2 = 267 + MsgNodeAnnouncement2 = 269 MsgChannelUpdate2 = 271 MsgKickoffSig = 777 ) @@ -181,6 +182,8 @@ func (t MessageType) String() string { return "MsgAnnounceSignatures2" case MsgChannelAnnouncement2: return "ChannelAnnouncement2" + case MsgNodeAnnouncement2: + return "NodeAnnouncement2" case MsgChannelUpdate2: return "ChannelUpdate2" default: @@ -316,6 +319,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &AnnounceSignatures2{} case MsgChannelAnnouncement2: msg = &ChannelAnnouncement2{} + case MsgNodeAnnouncement2: + msg = &NodeAnnouncement2{} case MsgChannelUpdate2: msg = &ChannelUpdate2{} default: diff --git a/lnwire/node_announcement_2.go b/lnwire/node_announcement_2.go new file mode 100644 index 0000000000..183f94c638 --- /dev/null +++ b/lnwire/node_announcement_2.go @@ -0,0 +1,429 @@ +package lnwire + +import ( + "bytes" + "encoding/binary" + "fmt" + "image/color" + "io" + "net" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/lightningnetwork/lnd/tor" +) + +type NodeAnnouncement2 struct { + // Features is the feature vector that encodes the features supported + // by the target node. + Features tlv.RecordT[tlv.TlvType0, RawFeatureVector] + + // Color is an optional field used to customize a node's appearance in + // maps and graphs. + Color tlv.OptionalRecordT[tlv.TlvType1, Color] + + // BlockHeight allows ordering in the case of multiple announcements. We + // should ignore the message if block height is not greater than the + // last-received. The block height must always be greater or equal to + // the block height that the channel funding transaction was confirmed + // in. + BlockHeight tlv.RecordT[tlv.TlvType2, uint32] + + // Alias is used to customize their node's appearance in maps and + // graphs. + Alias tlv.OptionalRecordT[tlv.TlvType3, []byte] + + // NodeID is the public key of the node creating the announcement. + NodeID tlv.RecordT[tlv.TlvType6, [33]byte] + + // IPV4Addrs is an optional list of ipv4 addresses that the node is + // reachable at. + IPV4Addrs tlv.OptionalRecordT[tlv.TlvType5, IPV4Addrs] + + // IPV6Addrs is an optional list of ipv6 addresses that the node is + // reachable at. + IPV6Addrs tlv.OptionalRecordT[tlv.TlvType7, IPV6Addrs] + + // TorV3Addrs is an optional list of tor v3 addresses that the node is + // reachable at. + TorV3Addrs tlv.OptionalRecordT[tlv.TlvType9, TorV3Addrs] + + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. +// +// NOTE: this is part of the PureTLVMessage interface. +func (n *NodeAnnouncement2) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &n.Features, + &n.BlockHeight, + &n.NodeID, + &n.Signature, + } + + n.Color.WhenSome(func(r tlv.RecordT[tlv.TlvType1, Color]) { + recordProducers = append(recordProducers, &r) + }) + + n.Alias.WhenSome(func(a tlv.RecordT[tlv.TlvType3, []byte]) { + recordProducers = append(recordProducers, &a) + }) + + n.IPV4Addrs.WhenSome(func(r tlv.RecordT[tlv.TlvType5, IPV4Addrs]) { + recordProducers = append(recordProducers, &r) + }) + + n.IPV6Addrs.WhenSome(func(r tlv.RecordT[tlv.TlvType7, IPV6Addrs]) { + recordProducers = append(recordProducers, &r) + }) + + n.TorV3Addrs.WhenSome(func(r tlv.RecordT[tlv.TlvType9, TorV3Addrs]) { + recordProducers = append(recordProducers, &r) + }) + + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(n.ExtraSignedFields), + )...) + + return ProduceRecordsSorted(recordProducers...) +} + +// Decode deserializes a serialized ChannelUpdate2 stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (n *NodeAnnouncement2) Decode(r io.Reader, _ uint32) error { + var ( + color = tlv.ZeroRecordT[tlv.TlvType1, Color]() + alias = tlv.ZeroRecordT[tlv.TlvType3, []byte]() + ipv4 = tlv.ZeroRecordT[tlv.TlvType5, IPV4Addrs]() + ipv6 = tlv.ZeroRecordT[tlv.TlvType7, IPV6Addrs]() + torV3 = tlv.ZeroRecordT[tlv.TlvType9, TorV3Addrs]() + ) + stream, err := tlv.NewStream(ProduceRecordsSorted( + &n.Features, + &n.BlockHeight, + &n.NodeID, + &n.Signature, + &alias, + &color, + &ipv4, + &ipv6, + &torV3, + )...) + if err != nil { + return err + } + n.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + if _, ok := typeMap[n.Alias.TlvType()]; ok { + n.Alias = tlv.SomeRecordT(alias) + } + + if _, ok := typeMap[n.Color.TlvType()]; ok { + n.Color = tlv.SomeRecordT(color) + } + + if _, ok := typeMap[n.IPV4Addrs.TlvType()]; ok { + n.IPV4Addrs = tlv.SomeRecordT(ipv4) + } + + if _, ok := typeMap[n.IPV6Addrs.TlvType()]; ok { + n.IPV6Addrs = tlv.SomeRecordT(ipv6) + } + + if _, ok := typeMap[n.TorV3Addrs.TlvType()]; ok { + n.TorV3Addrs = tlv.SomeRecordT(torV3) + } + + n.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// Encode serializes the target ChannelUpdate2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (n *NodeAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(n, w) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (n *NodeAnnouncement2) MsgType() MessageType { + return MsgNodeAnnouncement2 +} + +// A compile-time check to ensure NodeAnnouncement2 implements the Message +// interface. +var _ Message = (*NodeAnnouncement2)(nil) + +// A compile-time check to ensure NodeAnnouncement2 implements the +// PureTLVMessage interface. +var _ PureTLVMessage = (*NodeAnnouncement2)(nil) + +type Color color.RGBA + +func (c *Color) Record() tlv.Record { + return tlv.MakeStaticRecord(0, c, 3, rgbEncoder, rgbDecoder) +} + +func rgbEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*Color); ok { + buf := bytes.NewBuffer(nil) + err := WriteColorRGBA(buf, color.RGBA(*v)) + if err != nil { + return err + } + _, err = w.Write(buf.Bytes()) + return err + } + return tlv.NewTypeForEncodingErr(val, "Color") +} + +func rgbDecoder(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if v, ok := val.(*Color); ok { + return ReadElements(r, &v.R, &v.G, &v.B) + } + return tlv.NewTypeForDecodingErr(val, "Color", l, 3) +} + +// ipv4AddrEncodedSize is the number of bytes required to encode a single ipv4 +// address. Four bytes are used to encode the IP address and two bytes for the +// port number. +const ipv4AddrEncodedSize = 4 + 2 + +// IPV4Addrs is a list of ipv4 addresses that can be encoded as a TLV record. +type IPV4Addrs []*net.TCPAddr + +// Record returns a Record that can be used to encode/decode a IPV4Addrs +// to/from a TLV stream. +func (a *IPV4Addrs) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, a, a.EncodedSize, ipv4AddrsEncoder, ipv4AddrsDecoder, + ) +} + +// EncodedSize returns the number of bytes required to encode an IPV4Addrs +// variable. +func (a *IPV4Addrs) EncodedSize() uint64 { + return uint64(len(*a) * ipv4AddrEncodedSize) +} + +func ipv4AddrsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*IPV4Addrs); ok { + for _, ip := range *v { + _, err := w.Write(ip.IP.To4()) + if err != nil { + return err + } + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(ip.Port)) + _, err = w.Write(port[:]) + return err + } + } + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV4Addrs") +} + +func ipv4AddrsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + if v, ok := val.(*IPV4Addrs); ok { + if l%(ipv4AddrEncodedSize) != 0 { + return fmt.Errorf("invalid ipv4 list encoding") + } + var ( + numAddrs = int(l / ipv4AddrEncodedSize) + addrs = make([]*net.TCPAddr, 0, numAddrs) + ip [4]byte + port [2]byte + ) + for len(addrs) < numAddrs { + _, err := r.Read(ip[:]) + if err != nil { + return err + } + _, err = r.Read(port[:]) + if err != nil { + return err + } + addrs = append(addrs, &net.TCPAddr{ + IP: ip[:], + Port: int(binary.BigEndian.Uint16(port[:])), + }) + } + *v = addrs + return nil + } + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV4Addrs") +} + +// IPV6Addrs is a list of ipv6 addresses that can be encoded as a TLV record. +type IPV6Addrs []*net.TCPAddr + +// Record returns a Record that can be used to encode/decode a IPV4Addrs +// to/from a TLV stream. +func (a *IPV6Addrs) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, a, a.EncodedSize, ipv6AddrsEncoder, ipv6AddrsDecoder, + ) +} + +// ipv6AddrEncodedSize is the number of bytes required to encode a single ipv6 +// address. Sixteen bytes are used to encode the IP address and two bytes for +// the port number. +const ipv6AddrEncodedSize = 16 + 2 + +// EncodedSize returns the number of bytes required to encode an IPV6Addrs +// variable. +func (a *IPV6Addrs) EncodedSize() uint64 { + return uint64(len(*a) * ipv6AddrEncodedSize) +} + +func ipv6AddrsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*IPV6Addrs); ok { + for _, ip := range *v { + _, err := w.Write(ip.IP.To16()) + if err != nil { + return err + } + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(ip.Port)) + _, err = w.Write(port[:]) + return err + } + } + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV6Addrs") +} + +func ipv6AddrsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + if v, ok := val.(*IPV6Addrs); ok { + if l%(ipv6AddrEncodedSize) != 0 { + return fmt.Errorf("invalid ipv6 list encoding") + } + var ( + numAddrs = int(l / ipv6AddrEncodedSize) + addrs = make([]*net.TCPAddr, 0, numAddrs) + ip [16]byte + port [2]byte + ) + for len(addrs) < numAddrs { + _, err := r.Read(ip[:]) + if err != nil { + return err + } + _, err = r.Read(port[:]) + if err != nil { + return err + } + addrs = append(addrs, &net.TCPAddr{ + IP: ip[:], + Port: int(binary.BigEndian.Uint16(port[:])), + }) + } + *v = addrs + return nil + } + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV6Addrs") +} + +// TorV3Addrs is a list of tor v3 addresses that can be encoded as a TLV record. +type TorV3Addrs []*tor.OnionAddr + +// torV3AddrEncodedSize is the number of bytes required to encode a single tor +// v3 address. +const torV3AddrEncodedSize = tor.V3DecodedLen + 2 + +// EncodedSize returns the number of bytes required to encode an TorV3Addrs +// variable. +func (a *TorV3Addrs) EncodedSize() uint64 { + return uint64(len(*a) * torV3AddrEncodedSize) +} + +// Record returns a Record that can be used to encode/decode a IPV4Addrs +// to/from a TLV stream. +func (a *TorV3Addrs) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, a, a.EncodedSize, torV3AddrsEncoder, torV3AddrsDecoder, + ) +} + +func torV3AddrsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*TorV3Addrs); ok { + for _, addr := range *v { + encodedHostLen := tor.V3Len - tor.OnionSuffixLen + host, err := tor.Base32Encoding.DecodeString( + addr.OnionService[:encodedHostLen], + ) + if err != nil { + return err + } + if len(host) != tor.V3DecodedLen { + return fmt.Errorf("expected a tor v3 host "+ + "length of %d, got: %d", + tor.V2DecodedLen, len(host)) + } + if _, err = w.Write(host); err != nil { + return err + } + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(addr.Port)) + _, err = w.Write(port[:]) + return err + } + } + return tlv.NewTypeForEncodingErr(val, "lnwire.TorV3Addrs") +} + +func torV3AddrsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + if v, ok := val.(*TorV3Addrs); ok { + if l%torV3AddrEncodedSize != 0 { + return fmt.Errorf("invalid tor v3 list encoding") + } + var ( + numAddrs = int(l / torV3AddrEncodedSize) + addrs = make([]*tor.OnionAddr, 0, numAddrs) + ip [tor.V3DecodedLen]byte + p [2]byte + ) + for len(addrs) < numAddrs { + _, err := r.Read(ip[:]) + if err != nil { + return err + } + _, err = r.Read(p[:]) + if err != nil { + return err + } + onionService := tor.Base32Encoding.EncodeToString(ip[:]) + onionService += tor.OnionSuffix + port := int(binary.BigEndian.Uint16(p[:])) + addrs = append(addrs, &tor.OnionAddr{ + OnionService: onionService, + Port: port, + }) + } + *v = addrs + return nil + } + return tlv.NewTypeForEncodingErr(val, "lnwire.TorV3Addrs") +}