diff --git a/channeldb/graph.go b/channeldb/graph.go index 5647326dc4..c0cbc32874 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -1009,7 +1009,9 @@ func (c *ChannelGraph) AddChannelEdge(edge *models.ChannelEdgeInfo, // addChannelEdge is the private form of AddChannelEdge that allows callers to // utilize an existing db transaction. -func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *models.ChannelEdgeInfo) error { +func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, + edge *models.ChannelEdgeInfo) error { + // Construct the channel's primary key which is the 8-byte channel ID. var chanKey [8]byte binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) @@ -1265,7 +1267,8 @@ const ( // with the current UTXO state. A slice of channels that have been closed by // the target block are returned if the function succeeds without error. func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, - blockHash *chainhash.Hash, blockHeight uint32) ([]*models.ChannelEdgeInfo, error) { + blockHash *chainhash.Hash, blockHeight uint32) ( + []*models.ChannelEdgeInfo, error) { c.cacheMu.Lock() defer c.cacheMu.Unlock() @@ -1518,8 +1521,8 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, // set to the last prune height valid for the remaining chain. // Channels that were removed from the graph resulting from the // disconnected block are returned. -func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*models.ChannelEdgeInfo, - error) { +func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( + []*models.ChannelEdgeInfo, error) { // Every channel having a ShortChannelID starting at 'height' // will no longer be confirmed. @@ -2552,7 +2555,9 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *models.ChannelEdgePolicy, return c.chanScheduler.Execute(r) } -func (c *ChannelGraph) updateEdgeCache(e *models.ChannelEdgePolicy, isUpdate1 bool) { +func (c *ChannelGraph) updateEdgeCache(e *models.ChannelEdgePolicy, + isUpdate1 bool) { + // If an entry for this channel is found in reject cache, we'll modify // the entry with the updated timestamp for the direction that was just // written. If the edge doesn't exist, we'll load the cache entry lazily @@ -2964,7 +2969,8 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // nodeTraversal is used to traverse all channels of a node given by its // public key and passes channel information into the specified callback. func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, - cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error { traversal := func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) @@ -3067,8 +3073,9 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub route.Vertex, // the target node in the channel. This is useful when one knows the pubkey of // one of the nodes, and wishes to obtain the full LightningNode for the other // end of the channel. -func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, channel *models.ChannelEdgeInfo, - thisNodeKey []byte) (*LightningNode, error) { +func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, + channel *models.ChannelEdgeInfo, thisNodeKey []byte) (*LightningNode, + error) { // Ensure that the node passed in is actually a member of the channel. var targetNodeBytes [33]byte @@ -3117,8 +3124,9 @@ func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, channel *models.ChannelEdgeIn // found, then ErrEdgeNotFound is returned. A struct which houses the general // information for the channel itself is returned as well as two structs that // contain the routing policies for the channel in either direction. -func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, -) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { +func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { var ( edgeInfo *models.ChannelEdgeInfo @@ -3201,8 +3209,9 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, // ErrZombieEdge an be returned if the edge is currently marked as a zombie // within the database. In this case, the ChannelEdgePolicy's will be nil, and // the ChannelEdgeInfo will only include the public keys of each node. -func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, -) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { +func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { var ( edgeInfo *models.ChannelEdgeInfo @@ -3904,7 +3913,9 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { return node, nil } -func putChanEdgeInfo(edgeIndex kvdb.RwBucket, edgeInfo *models.ChannelEdgeInfo, chanID [8]byte) error { +func putChanEdgeInfo(edgeIndex kvdb.RwBucket, + edgeInfo *models.ChannelEdgeInfo, chanID [8]byte) error { + var b bytes.Buffer if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { @@ -4059,8 +4070,8 @@ func deserializeChanEdgeInfo(r io.Reader) (models.ChannelEdgeInfo, error) { return edgeInfo, nil } -func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy, from, - to []byte) error { +func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy, + from, to []byte) error { var edgeKey [33 + 8]byte copy(edgeKey[:], from) @@ -4213,7 +4224,8 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, } func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, - chanID []byte) (*models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { + chanID []byte) (*models.ChannelEdgePolicy, *models.ChannelEdgePolicy, + error) { edgeInfo := edgeIndex.Get(chanID) if edgeInfo == nil { @@ -4320,7 +4332,9 @@ func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy, error) { return edge, deserializeErr } -func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy, error) { +func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy, + error) { + edge := &models.ChannelEdgePolicy{} var err error diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index 9e142ed73b..7bf9c41d97 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -28,7 +28,8 @@ type GraphCacheNode interface { // error, then the iteration is halted with the error propagated back up // to the caller. ForEachChannel(kvdb.RTx, - func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + func(kvdb.RTx, *models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error } diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 25336d28d6..12e20a9e2c 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -385,7 +385,8 @@ func TestEdgeInsertionDeletion(t *testing.T) { } func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, - node1, node2 *LightningNode) (models.ChannelEdgeInfo, lnwire.ShortChannelID) { + node1, node2 *LightningNode) (models.ChannelEdgeInfo, + lnwire.ShortChannelID) { shortChanID := lnwire.ShortChannelID{ BlockHeight: height, @@ -616,8 +617,9 @@ func assertEdgeInfoEqual(t *testing.T, e1 *models.ChannelEdgeInfo, } } -func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) (*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) { +func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) { var ( firstNode [33]byte @@ -1038,7 +1040,8 @@ func TestGraphTraversal(t *testing.T) { // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have // properly been reached. - err = graph.ForEachChannel(func(ei *models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy, + err = graph.ForEachChannel(func(ei *models.ChannelEdgeInfo, + _ *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { delete(chanIndex, ei.ChannelID) @@ -1132,9 +1135,10 @@ func TestGraphTraversalCacheable(t *testing.T) { err = graph.db.View(func(tx kvdb.RTx) error { for _, node := range nodes { err := node.ForEachChannel( - tx, func(tx kvdb.RTx, info *models.ChannelEdgeInfo, + tx, func(tx kvdb.RTx, + info *models.ChannelEdgeInfo, policy *models.ChannelEdgePolicy, - policy2 *models.ChannelEdgePolicy) error { + policy2 *models.ChannelEdgePolicy) error { //nolint:lll delete(chanIndex, info.ChannelID) return nil @@ -1316,7 +1320,8 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 - if err := graph.ForEachChannel(func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + if err := graph.ForEachChannel(func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error { numChans++ @@ -3435,24 +3440,24 @@ func BenchmarkForEachChannel(b *testing.B) { err = graph.db.View(func(tx kvdb.RTx) error { for _, n := range nodes { - err := n.ForEachChannel( - tx, func(tx kvdb.RTx, - info *models.ChannelEdgeInfo, - policy *models.ChannelEdgePolicy, - policy2 *models.ChannelEdgePolicy) error { - - // We need to do something with - // the data here, otherwise the - // compiler is going to optimize - // this away, and we get bogus - // results. - totalCapacity += info.Capacity - maxHTLCs += policy.MaxHTLC - maxHTLCs += policy2.MaxHTLC - - return nil - }, - ) + cb := func(tx kvdb.RTx, + info *models.ChannelEdgeInfo, + policy *models.ChannelEdgePolicy, + policy2 *models.ChannelEdgePolicy) error { //nolint:lll + + // We need to do something with + // the data here, otherwise the + // compiler is going to optimize + // this away, and we get bogus + // results. + totalCapacity += info.Capacity + maxHTLCs += policy.MaxHTLC + maxHTLCs += policy2.MaxHTLC + + return nil + } + + err := n.ForEachChannel(tx, cb) if err != nil { return err } diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 7f42ffb763..a47d98bb95 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -222,11 +222,6 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, return nil } -func (r *mockGraphSource) ForEachChannel(func(chanInfo *models.ChannelEdgeInfo, - e1, e2 *models.ChannelEdgePolicy) error) error { - return nil -} - func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, diff --git a/lnrpc/devrpc/dev_server.go b/lnrpc/devrpc/dev_server.go index 152d988b6c..afe30b3616 100644 --- a/lnrpc/devrpc/dev_server.go +++ b/lnrpc/devrpc/dev_server.go @@ -289,7 +289,7 @@ func (s *Server) ImportGraph(ctx context.Context, rpcEdge.ChanPoint, err) } - makePolicy := func(rpcPolicy *lnrpc.RoutingPolicy) *models.ChannelEdgePolicy { + makePolicy := func(rpcPolicy *lnrpc.RoutingPolicy) *models.ChannelEdgePolicy { //nolint:lll policy := &models.ChannelEdgePolicy{ ChannelID: rpcEdge.ChannelId, LastUpdate: time.Unix( diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index 26641648d0..1b15921d0a 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -16,9 +16,14 @@ import ( ) type hopHintsConfigMock struct { + t *testing.T mock.Mock } +func newHopHintsConfigMock(t *testing.T) *hopHintsConfigMock { + return &hopHintsConfigMock{t: t} +} + // IsPublicNode mocks node public state lookup. func (h *hopHintsConfigMock) IsPublicNode(pubKey [33]byte) (bool, error) { args := h.Mock.Called(pubKey) @@ -65,9 +70,14 @@ func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( return nil, nil, nil, err } - edgeInfo := args.Get(0).(*models.ChannelEdgeInfo) - policy1 := args.Get(1).(*models.ChannelEdgePolicy) - policy2 := args.Get(2).(*models.ChannelEdgePolicy) + edgeInfo, ok := args.Get(0).(*models.ChannelEdgeInfo) + require.True(h.t, ok) + + policy1, ok := args.Get(1).(*models.ChannelEdgePolicy) + require.True(h.t, ok) + + policy2, ok := args.Get(2).(*models.ChannelEdgePolicy) + require.True(h.t, ok) return edgeInfo, policy1, policy2, err } @@ -429,7 +439,7 @@ func TestShouldIncludeChannel(t *testing.T) { t.Parallel() // Create mock and prime it for the test case. - mock := &hopHintsConfigMock{} + mock := newHopHintsConfigMock(t) if tc.setupMock != nil { tc.setupMock(mock) } @@ -862,7 +872,7 @@ func TestPopulateHopHints(t *testing.T) { t.Parallel() // Create mock and prime it for the test case. - mock := &hopHintsConfigMock{} + mock := newHopHintsConfigMock(t) if tc.setupMock != nil { tc.setupMock(mock) } diff --git a/routing/notifications_test.go b/routing/notifications_test.go index f3bc24ad86..5d6695010c 100644 --- a/routing/notifications_test.go +++ b/routing/notifications_test.go @@ -471,6 +471,7 @@ func TestEdgeUpdateNotification(t *testing.T) { assertEdgeCorrect := func(t *testing.T, edgeUpdate *ChannelEdgeUpdate, edgeAnn *models.ChannelEdgePolicy) { + if edgeUpdate.ChanID != edgeAnn.ChannelID { t.Fatalf("channel ID of edge doesn't match: "+ "expected %v, got %v", chanID.ToUint64(), edgeUpdate.ChanID) diff --git a/rpcserver.go b/rpcserver.go index ec14b55391..fa23f04852 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -5933,7 +5933,7 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, return nil } - edge := marshalDbEdge(edgeInfo, c1, c2) + edge := marshalDBEdge(edgeInfo, c1, c2) resp.Edges = append(resp.Edges, edge) return nil @@ -5978,7 +5978,7 @@ func marshalExtraOpaqueData(data []byte) map[uint64][]byte { return records } -func marshalDbEdge(edgeInfo *models.ChannelEdgeInfo, +func marshalDBEdge(edgeInfo *models.ChannelEdgeInfo, c1, c2 *models.ChannelEdgePolicy) *lnrpc.ChannelEdge { // Make sure the policies match the node they belong to. c1 should point @@ -6114,7 +6114,7 @@ func (r *rpcServer) GetChanInfo(ctx context.Context, // Convert the database's edge format into the network/RPC edge format // which couples the edge itself along with the directional node // routing policies of each node involved within the channel. - channelEdge := marshalDbEdge(edgeInfo, edge1, edge2) + channelEdge := marshalDBEdge(edgeInfo, edge1, edge2) return channelEdge, nil } @@ -6171,7 +6171,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, // Convert the database's edge format into the // network/RPC edge format. - channelEdge := marshalDbEdge(edge, c1, c2) + channelEdge := marshalDBEdge(edge, c1, c2) channels = append(channels, channelEdge) }