From ed2989ae33e230188aaaf8a6052440db7e39810a Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 3 Dec 2024 15:51:05 -0700 Subject: [PATCH 01/59] multi: update to fn v2 --- aliasmgr/aliasmgr.go | 22 +++-- chainntnfs/bitcoindnotify/bitcoind.go | 2 +- chainntnfs/btcdnotify/btcd.go | 2 +- chainntnfs/interface.go | 2 +- chainntnfs/mocks.go | 2 +- chainreg/chainregistry.go | 2 +- chanbackup/backup.go | 2 +- chanbackup/single.go | 2 +- chanbackup/single_test.go | 2 +- channeldb/channel.go | 2 +- channeldb/channel_test.go | 2 +- .../migration/lnwire21/custom_records.go | 15 +-- .../migration32/mission_control_store.go | 4 +- channeldb/revocation_log.go | 2 +- channeldb/revocation_log_test.go | 2 +- cmd/commands/cmd_macaroon.go | 17 ++-- config_builder.go | 2 +- contractcourt/anchor_resolver.go | 2 +- contractcourt/breach_arbitrator.go | 10 +- contractcourt/breach_arbitrator_test.go | 2 +- contractcourt/briefcase.go | 2 +- contractcourt/briefcase_test.go | 2 +- contractcourt/chain_arbitrator.go | 2 +- contractcourt/chain_watcher.go | 12 +-- contractcourt/channel_arbitrator.go | 6 +- contractcourt/channel_arbitrator_test.go | 2 +- contractcourt/commit_sweep_resolver.go | 2 +- contractcourt/contract_resolver.go | 2 +- .../htlc_incoming_contest_resolver.go | 2 +- contractcourt/htlc_lease_resolver.go | 2 +- .../htlc_outgoing_contest_resolver.go | 2 +- contractcourt/htlc_success_resolver.go | 2 +- contractcourt/htlc_success_resolver_test.go | 2 +- contractcourt/htlc_timeout_resolver.go | 2 +- contractcourt/htlc_timeout_resolver_test.go | 2 +- contractcourt/utxonursery.go | 2 +- contractcourt/utxonursery_test.go | 2 +- discovery/gossiper.go | 2 +- discovery/gossiper_test.go | 2 +- funding/aux_funding.go | 2 +- funding/manager.go | 2 +- funding/manager_test.go | 2 +- go.mod | 4 +- go.sum | 8 +- graph/builder.go | 2 +- graph/db/models/channel_edge_info.go | 2 +- htlcswitch/interceptable_switch.go | 2 +- htlcswitch/interfaces.go | 2 +- htlcswitch/link.go | 2 +- htlcswitch/link_test.go | 2 +- htlcswitch/mock.go | 2 +- htlcswitch/quiescer.go | 2 +- htlcswitch/quiescer_test.go | 2 +- htlcswitch/switch.go | 2 +- htlcswitch/switch_test.go | 2 +- input/input.go | 2 +- input/mocks.go | 2 +- input/script_utils.go | 2 +- input/taproot.go | 2 +- input/taproot_test.go | 2 +- intercepted_forward.go | 2 +- invoices/modification_interceptor.go | 2 +- itest/lnd_funding_test.go | 2 +- itest/lnd_sweep_test.go | 6 +- lnrpc/devrpc/dev_server.go | 2 +- lnrpc/marshall_utils.go | 24 +++-- lnrpc/routerrpc/forward_interceptor.go | 2 +- lnrpc/routerrpc/router_backend.go | 2 +- lnrpc/routerrpc/router_server.go | 2 +- lnrpc/walletrpc/walletkit_server.go | 35 ++++--- lntest/harness.go | 2 +- lntest/harness_assertion.go | 4 +- lntest/miner/miner.go | 7 +- lntest/mock/walletcontroller.go | 2 +- lntest/node/state.go | 6 +- lnwallet/aux_leaf_store.go | 2 +- lnwallet/aux_resolutions.go | 2 +- lnwallet/aux_signer.go | 2 +- lnwallet/btcwallet/btcwallet.go | 2 +- lnwallet/chainfee/filtermanager.go | 2 +- lnwallet/chancloser/aux_closer.go | 2 +- lnwallet/chancloser/chancloser.go | 2 +- lnwallet/chancloser/chancloser_test.go | 2 +- lnwallet/chancloser/interface.go | 2 +- lnwallet/chanfunding/canned_assembler.go | 2 +- lnwallet/chanfunding/interface.go | 2 +- lnwallet/chanfunding/psbt_assembler.go | 2 +- lnwallet/channel.go | 93 ++++++++++--------- lnwallet/channel_test.go | 11 ++- lnwallet/commitment.go | 10 +- lnwallet/commitment_chain.go | 2 +- lnwallet/config.go | 2 +- lnwallet/interface.go | 2 +- lnwallet/mock.go | 2 +- lnwallet/musig_session.go | 2 +- lnwallet/reservation.go | 2 +- lnwallet/rpcwallet/rpcwallet.go | 2 +- lnwallet/test/test_interface.go | 2 +- lnwallet/test_utils.go | 2 +- lnwallet/transactions_test.go | 2 +- lnwallet/update_log.go | 2 +- lnwallet/wallet.go | 2 +- lnwire/channel_reestablish.go | 2 +- lnwire/custom_records.go | 15 +-- lnwire/custom_records_test.go | 11 ++- lnwire/dyn_ack.go | 2 +- lnwire/dyn_propose.go | 2 +- lnwire/extra_bytes.go | 2 +- lnwire/lnwire_test.go | 2 +- lnwire/onion_error.go | 2 +- lnwire/onion_error_test.go | 2 +- msgmux/msg_router.go | 4 +- peer/brontide.go | 2 +- peer/brontide_test.go | 2 +- peer/musig_chan_closer.go | 2 +- peer/test_utils.go | 2 +- protofsm/daemon_events.go | 2 +- protofsm/msg_mapper.go | 2 +- protofsm/state_machine.go | 44 +++++++-- protofsm/state_machine_test.go | 2 +- routing/bandwidth.go | 2 +- routing/bandwidth_test.go | 2 +- routing/blinding.go | 2 +- routing/blinding_test.go | 2 +- routing/integrated_routing_context_test.go | 2 +- routing/localchans/manager.go | 2 +- routing/missioncontrol.go | 2 +- routing/mock_test.go | 2 +- routing/pathfind.go | 2 +- routing/pathfind_test.go | 2 +- routing/payment_lifecycle.go | 2 +- routing/payment_lifecycle_test.go | 2 +- routing/payment_session_source.go | 2 +- routing/result_interpretation.go | 4 +- routing/result_interpretation_test.go | 2 +- routing/router.go | 2 +- routing/router_test.go | 2 +- rpcserver.go | 6 +- rpcserver_test.go | 2 +- server.go | 2 +- subrpcserver_config.go | 2 +- sweep/aggregator.go | 2 +- sweep/aggregator_test.go | 2 +- sweep/fee_bumper.go | 8 +- sweep/fee_bumper_test.go | 2 +- sweep/fee_function.go | 2 +- sweep/fee_function_test.go | 2 +- sweep/interface.go | 2 +- sweep/mock_test.go | 2 +- sweep/sweeper.go | 2 +- sweep/sweeper_test.go | 2 +- sweep/tx_input_set.go | 6 +- sweep/tx_input_set_test.go | 2 +- sweep/walletsweep.go | 2 +- sweep/walletsweep_test.go | 2 +- watchtower/blob/justice_kit.go | 2 +- watchtower/blob/justice_kit_test.go | 2 +- watchtower/lookout/justice_descriptor_test.go | 2 +- .../wtclient/backup_task_internal_test.go | 2 +- watchtower/wtclient/client_test.go | 2 +- watchtower/wtclient/manager.go | 2 +- watchtower/wtdb/client_chan_summary.go | 2 +- watchtower/wtdb/client_db.go | 2 +- zpay32/decode.go | 2 +- zpay32/encode.go | 2 +- zpay32/invoice.go | 2 +- zpay32/invoice_test.go | 2 +- 167 files changed, 372 insertions(+), 302 deletions(-) diff --git a/aliasmgr/aliasmgr.go b/aliasmgr/aliasmgr.go index f06cb53d79..a3227b18b8 100644 --- a/aliasmgr/aliasmgr.go +++ b/aliasmgr/aliasmgr.go @@ -5,7 +5,7 @@ import ( "fmt" "sync" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -432,9 +432,9 @@ func (m *Manager) DeleteLocalAlias(alias, } // We'll filter the alias set and remove the alias from it. - aliasSet = fn.Filter(func(a lnwire.ShortChannelID) bool { + aliasSet = fn.Filter(aliasSet, func(a lnwire.ShortChannelID) bool { return a.ToUint64() != alias.ToUint64() - }, aliasSet) + }) // If the alias set is empty, we'll delete the base SCID from the // baseToSet map. @@ -514,11 +514,17 @@ func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) { // haveAlias returns true if the passed alias is already assigned to a // channel in the baseToSet map. haveAlias := func(maybeNextAlias lnwire.ShortChannelID) bool { - return fn.Any(func(aliasList []lnwire.ShortChannelID) bool { - return fn.Any(func(alias lnwire.ShortChannelID) bool { - return alias == maybeNextAlias - }, aliasList) - }, maps.Values(m.baseToSet)) + return fn.Any( + maps.Values(m.baseToSet), + func(aliasList []lnwire.ShortChannelID) bool { + return fn.Any( + aliasList, + func(alias lnwire.ShortChannelID) bool { + return alias == maybeNextAlias + }, + ) + }, + ) } err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { diff --git a/chainntnfs/bitcoindnotify/bitcoind.go b/chainntnfs/bitcoindnotify/bitcoind.go index fc20fbb857..59c03d5171 100644 --- a/chainntnfs/bitcoindnotify/bitcoind.go +++ b/chainntnfs/bitcoindnotify/bitcoind.go @@ -15,7 +15,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/blockcache" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/queue" ) diff --git a/chainntnfs/btcdnotify/btcd.go b/chainntnfs/btcdnotify/btcd.go index c3a40a00bf..e3bff289cf 100644 --- a/chainntnfs/btcdnotify/btcd.go +++ b/chainntnfs/btcdnotify/btcd.go @@ -17,7 +17,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/blockcache" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/queue" ) diff --git a/chainntnfs/interface.go b/chainntnfs/interface.go index b2383636aa..1b8a5acb50 100644 --- a/chainntnfs/interface.go +++ b/chainntnfs/interface.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) var ( diff --git a/chainntnfs/mocks.go b/chainntnfs/mocks.go index d9ab9928d0..4a888b162e 100644 --- a/chainntnfs/mocks.go +++ b/chainntnfs/mocks.go @@ -3,7 +3,7 @@ package chainntnfs import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/stretchr/testify/mock" ) diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index edc422482e..a9a9ede704 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -23,7 +23,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs/btcdnotify" "github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" diff --git a/chanbackup/backup.go b/chanbackup/backup.go index 5853b37e45..afffe5a2e8 100644 --- a/chanbackup/backup.go +++ b/chanbackup/backup.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) // LiveChannelSource is an interface that allows us to query for the set of diff --git a/chanbackup/single.go b/chanbackup/single.go index b741320b07..01d14f6c07 100644 --- a/chanbackup/single.go +++ b/chanbackup/single.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnencrypt" "github.com/lightningnetwork/lnd/lnwire" diff --git a/chanbackup/single_test.go b/chanbackup/single_test.go index d2212bd859..0fe402926d 100644 --- a/chanbackup/single_test.go +++ b/chanbackup/single_test.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnencrypt" "github.com/lightningnetwork/lnd/lnwire" diff --git a/channeldb/channel.go b/channeldb/channel.go index 9ca57312aa..f4e99a6f8c 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -19,7 +19,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/walletdb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 2cac0baced..b1ca100eb3 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -18,7 +18,7 @@ import ( _ "github.com/btcsuite/btcwallet/walletdb/bdb" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" diff --git a/channeldb/migration/lnwire21/custom_records.go b/channeldb/migration/lnwire21/custom_records.go index f0f59185e9..7771c8ec8b 100644 --- a/channeldb/migration/lnwire21/custom_records.go +++ b/channeldb/migration/lnwire21/custom_records.go @@ -6,7 +6,7 @@ import ( "io" "sort" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) @@ -163,9 +163,12 @@ func (c CustomRecords) SerializeTo(w io.Writer) error { // ProduceRecordsSorted converts a slice of record producers into a slice of // records and then sorts it by type. func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record { - records := fn.Map(func(producer tlv.RecordProducer) tlv.Record { - return producer.Record() - }, recordProducers) + records := fn.Map( + recordProducers, + func(producer tlv.RecordProducer) tlv.Record { + return producer.Record() + }, + ) // Ensure that the set of records are sorted before we attempt to // decode from the stream, to ensure they're canonical. @@ -196,9 +199,9 @@ func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record { // RecordsAsProducers converts a slice of records into a slice of record // producers. func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer { - return fn.Map(func(record tlv.Record) tlv.RecordProducer { + return fn.Map(records, func(record tlv.Record) tlv.RecordProducer { return &record - }, records) + }) } // EncodeRecords encodes the given records into a byte slice. diff --git a/channeldb/migration32/mission_control_store.go b/channeldb/migration32/mission_control_store.go index 3ac9d6114c..76463eb6ca 100644 --- a/channeldb/migration32/mission_control_store.go +++ b/channeldb/migration32/mission_control_store.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/wire" lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) @@ -371,7 +371,7 @@ func extractMCRoute(r *Route) *mcRoute { // extractMCHops extracts the Hop fields that MC actually uses from a slice of // Hops. func extractMCHops(hops []*Hop) mcHops { - return fn.Map(extractMCHop, hops) + return fn.Map(hops, extractMCHop) } // extractMCHop extracts the Hop fields that MC actually uses from a Hop. diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index 3abc73f81e..ea6eaf13f2 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -7,7 +7,7 @@ import ( "math" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index 4290552eee..2df6627e2c 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lnwire" diff --git a/cmd/commands/cmd_macaroon.go b/cmd/commands/cmd_macaroon.go index 15c29380a7..d7d6d5f9dc 100644 --- a/cmd/commands/cmd_macaroon.go +++ b/cmd/commands/cmd_macaroon.go @@ -10,7 +10,7 @@ import ( "strings" "unicode" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/macaroons" @@ -177,12 +177,15 @@ func bakeMacaroon(ctx *cli.Context) error { "%w", err) } - ops := fn.Map(func(p *lnrpc.MacaroonPermission) bakery.Op { - return bakery.Op{ - Entity: p.Entity, - Action: p.Action, - } - }, parsedPermissions) + ops := fn.Map( + parsedPermissions, + func(p *lnrpc.MacaroonPermission) bakery.Op { + return bakery.Op{ + Entity: p.Entity, + Action: p.Action, + } + }, + ) rawMacaroon, err = macaroons.BakeFromRootKey(macRootKey, ops) if err != nil { diff --git a/config_builder.go b/config_builder.go index 42650bb68b..42790a50c6 100644 --- a/config_builder.go +++ b/config_builder.go @@ -33,7 +33,7 @@ import ( "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/invoices" diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index b4d6877202..e482c4c713 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/sweep" ) diff --git a/contractcourt/breach_arbitrator.go b/contractcourt/breach_arbitrator.go index d59829b5e5..33bc7f7e33 100644 --- a/contractcourt/breach_arbitrator.go +++ b/contractcourt/breach_arbitrator.go @@ -15,7 +15,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" @@ -1537,9 +1537,9 @@ func (b *BreachArbitrator) createSweepTx( // outputs from the regular, BTC only outputs. So we only need one such // output, which'll carry the custom channel "valuables" from both the // breached commitment and HTLC outputs. - hasBlobs := fn.Any(func(i input.Input) bool { + hasBlobs := fn.Any(inputs, func(i input.Input) bool { return i.ResolutionBlob().IsSome() - }, inputs) + }) if hasBlobs { weightEstimate.AddP2TROutput() } @@ -1624,7 +1624,7 @@ func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit, // First, we'll add the extra sweep output if it exists, subtracting the // amount from the sweep amt. if b.cfg.AuxSweeper.IsSome() { - extraChangeOut.WhenResult(func(o sweep.SweepOutput) { + extraChangeOut.WhenOk(func(o sweep.SweepOutput) { sweepAmt -= o.Value txn.AddTxOut(&o.TxOut) @@ -1697,7 +1697,7 @@ func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit, return &justiceTxCtx{ justiceTx: txn, sweepAddr: pkScript, - extraTxOut: extraChangeOut.Option(), + extraTxOut: extraChangeOut.OkToSome(), fee: txFee, inputs: inputs, }, nil diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index 576009eda4..c387c21797 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -22,7 +22,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/channels" diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go index a0908ea3fa..7d199c5c28 100644 --- a/contractcourt/briefcase.go +++ b/contractcourt/briefcase.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go index 0f44db2abb..533d0eff78 100644 --- a/contractcourt/briefcase_test.go +++ b/contractcourt/briefcase_test.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 6d9b30d208..646d68b869 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -14,7 +14,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index e79c8d546b..e29f21e7f4 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -18,7 +18,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" @@ -451,7 +451,7 @@ func (c *chainWatcher) handleUnknownLocalState( leaseExpiry = c.cfg.chanState.ThawHeight } - remoteAuxLeaf := fn.ChainOption( + remoteAuxLeaf := fn.FlatMapOption( func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { return l.RemoteAuxLeaf }, @@ -468,7 +468,7 @@ func (c *chainWatcher) handleUnknownLocalState( // Next, we'll derive our script that includes the revocation base for // the remote party allowing them to claim this output before the CSV // delay if we breach. - localAuxLeaf := fn.ChainOption( + localAuxLeaf := fn.FlatMapOption( func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { return l.LocalAuxLeaf }, @@ -1062,15 +1062,15 @@ func (c *chainWatcher) toSelfAmount(tx *wire.MsgTx) btcutil.Amount { return false } - return fn.Any(c.cfg.isOurAddr, addrs) + return fn.Any(addrs, c.cfg.isOurAddr) } // Grab all of the outputs that correspond with our delivery address // or our wallet is aware of. - outs := fn.Filter(fn.PredOr(isDeliveryOutput, isWalletOutput), tx.TxOut) + outs := fn.Filter(tx.TxOut, fn.PredOr(isDeliveryOutput, isWalletOutput)) // Grab the values for those outputs. - vals := fn.Map(func(o *wire.TxOut) int64 { return o.Value }, outs) + vals := fn.Map(outs, func(o *wire.TxOut) int64 { return o.Value }) // Return the sum. return btcutil.Amount(fn.Sum(vals)) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 319b437e4e..0be157d971 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -15,7 +15,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -997,7 +997,7 @@ func (c *ChannelArbitrator) stateStep( getIdx := func(htlc channeldb.HTLC) uint64 { return htlc.HtlcIndex } - dustHTLCSet := fn.NewSet(fn.Map(getIdx, dustHTLCs)...) + dustHTLCSet := fn.NewSet(fn.Map(dustHTLCs, getIdx)...) err = c.abandonForwards(dustHTLCSet) if err != nil { return StateError, closeTx, err @@ -1306,7 +1306,7 @@ func (c *ChannelArbitrator) stateStep( return htlc.HtlcIndex } remoteDangling := fn.NewSet(fn.Map( - getIdx, htlcActions[HtlcFailDanglingAction], + htlcActions[HtlcFailDanglingAction], getIdx, )...) err := c.abandonForwards(remoteDangling) if err != nil { diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 92ad608eb9..02e4b347c2 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -16,7 +16,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 4b47a34294..6019a0dbc6 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/sweep" diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 53f4f680d0..f5a88f24e6 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) var ( diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 73841eb88c..e5be63cbf7 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" diff --git a/contractcourt/htlc_lease_resolver.go b/contractcourt/htlc_lease_resolver.go index 53fa893553..9c5da6ee49 100644 --- a/contractcourt/htlc_lease_resolver.go +++ b/contractcourt/htlc_lease_resolver.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 2466544c98..1303d0af60 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet" ) diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index b2716ad305..9d09f844dc 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -12,7 +12,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index 23023729fa..c0206d8f14 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -12,7 +12,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 9954c3c0db..545e7c6135 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -12,7 +12,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index f3f23c385c..0e4f1336c2 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go index aef906a0ad..a870683746 100644 --- a/contractcourt/utxonursery.go +++ b/contractcourt/utxonursery.go @@ -15,7 +15,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" diff --git a/contractcourt/utxonursery_test.go b/contractcourt/utxonursery_test.go index 796d1ed239..f1b47cc2ca 100644 --- a/contractcourt/utxonursery_test.go +++ b/contractcourt/utxonursery_test.go @@ -18,7 +18,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 41e58c404e..7db67c39ea 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -19,7 +19,7 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 85a4e0657e..b74f69bf0a 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -24,7 +24,7 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" diff --git a/funding/aux_funding.go b/funding/aux_funding.go index 492612145a..c7ef653f47 100644 --- a/funding/aux_funding.go +++ b/funding/aux_funding.go @@ -2,7 +2,7 @@ package funding import ( "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/msgmux" diff --git a/funding/manager.go b/funding/manager.go index c8a54d9588..395cccb2a6 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -23,7 +23,7 @@ import ( "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/discovery" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" diff --git a/funding/manager_test.go b/funding/manager_test.go index 525f69f9a5..b6130176d1 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -27,7 +27,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/discovery" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" diff --git a/go.mod b/go.mod index 1330f9a84a..bbb421de40 100644 --- a/go.mod +++ b/go.mod @@ -36,13 +36,13 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 - github.com/lightningnetwork/lnd/fn v1.2.5 + github.com/lightningnetwork/lnd/fn/v2 v2.0.2 github.com/lightningnetwork/lnd/healthcheck v1.2.6 github.com/lightningnetwork/lnd/kvdb v1.4.11 github.com/lightningnetwork/lnd/queue v1.1.1 github.com/lightningnetwork/lnd/sqldb v1.0.5 github.com/lightningnetwork/lnd/ticker v1.1.1 - github.com/lightningnetwork/lnd/tlv v1.2.6 + github.com/lightningnetwork/lnd/tlv v1.3.0 github.com/lightningnetwork/lnd/tor v1.1.4 github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 github.com/miekg/dns v1.1.43 diff --git a/go.sum b/go.sum index aa04dc5fce..4c452df735 100644 --- a/go.sum +++ b/go.sum @@ -456,8 +456,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ= -github.com/lightningnetwork/lnd/fn v1.2.5 h1:pGMz0BDUxrhvOtShD4FIysdVy+ulfFAnFvTKjZO5Pp8= -github.com/lightningnetwork/lnd/fn v1.2.5/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0= +github.com/lightningnetwork/lnd/fn/v2 v2.0.2 h1:M7o2lYrh/zCp+lntPB3WP/rWTu5U+4ssyHW+kqNJ0fs= +github.com/lightningnetwork/lnd/fn/v2 v2.0.2/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s= github.com/lightningnetwork/lnd/healthcheck v1.2.6 h1:1sWhqr93GdkWy4+6U7JxBfcyZIE78MhIHTJZfPx7qqI= github.com/lightningnetwork/lnd/healthcheck v1.2.6/go.mod h1:Mu02um4CWY/zdTOvFje7WJgJcHyX2zq/FG3MhOAiGaQ= github.com/lightningnetwork/lnd/kvdb v1.4.11 h1:fk1HMVFrsVK3xqU7q+JWHRgBltw/a2qIg1E3zazMb/8= @@ -468,8 +468,8 @@ github.com/lightningnetwork/lnd/sqldb v1.0.5 h1:ax5vBPf44tN/uD6C5+hBPBjOJ7cRMrUL github.com/lightningnetwork/lnd/sqldb v1.0.5/go.mod h1:OG09zL/PHPaBJefp4HsPz2YLUJ+zIQHbpgCtLnOx8I4= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= -github.com/lightningnetwork/lnd/tlv v1.2.6 h1:icvQG2yDr6k3ZuZzfRdG3EJp6pHurcuh3R6dg0gv/Mw= -github.com/lightningnetwork/lnd/tlv v1.2.6/go.mod h1:/CmY4VbItpOldksocmGT4lxiJqRP9oLxwSZOda2kzNQ= +github.com/lightningnetwork/lnd/tlv v1.3.0 h1:exS/KCPEgpOgviIttfiXAPaUqw2rHQrnUOpP7HPBPiY= +github.com/lightningnetwork/lnd/tlv v1.3.0/go.mod h1:pJuiBj1ecr1WWLOtcZ+2+hu9Ey25aJWFIsjmAoPPnmc= github.com/lightningnetwork/lnd/tor v1.1.4 h1:TUW27EXqoZCcCAQPlD4aaDfh8jMbBS9CghNz50qqwtA= github.com/lightningnetwork/lnd/tor v1.1.4/go.mod h1:qSRB8llhAK+a6kaTPWOLLXSZc6Hg8ZC0mq1sUQ/8JfI= github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 h1:sjOGyegMIhvgfq5oaue6Td+hxZuf3tDC8lAPrFldqFw= diff --git a/graph/builder.go b/graph/builder.go index c0133e02ec..d6984af709 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -16,7 +16,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" diff --git a/graph/db/models/channel_edge_info.go b/graph/db/models/channel_edge_info.go index 0f91e2bbec..6aa67acc6a 100644 --- a/graph/db/models/channel_edge_info.go +++ b/graph/db/models/channel_edge_info.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) // ChannelEdgeInfo represents a fully authenticated channel along with all its diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index c48436173f..6414c9f802 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -8,7 +8,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index d8f55afc69..7763a4a751 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 60062862ef..214144ac19 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -16,7 +16,7 @@ import ( "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 80632b07e9..4e5c9478a1 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -26,7 +26,7 @@ import ( "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index ce791bef32..1d149fb0bf 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -24,7 +24,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" diff --git a/htlcswitch/quiescer.go b/htlcswitch/quiescer.go index 27d0deb8c6..468ad5e708 100644 --- a/htlcswitch/quiescer.go +++ b/htlcswitch/quiescer.go @@ -6,7 +6,7 @@ import ( "time" "github.com/btcsuite/btclog/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/htlcswitch/quiescer_test.go b/htlcswitch/quiescer_test.go index da08909d57..6ce9563e45 100644 --- a/htlcswitch/quiescer_test.go +++ b/htlcswitch/quiescer_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/require" diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 1a08275ec9..3e2e9a52dd 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -17,7 +17,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/kvdb" diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index abfb8e4d5b..8809321460 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -17,7 +17,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" diff --git a/input/input.go b/input/input.go index 088b20401f..4a9a4b55c0 100644 --- a/input/input.go +++ b/input/input.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/input/mocks.go b/input/mocks.go index bbd4550c5f..6d90bc28df 100644 --- a/input/mocks.go +++ b/input/mocks.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/tlv" diff --git a/input/script_utils.go b/input/script_utils.go index 91ca55292f..000efe9585 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "golang.org/x/crypto/ripemd160" diff --git a/input/taproot.go b/input/taproot.go index 2ca6e97236..5ca4dd0c66 100644 --- a/input/taproot.go +++ b/input/taproot.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/waddrmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) const ( diff --git a/input/taproot_test.go b/input/taproot_test.go index a1259be196..3a1e000374 100644 --- a/input/taproot_test.go +++ b/input/taproot_test.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" diff --git a/intercepted_forward.go b/intercepted_forward.go index 791d4bd583..5cb1ca192b 100644 --- a/intercepted_forward.go +++ b/intercepted_forward.go @@ -3,7 +3,7 @@ package lnd import ( "errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/invoices/modification_interceptor.go b/invoices/modification_interceptor.go index 97e75e8cc5..58f5b63d07 100644 --- a/invoices/modification_interceptor.go +++ b/invoices/modification_interceptor.go @@ -5,7 +5,7 @@ import ( "fmt" "sync/atomic" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) var ( diff --git a/itest/lnd_funding_test.go b/itest/lnd_funding_test.go index 54180abf57..0b08da32b3 100644 --- a/itest/lnd_funding_test.go +++ b/itest/lnd_funding_test.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainreg" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" diff --git a/itest/lnd_sweep_test.go b/itest/lnd_sweep_test.go index 099014aff0..158e8768f9 100644 --- a/itest/lnd_sweep_test.go +++ b/itest/lnd_sweep_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" @@ -1119,9 +1119,9 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // The sweeping tx has two inputs, one from wallet, the other // from the force close tx. We now check whether the first tx // spends from the force close tx of Alice->Bob. - found := fn.Any(func(inp *wire.TxIn) bool { + found := fn.Any(txns[0].TxIn, func(inp *wire.TxIn) bool { return inp.PreviousOutPoint.Hash == abCloseTxid - }, txns[0].TxIn) + }) // If the first tx spends an outpoint from the force close tx // of Alice->Bob, then it must be the incoming HTLC sweeping diff --git a/lnrpc/devrpc/dev_server.go b/lnrpc/devrpc/dev_server.go index 60f30dd7ed..b26b144c81 100644 --- a/lnrpc/devrpc/dev_server.go +++ b/lnrpc/devrpc/dev_server.go @@ -16,7 +16,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" diff --git a/lnrpc/marshall_utils.go b/lnrpc/marshall_utils.go index 230fea35b6..96d3342d83 100644 --- a/lnrpc/marshall_utils.go +++ b/lnrpc/marshall_utils.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/wallet" "github.com/lightningnetwork/lnd/aliasmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "golang.org/x/exp/maps" @@ -221,12 +221,18 @@ func UnmarshallCoinSelectionStrategy(strategy CoinSelectionStrategy, // MarshalAliasMap converts a ScidAliasMap to its proto counterpart. This is // used in various RPCs that handle scid alias mappings. func MarshalAliasMap(scidMap aliasmgr.ScidAliasMap) []*AliasMap { - return fn.Map(func(base lnwire.ShortChannelID) *AliasMap { - return &AliasMap{ - BaseScid: base.ToUint64(), - Aliases: fn.Map(func(a lnwire.ShortChannelID) uint64 { - return a.ToUint64() - }, scidMap[base]), - } - }, maps.Keys(scidMap)) + return fn.Map( + maps.Keys(scidMap), + func(base lnwire.ShortChannelID) *AliasMap { + return &AliasMap{ + BaseScid: base.ToUint64(), + Aliases: fn.Map( + scidMap[base], + func(a lnwire.ShortChannelID) uint64 { + return a.ToUint64() + }, + ), + } + }, + ) } diff --git a/lnrpc/routerrpc/forward_interceptor.go b/lnrpc/routerrpc/forward_interceptor.go index 9da831ac04..72df3d0199 100644 --- a/lnrpc/routerrpc/forward_interceptor.go +++ b/lnrpc/routerrpc/forward_interceptor.go @@ -3,7 +3,7 @@ package routerrpc import ( "errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnrpc" diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 9421e991b6..7d73681094 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -16,7 +16,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 7f1a7edf07..9499fa25a3 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -16,7 +16,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnrpc/walletrpc/walletkit_server.go b/lnrpc/walletrpc/walletkit_server.go index c6dec6fbd5..4f477cdbd4 100644 --- a/lnrpc/walletrpc/walletkit_server.go +++ b/lnrpc/walletrpc/walletkit_server.go @@ -31,7 +31,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/labels" @@ -1145,9 +1145,9 @@ func (w *WalletKit) getWaitingCloseChannel( return nil, err } - channel := fn.Find(func(c *channeldb.OpenChannel) bool { + channel := fn.Find(chans, func(c *channeldb.OpenChannel) bool { return c.FundingOutpoint == chanPoint - }, chans) + }) return channel.UnwrapOrErr(errors.New("channel not found")) } @@ -1231,18 +1231,23 @@ func (w *WalletKit) BumpForceCloseFee(_ context.Context, pendingSweeps := maps.Values(inputsMap) // Discard everything except for the anchor sweeps. - anchors := fn.Filter(func(sweep *sweep.PendingInputResponse) bool { - // Only filter for anchor inputs because these are the only - // inputs which can be used to bump a closed unconfirmed - // commitment transaction. - if sweep.WitnessType != input.CommitmentAnchor && - sweep.WitnessType != input.TaprootAnchorSweepSpend { - - return false - } + anchors := fn.Filter( + pendingSweeps, + func(sweep *sweep.PendingInputResponse) bool { + // Only filter for anchor inputs because these are the + // only inputs which can be used to bump a closed + // unconfirmed commitment transaction. + isCommitAnchor := sweep.WitnessType == + input.CommitmentAnchor + isTaprootSweepSpend := sweep.WitnessType == + input.TaprootAnchorSweepSpend + if !isCommitAnchor && !isTaprootSweepSpend { + return false + } - return commitSet.Contains(sweep.OutPoint.Hash) - }, pendingSweeps) + return commitSet.Contains(sweep.OutPoint.Hash) + }, + ) if len(anchors) == 0 { return nil, fmt.Errorf("unable to find pending anchor outputs") @@ -1754,7 +1759,7 @@ func (w *WalletKit) fundPsbtInternalWallet(account string, return true } - eligibleUtxos := fn.Filter(filterFn, utxos) + eligibleUtxos := fn.Filter(utxos, filterFn) // Validate all inputs against our known list of UTXOs // now. diff --git a/lntest/harness.go b/lntest/harness.go index f96a3aadd7..8e8fcd3936 100644 --- a/lntest/harness.go +++ b/lntest/harness.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb/etcd" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" diff --git a/lntest/harness_assertion.go b/lntest/harness_assertion.go index 1b079fea16..11cbefdd5c 100644 --- a/lntest/harness_assertion.go +++ b/lntest/harness_assertion.go @@ -19,7 +19,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" @@ -270,7 +270,7 @@ func (h *HarnessTest) AssertNumActiveEdges(hn *node.HarnessNode, IncludeUnannounced: includeUnannounced, } resp := hn.RPC.DescribeGraph(req) - activeEdges := fn.Filter(filterDisabled, resp.Edges) + activeEdges := fn.Filter(resp.Edges, filterDisabled) total := len(activeEdges) if total-old == expected { diff --git a/lntest/miner/miner.go b/lntest/miner/miner.go index e9e380bbb3..0229d6a47f 100644 --- a/lntest/miner/miner.go +++ b/lntest/miner/miner.go @@ -17,7 +17,7 @@ import ( "github.com/btcsuite/btcd/integration/rpctest" "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntest/node" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/stretchr/testify/require" @@ -296,10 +296,7 @@ func (h *HarnessMiner) AssertTxInMempool(txid chainhash.Hash) *wire.MsgTx { return fmt.Errorf("empty mempool") } - isEqual := func(memTx chainhash.Hash) bool { - return memTx == txid - } - result := fn.Find(isEqual, mempool) + result := fn.Find(mempool, fn.Eq(txid)) if result.IsNone() { return fmt.Errorf("txid %v not found in "+ diff --git a/lntest/mock/walletcontroller.go b/lntest/mock/walletcontroller.go index 8b7ef55380..fa623bf84d 100644 --- a/lntest/mock/walletcontroller.go +++ b/lntest/mock/walletcontroller.go @@ -16,7 +16,7 @@ import ( base "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) diff --git a/lntest/node/state.go b/lntest/node/state.go index a89ab7d2cc..38f02f3a4c 100644 --- a/lntest/node/state.go +++ b/lntest/node/state.go @@ -7,7 +7,7 @@ import ( "time" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/walletrpc" "github.com/lightningnetwork/lnd/lntest/rpc" @@ -324,11 +324,11 @@ func (s *State) updateEdgeStats() { req := &lnrpc.ChannelGraphRequest{IncludeUnannounced: true} resp := s.rpc.DescribeGraph(req) - s.Edge.Total = len(fn.Filter(filterDisabled, resp.Edges)) + s.Edge.Total = len(fn.Filter(resp.Edges, filterDisabled)) req = &lnrpc.ChannelGraphRequest{IncludeUnannounced: false} resp = s.rpc.DescribeGraph(req) - s.Edge.Public = len(fn.Filter(filterDisabled, resp.Edges)) + s.Edge.Public = len(fn.Filter(resp.Edges, filterDisabled)) } // updateWalletBalance creates stats for the node's wallet balance. diff --git a/lnwallet/aux_leaf_store.go b/lnwallet/aux_leaf_store.go index c457a92509..28a78e09db 100644 --- a/lnwallet/aux_leaf_store.go +++ b/lnwallet/aux_leaf_store.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/lnwallet/aux_resolutions.go b/lnwallet/aux_resolutions.go index 382232640d..b36e2d6368 100644 --- a/lnwallet/aux_resolutions.go +++ b/lnwallet/aux_resolutions.go @@ -4,7 +4,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" diff --git a/lnwallet/aux_signer.go b/lnwallet/aux_signer.go index 01abe1aae3..510b64b5d1 100644 --- a/lnwallet/aux_signer.go +++ b/lnwallet/aux_signer.go @@ -2,7 +2,7 @@ package lnwallet import ( "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/lnwallet/btcwallet/btcwallet.go b/lnwallet/btcwallet/btcwallet.go index 5d28574cbe..b9a909fbd3 100644 --- a/lnwallet/btcwallet/btcwallet.go +++ b/lnwallet/btcwallet/btcwallet.go @@ -27,7 +27,7 @@ import ( "github.com/btcsuite/btcwallet/wtxmgr" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/blockcache" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" diff --git a/lnwallet/chainfee/filtermanager.go b/lnwallet/chainfee/filtermanager.go index 26fa56aef1..2d6fd0a2e1 100644 --- a/lnwallet/chainfee/filtermanager.go +++ b/lnwallet/chainfee/filtermanager.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/rpcclient" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) const ( diff --git a/lnwallet/chancloser/aux_closer.go b/lnwallet/chancloser/aux_closer.go index 8b1c445ca3..62f475dd43 100644 --- a/lnwallet/chancloser/aux_closer.go +++ b/lnwallet/chancloser/aux_closer.go @@ -4,7 +4,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go index 17112b29e0..398a8a9f3e 100644 --- a/lnwallet/chancloser/chancloser.go +++ b/lnwallet/chancloser/chancloser.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 28709fd5f8..fe71fe5e3b 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/chancloser/interface.go b/lnwallet/chancloser/interface.go index 729cdc545b..f774c81039 100644 --- a/lnwallet/chancloser/interface.go +++ b/lnwallet/chancloser/interface.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/lnwallet/chanfunding/canned_assembler.go b/lnwallet/chanfunding/canned_assembler.go index b3457f21bf..e28cbb96d1 100644 --- a/lnwallet/chanfunding/canned_assembler.go +++ b/lnwallet/chanfunding/canned_assembler.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" ) diff --git a/lnwallet/chanfunding/interface.go b/lnwallet/chanfunding/interface.go index 3512b32ff9..e40c4a1157 100644 --- a/lnwallet/chanfunding/interface.go +++ b/lnwallet/chanfunding/interface.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) diff --git a/lnwallet/chanfunding/psbt_assembler.go b/lnwallet/chanfunding/psbt_assembler.go index f678f520fc..dd1bedd05a 100644 --- a/lnwallet/chanfunding/psbt_assembler.go +++ b/lnwallet/chanfunding/psbt_assembler.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" ) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 7f70e600c6..d190acdf5e 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -25,7 +25,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -600,7 +600,7 @@ func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, htlc := htlc - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves if htlc.Incoming { @@ -1106,7 +1106,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, feeRate, wireMsg.Amount.ToSatoshis(), remoteDustLimit, ) if !isDustRemote { - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves return leaves[pd.HtlcIndex].AuxTapLeaf @@ -2088,7 +2088,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, // Since it is the remote breach we are reconstructing, the output // going to us will be a to-remote script with our local params. - remoteAuxLeaf := fn.ChainOption( + remoteAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.RemoteAuxLeaf }, @@ -2102,7 +2102,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, return nil, err } - localAuxLeaf := fn.ChainOption( + localAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.LocalAuxLeaf }, @@ -2229,7 +2229,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - br.LocalResolutionBlob = resolveBlob.Option() + br.LocalResolutionBlob = resolveBlob.OkToSome() } // Similarly, if their balance exceeds the remote party's dust limit, @@ -2308,7 +2308,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - br.RemoteResolutionBlob = resolveBlob.Option() + br.RemoteResolutionBlob = resolveBlob.OkToSome() } // Finally, with all the necessary data constructed, we can pad the @@ -2338,7 +2338,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // We'll generate the original second level witness script now, as // we'll need it if we're revoking an HTLC output on the remote // commitment transaction, and *they* go to the second level. - secondLevelAuxLeaf := fn.ChainOption( + secondLevelAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] { return fn.MapOption(func(val uint16) input.AuxTapLeaf { idx := input.HtlcIndex(val) @@ -2366,7 +2366,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // HTLC script. Otherwise, is this was an outgoing HTLC that we sent, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. - htlcLeaf := fn.ChainOption( + htlcLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] { return fn.MapOption(func(val uint16) input.AuxTapLeaf { idx := input.HtlcIndex(val) @@ -2693,13 +2693,13 @@ type HtlcView struct { // AuxOurUpdates returns the outgoing HTLCs as a read-only copy of // AuxHtlcDescriptors. func (v *HtlcView) AuxOurUpdates() []AuxHtlcDescriptor { - return fn.Map(newAuxHtlcDescriptor, v.Updates.Local) + return fn.Map(v.Updates.Local, newAuxHtlcDescriptor) } // AuxTheirUpdates returns the incoming HTLCs as a read-only copy of // AuxHtlcDescriptors. func (v *HtlcView) AuxTheirUpdates() []AuxHtlcDescriptor { - return fn.Map(newAuxHtlcDescriptor, v.Updates.Remote) + return fn.Map(v.Updates.Remote, newAuxHtlcDescriptor) } // fetchHTLCView returns all the candidate HTLC updates which should be @@ -2917,9 +2917,9 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, // The fee rate of our view is always the last UpdateFee message from // the channel's OpeningParty. openerUpdates := view.Updates.GetForParty(lc.channelState.Initiator()) - feeUpdates := fn.Filter(func(u *paymentDescriptor) bool { + feeUpdates := fn.Filter(openerUpdates, func(u *paymentDescriptor) bool { return u.EntryType == FeeUpdate - }, openerUpdates) + }) lastFeeUpdate := fn.Last(feeUpdates) lastFeeUpdate.WhenSome(func(pd *paymentDescriptor) { newView.FeePerKw = chainfee.SatPerKWeight( @@ -2942,14 +2942,17 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, for _, party := range parties { // First we run through non-add entries in both logs, // populating the skip sets. - resolutions := fn.Filter(func(pd *paymentDescriptor) bool { - switch pd.EntryType { - case Settle, Fail, MalformedFail: - return true - default: - return false - } - }, view.Updates.GetForParty(party)) + resolutions := fn.Filter( + view.Updates.GetForParty(party), + func(pd *paymentDescriptor) bool { + switch pd.EntryType { + case Settle, Fail, MalformedFail: + return true + default: + return false + } + }, + ) for _, entry := range resolutions { addEntry, err := lc.fetchParent( @@ -3002,10 +3005,16 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. for _, party := range parties { - liveAdds := fn.Filter(func(pd *paymentDescriptor) bool { - return pd.EntryType == Add && - !skip.GetForParty(party).Contains(pd.HtlcIndex) - }, view.Updates.GetForParty(party)) + liveAdds := fn.Filter( + view.Updates.GetForParty(party), + func(pd *paymentDescriptor) bool { + isAdd := pd.EntryType == Add + shouldSkip := skip.GetForParty(party). + Contains(pd.HtlcIndex) + + return isAdd && !shouldSkip + }, + ) for _, entry := range liveAdds { // Skip the entries that have already had their add @@ -3063,7 +3072,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, uncommittedUpdates := lntypes.MapDual( view.Updates, func(us []*paymentDescriptor) []*paymentDescriptor { - return fn.Filter(isUncommitted, us) + return fn.Filter(us, isUncommitted) }, ) @@ -3189,7 +3198,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.IncomingHtlcLeaves return leaves[htlc.HtlcIndex].SecondLevelLeaf @@ -3270,7 +3279,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves return leaves[htlc.HtlcIndex].SecondLevelLeaf @@ -4802,7 +4811,7 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee - auxLeaf := fn.ChainOption(func( + auxLeaf := fn.FlatMapOption(func( l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.IncomingHtlcLeaves @@ -4895,7 +4904,7 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee - auxLeaf := fn.ChainOption(func( + auxLeaf := fn.FlatMapOption(func( l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves @@ -6766,7 +6775,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, //nolint:funlen // Before we can generate the proper sign descriptor, we'll need to // locate the output index of our non-delayed output on the commitment // transaction. - remoteAuxLeaf := fn.ChainOption( + remoteAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.RemoteAuxLeaf }, @@ -6870,7 +6879,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, //nolint:funlen return nil, fmt.Errorf("unable to aux resolve: %w", err) } - commitResolution.ResolutionBlob = resolveBlob.Option() + commitResolution.ResolutionBlob = resolveBlob.OkToSome() } closeSummary := channeldb.ChannelCloseSummary{ @@ -7059,7 +7068,7 @@ func newOutgoingHtlcResolution(signer input.Signer, // First, we'll re-generate the script used to send the HTLC to the // remote party within their commitment transaction. - auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + auxLeaf := fn.FlatMapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { return l.OutgoingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf })(auxLeaves) htlcScriptInfo, err := genHtlcScript( @@ -7149,7 +7158,7 @@ func newOutgoingHtlcResolution(signer input.Signer, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - resolutionBlob := resolveRes.Option() + resolutionBlob := resolveRes.OkToSome() return &OutgoingHtlcResolution{ Expiry: htlc.RefundTimeout, @@ -7171,7 +7180,7 @@ func newOutgoingHtlcResolution(signer input.Signer, // With the fee calculated, re-construct the second level timeout // transaction. - secondLevelAuxLeaf := fn.ChainOption( + secondLevelAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves return leaves[htlc.HtlcIndex].SecondLevelLeaf @@ -7366,7 +7375,7 @@ func newOutgoingHtlcResolution(signer input.Signer, if err := resolveRes.Err(); err != nil { return nil, fmt.Errorf("unable to aux resolve: %w", err) } - resolutionBlob := resolveRes.Option() + resolutionBlob := resolveRes.OkToSome() return &OutgoingHtlcResolution{ Expiry: htlc.RefundTimeout, @@ -7406,7 +7415,7 @@ func newIncomingHtlcResolution(signer input.Signer, // First, we'll re-generate the script the remote party used to // send the HTLC to us in their commitment transaction. - auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + auxLeaf := fn.FlatMapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { return l.IncomingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf })(auxLeaves) scriptInfo, err := genHtlcScript( @@ -7497,7 +7506,7 @@ func newIncomingHtlcResolution(signer input.Signer, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - resolutionBlob := resolveRes.Option() + resolutionBlob := resolveRes.OkToSome() return &IncomingHtlcResolution{ ClaimOutpoint: op, @@ -7507,7 +7516,7 @@ func newIncomingHtlcResolution(signer input.Signer, }, nil } - secondLevelAuxLeaf := fn.ChainOption( + secondLevelAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.IncomingHtlcLeaves return leaves[htlc.HtlcIndex].SecondLevelLeaf @@ -7707,7 +7716,7 @@ func newIncomingHtlcResolution(signer input.Signer, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - resolutionBlob := resolveRes.Option() + resolutionBlob := resolveRes.OkToSome() return &IncomingHtlcResolution{ SignedSuccessTx: successTx, @@ -8011,7 +8020,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, leaseExpiry = chanState.ThawHeight } - localAuxLeaf := fn.ChainOption( + localAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.LocalAuxLeaf }, @@ -8126,7 +8135,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - commitResolution.ResolutionBlob = resolveBlob.Option() + commitResolution.ResolutionBlob = resolveBlob.OkToSome() } // Once the delay output has been found (if it exists), then we'll also diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index f7ecd32277..d0caa97812 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -25,7 +25,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" @@ -730,9 +730,12 @@ func TestCommitHTLCSigCustomRecordSize(t *testing.T) { // Replace the default PackSigs implementation to return a // large custom records blob. - mockSigner.ExpectedCalls = fn.Filter(func(c *mock.Call) bool { - return c.Method != "PackSigs" - }, mockSigner.ExpectedCalls) + mockSigner.ExpectedCalls = fn.Filter( + mockSigner.ExpectedCalls, + func(c *mock.Call) bool { + return c.Method != "PackSigs" + }, + ) mockSigner.On("PackSigs", mock.Anything). Return(fn.Ok(fn.Some(largeBlob))) }) diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 8b364a01df..787e8a71e1 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -836,7 +836,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf { return leaves[htlc.HtlcIndex].AuxTapLeaf }, @@ -864,7 +864,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf { return leaves[htlc.HtlcIndex].AuxTapLeaf }, @@ -1323,7 +1323,7 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // Compute the to_local script. From our PoV, when facing a remote // commitment, the to_local output belongs to them. - localAuxLeaf := fn.ChainOption( + localAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.LocalAuxLeaf }, @@ -1338,7 +1338,7 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // Compute the to_remote script. From our PoV, when facing a remote // commitment, the to_remote output belongs to us. - remoteAuxLeaf := fn.ChainOption( + remoteAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.RemoteAuxLeaf }, diff --git a/lnwallet/commitment_chain.go b/lnwallet/commitment_chain.go index fa2abe0aa2..871a139c5c 100644 --- a/lnwallet/commitment_chain.go +++ b/lnwallet/commitment_chain.go @@ -1,7 +1,7 @@ package lnwallet import ( - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) // commitmentChain represents a chain of unrevoked commitments. The tail of the diff --git a/lnwallet/config.go b/lnwallet/config.go index 425fe15dad..c60974be6d 100644 --- a/lnwallet/config.go +++ b/lnwallet/config.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/lnwallet/interface.go b/lnwallet/interface.go index c9dee9202a..64f8546310 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -19,7 +19,7 @@ import ( base "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chanvalidate" diff --git a/lnwallet/mock.go b/lnwallet/mock.go index a8610dc779..39e520d276 100644 --- a/lnwallet/mock.go +++ b/lnwallet/mock.go @@ -18,7 +18,7 @@ import ( "github.com/btcsuite/btcwallet/wtxmgr" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/tlv" diff --git a/lnwallet/musig_session.go b/lnwallet/musig_session.go index 822aa48a14..748e5fa958 100644 --- a/lnwallet/musig_session.go +++ b/lnwallet/musig_session.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" diff --git a/lnwallet/reservation.go b/lnwallet/reservation.go index fd35d95076..a8a0cacd4b 100644 --- a/lnwallet/reservation.go +++ b/lnwallet/reservation.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/rpcwallet/rpcwallet.go b/lnwallet/rpcwallet/rpcwallet.go index bf6aa61df3..426712b597 100644 --- a/lnwallet/rpcwallet/rpcwallet.go +++ b/lnwallet/rpcwallet/rpcwallet.go @@ -22,7 +22,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/waddrmgr" basewallet "github.com/btcsuite/btcwallet/wallet" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lncfg" diff --git a/lnwallet/test/test_interface.go b/lnwallet/test/test_interface.go index c006aa2e50..27de51708c 100644 --- a/lnwallet/test/test_interface.go +++ b/lnwallet/test/test_interface.go @@ -34,7 +34,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs/btcdnotify" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index ff9adfbd79..738558e224 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -16,7 +16,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index 135d1866bc..38131eaa72 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -21,7 +21,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/update_log.go b/lnwallet/update_log.go index 2d1f65c9fa..b2b8af58d1 100644 --- a/lnwallet/update_log.go +++ b/lnwallet/update_log.go @@ -1,7 +1,7 @@ package lnwallet import ( - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) // updateLog is an append-only log that stores updates to a node's commitment diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index ad6354e2e8..96ea85cf9e 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -23,7 +23,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index e523279498..577379623f 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -5,7 +5,7 @@ import ( "io" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/custom_records.go b/lnwire/custom_records.go index 8177cbe821..a63aa5dfb0 100644 --- a/lnwire/custom_records.go +++ b/lnwire/custom_records.go @@ -6,7 +6,7 @@ import ( "io" "sort" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) @@ -179,9 +179,12 @@ func (c CustomRecords) SerializeTo(w io.Writer) error { // ProduceRecordsSorted converts a slice of record producers into a slice of // records and then sorts it by type. func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record { - records := fn.Map(func(producer tlv.RecordProducer) tlv.Record { - return producer.Record() - }, recordProducers) + records := fn.Map( + recordProducers, + func(producer tlv.RecordProducer) tlv.Record { + return producer.Record() + }, + ) // Ensure that the set of records are sorted before we attempt to // decode from the stream, to ensure they're canonical. @@ -212,9 +215,9 @@ func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record { // RecordsAsProducers converts a slice of records into a slice of record // producers. func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer { - return fn.Map(func(record tlv.Record) tlv.RecordProducer { + return fn.Map(records, func(record tlv.Record) tlv.RecordProducer { return &record - }, records) + }) } // EncodeRecords encodes the given records into a byte slice. diff --git a/lnwire/custom_records_test.go b/lnwire/custom_records_test.go index 8ff6af10ba..d4aad2e546 100644 --- a/lnwire/custom_records_test.go +++ b/lnwire/custom_records_test.go @@ -4,7 +4,7 @@ import ( "bytes" "testing" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -182,9 +182,12 @@ func TestCustomRecordsExtendRecordProducers(t *testing.T) { func serializeRecordProducers(t *testing.T, producers []tlv.RecordProducer) []byte { - tlvRecords := fn.Map(func(p tlv.RecordProducer) tlv.Record { - return p.Record() - }, producers) + tlvRecords := fn.Map( + producers, + func(p tlv.RecordProducer) tlv.Record { + return p.Record() + }, + ) stream, err := tlv.NewStream(tlvRecords...) require.NoError(t, err) diff --git a/lnwire/dyn_ack.go b/lnwire/dyn_ack.go index 24f23a228d..d477461e7b 100644 --- a/lnwire/dyn_ack.go +++ b/lnwire/dyn_ack.go @@ -5,7 +5,7 @@ import ( "io" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/dyn_propose.go b/lnwire/dyn_propose.go index b0cc1198e9..394fff6f37 100644 --- a/lnwire/dyn_propose.go +++ b/lnwire/dyn_propose.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index c4ca260e1e..4681426cbb 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -5,7 +5,7 @@ import ( "fmt" "io" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 952e90a7e6..6bfbb465ec 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -22,7 +22,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tor" diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 5f05e1ef9f..7b65a85f4e 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -10,7 +10,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 9c39be6d5c..5c3d0291a5 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/stretchr/testify/require" ) diff --git a/msgmux/msg_router.go b/msgmux/msg_router.go index db9e783990..736c085a95 100644 --- a/msgmux/msg_router.go +++ b/msgmux/msg_router.go @@ -6,7 +6,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) @@ -91,8 +91,8 @@ func sendQueryErr[Q any](sendChan chan fn.Req[Q, error], queryArg Q, quitChan chan struct{}) error { return fn.ElimEither( - fn.Iden, fn.Iden, sendQuery(sendChan, queryArg, quitChan).Either, + fn.Iden, fn.Iden, ) } diff --git a/peer/brontide.go b/peer/brontide.go index 6bc49445ee..f8ac00aa20 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -26,7 +26,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" diff --git a/peer/brontide_test.go b/peer/brontide_test.go index c3d1bee48b..eded658887 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -13,7 +13,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/peer/musig_chan_closer.go b/peer/musig_chan_closer.go index 6f69a8c5b8..149ebcfa0c 100644 --- a/peer/musig_chan_closer.go +++ b/peer/musig_chan_closer.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chancloser" diff --git a/peer/test_utils.go b/peer/test_utils.go index eb510a53b1..34c42e2f7c 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -18,7 +18,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" diff --git a/protofsm/daemon_events.go b/protofsm/daemon_events.go index e5de0b6951..bca7283d39 100644 --- a/protofsm/daemon_events.go +++ b/protofsm/daemon_events.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/protofsm/msg_mapper.go b/protofsm/msg_mapper.go index b96d677e6b..5e24255fa3 100644 --- a/protofsm/msg_mapper.go +++ b/protofsm/msg_mapper.go @@ -1,7 +1,7 @@ package protofsm import ( - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go index b71d5efe42..a81f5746b2 100644 --- a/protofsm/state_machine.go +++ b/protofsm/state_machine.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" ) @@ -21,6 +21,12 @@ const ( pollInterval = time.Millisecond * 100 ) +var ( + // ErrStateMachineShutdown occurs when trying to feed an event to a + // StateMachine that has been asked to Stop. + ErrStateMachineShutdown = fmt.Errorf("StateMachine is shutting down") +) + // EmittedEvent is a special type that can be emitted by a state transition. // This can container internal events which are to be routed back to the state, // or external events which are to be sent to the daemon. @@ -287,7 +293,7 @@ func (s *StateMachine[Event, Env]) CurrentState() (State[Event, Env], error) { } if !fn.SendOrQuit(s.stateQuery, query, s.quit) { - return nil, fmt.Errorf("state machine is shutting down") + return nil, ErrStateMachineShutdown } return fn.RecvOrTimeout(query.CurrentState, time.Second) @@ -322,6 +328,8 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[ // executeDaemonEvent executes a daemon event, which is a special type of event // that can be emitted as part of the state transition function of the state // machine. An error is returned if the type of event is unknown. +// +//nolint:funlen func (s *StateMachine[Event, Env]) executeDaemonEvent( event DaemonEvent) error { @@ -347,7 +355,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // If a post-send event was specified, then we'll funnel // that back into the main state machine now as well. return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:ll - return s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(func(ctx context.Context) { log.Debugf("FSM(%v): sending "+ "post-send event: %v", s.cfg.Env.Name(), @@ -356,6 +364,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( s.SendEvent(event) }) + + if !launched { + return ErrStateMachineShutdown + } + + return nil }) } @@ -368,7 +382,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // Otherwise, this has a SendWhen predicate, so we'll need // launch a goroutine to poll the SendWhen, then send only once // the predicate is true. - return s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(func(ctx context.Context) { predicateTicker := time.NewTicker( s.cfg.CustomPollInterval.UnwrapOr(pollInterval), ) @@ -407,6 +421,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( } }) + if !launched { + return ErrStateMachineShutdown + } + + return nil + // If this is a broadcast transaction event, then we'll broadcast with // the label attached. case *BroadcastTxn: @@ -436,7 +456,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( return fmt.Errorf("unable to register spend: %w", err) } - return s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(func(ctx context.Context) { for { select { case spend, ok := <-spendEvent.Spend: @@ -461,6 +481,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( } }) + if !launched { + return ErrStateMachineShutdown + } + + return nil + // The state machine has requested a new event to be sent once a // specified txid+pkScript pair has confirmed. case *RegisterConf[Event]: @@ -476,7 +502,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( return fmt.Errorf("unable to register conf: %w", err) } - return s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(func(ctx context.Context) { for { select { case <-confEvent.Confirmed: @@ -498,6 +524,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( } } }) + + if !launched { + return ErrStateMachineShutdown + } + + return nil } return fmt.Errorf("unknown daemon event: %T", event) diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go index fc30fcefc3..fc7a4ccfdc 100644 --- a/protofsm/state_machine_test.go +++ b/protofsm/state_machine_test.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 12e82131dc..c816ed3410 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -3,7 +3,7 @@ package routing import ( "fmt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index 4872b5a7ec..7469bc84c6 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" diff --git a/routing/blinding.go b/routing/blinding.go index 7c84063469..0c27e87439 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" diff --git a/routing/blinding_test.go b/routing/blinding_test.go index 410dfaf643..8f83f7fd82 100644 --- a/routing/blinding_test.go +++ b/routing/blinding_test.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 315b0dff22..e4241dac53 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index d7380439ac..cd9e58fcaa 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/discovery" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" diff --git a/routing/missioncontrol.go b/routing/missioncontrol.go index de892392e7..3bc9be7aba 100644 --- a/routing/missioncontrol.go +++ b/routing/missioncontrol.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" diff --git a/routing/mock_test.go b/routing/mock_test.go index 3cdb5ebaf2..2575514102 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" diff --git a/routing/pathfind.go b/routing/pathfind.go index 8e40c5bc4b..80f4b1e68f 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/btcutil" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnutils" diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index da29c79a25..c463b8135b 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -21,7 +21,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 267ce3965d..292397d9df 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -10,7 +10,7 @@ import ( "github.com/davecgh/go-spew/spew" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 315c1bad58..98fe7ffd21 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lntest/wait" diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index d5f1a6af41..240f801e78 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -2,7 +2,7 @@ package routing import ( "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" diff --git a/routing/result_interpretation.go b/routing/result_interpretation.go index 089213d65e..bc1749dfb1 100644 --- a/routing/result_interpretation.go +++ b/routing/result_interpretation.go @@ -6,7 +6,7 @@ import ( "io" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -578,7 +578,7 @@ func extractMCRoute(r *route.Route) *mcRoute { // extractMCHops extracts the Hop fields that MC actually uses from a slice of // Hops. func extractMCHops(hops []*route.Hop) mcHops { - return fn.Map(extractMCHop, hops) + return fn.Map(hops, extractMCHop) } // extractMCHop extracts the Hop fields that MC actually uses from a Hop. diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index b213eb1835..8c67bdeea9 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) diff --git a/routing/router.go b/routing/router.go index 9eabe0b2ae..0b3c90c321 100644 --- a/routing/router.go +++ b/routing/router.go @@ -19,7 +19,7 @@ import ( "github.com/lightningnetwork/lnd/amp" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" diff --git a/routing/router_test.go b/routing/router_test.go index 2923f1fb90..0824146d1e 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -23,7 +23,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" diff --git a/rpcserver.go b/rpcserver.go index d7d2e0186c..72e2fa4afd 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -46,7 +46,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" @@ -8068,9 +8068,9 @@ func (r *rpcServer) VerifyChanBackup(ctx context.Context, } return &lnrpc.VerifyChanBackupResponse{ - ChanPoints: fn.Map(func(c chanbackup.Single) string { + ChanPoints: fn.Map(channels, func(c chanbackup.Single) string { return c.FundingOutpoint.String() - }, channels), + }), }, nil } diff --git a/rpcserver_test.go b/rpcserver_test.go index b4b66e719c..b686c9020a 100644 --- a/rpcserver_test.go +++ b/rpcserver_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnrpc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/server.go b/server.go index f8f8239ed6..abf898ca31 100644 --- a/server.go +++ b/server.go @@ -39,7 +39,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" diff --git a/subrpcserver_config.go b/subrpcserver_config.go index 30755c05e4..102e211187 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -11,7 +11,7 @@ import ( "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" diff --git a/sweep/aggregator.go b/sweep/aggregator.go index a0a1b0a540..e97ccb9a21 100644 --- a/sweep/aggregator.go +++ b/sweep/aggregator.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/sweep/aggregator_test.go b/sweep/aggregator_test.go index 6df0d73fa2..2cb89bdc38 100644 --- a/sweep/aggregator_test.go +++ b/sweep/aggregator_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index adb4db65ed..7bb58ae29e 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lntypes" @@ -145,13 +145,13 @@ type BumpRequest struct { func (r *BumpRequest) MaxFeeRateAllowed() (chainfee.SatPerKWeight, error) { // We'll want to know if we have any blobs, as we need to factor this // into the max fee rate for this bump request. - hasBlobs := fn.Any(func(i input.Input) bool { + hasBlobs := fn.Any(r.Inputs, func(i input.Input) bool { return fn.MapOptionZ( i.ResolutionBlob(), func(b tlv.Blob) bool { return len(b) > 0 }, ) - }, r.Inputs) + }) sweepAddrs := [][]byte{ r.DeliveryAddress.DeliveryAddress, @@ -1382,7 +1382,7 @@ func prepareSweepTx(inputs []input.Input, changePkScript lnwallet.AddrWithKey, return err } - extraChangeOut = extraOut.LeftToOption() + extraChangeOut = extraOut.LeftToSome() return nil }, diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 5030dee227..c9196aee5a 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/sweep/fee_function.go b/sweep/fee_function.go index cbf283e37d..bff44000be 100644 --- a/sweep/fee_function.go +++ b/sweep/fee_function.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/sweep/fee_function_test.go b/sweep/fee_function_test.go index c278bb7f06..a55ce79a78 100644 --- a/sweep/fee_function_test.go +++ b/sweep/fee_function_test.go @@ -3,7 +3,7 @@ package sweep import ( "testing" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/require" ) diff --git a/sweep/interface.go b/sweep/interface.go index f2fff84b08..6c8c2cfad2 100644 --- a/sweep/interface.go +++ b/sweep/interface.go @@ -4,7 +4,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 34202b1453..eeeb283969 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -4,7 +4,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 6257faac1f..9eeefc94b8 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 2b61f67933..7d99ba93b9 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index ce144a8eb3..adae7cf131 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -141,7 +141,7 @@ func validateInputs(inputs []SweeperInput, deadlineHeight int32) error { // dedupInputs is a set used to track unique outpoints of the inputs. dedupInputs := fn.NewSet( // Iterate all the inputs and map the function. - fn.Map(func(inp SweeperInput) wire.OutPoint { + fn.Map(inputs, func(inp SweeperInput) wire.OutPoint { // If the input has a deadline height, we'll check if // it's the same as the specified. inp.params.DeadlineHeight.WhenSome(func(h int32) { @@ -156,7 +156,7 @@ func validateInputs(inputs []SweeperInput, deadlineHeight int32) error { }) return inp.OutPoint() - }, inputs)..., + })..., ) // Make sure the inputs share the same deadline height when there is diff --git a/sweep/tx_input_set_test.go b/sweep/tx_input_set_test.go index 8d0850b20d..73f056a964 100644 --- a/sweep/tx_input_set_test.go +++ b/sweep/tx_input_set_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" diff --git a/sweep/walletsweep.go b/sweep/walletsweep.go index 81458fbfb0..3f790dc66f 100644 --- a/sweep/walletsweep.go +++ b/sweep/walletsweep.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/sweep/walletsweep_test.go b/sweep/walletsweep_test.go index 968d9cb4fb..c7a5dfc221 100644 --- a/sweep/walletsweep_test.go +++ b/sweep/walletsweep_test.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/watchtower/blob/justice_kit.go b/watchtower/blob/justice_kit.go index 7780239f07..9dc1af6258 100644 --- a/watchtower/blob/justice_kit.go +++ b/watchtower/blob/justice_kit.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" diff --git a/watchtower/blob/justice_kit_test.go b/watchtower/blob/justice_kit_test.go index a1d6ec9f2c..0d23e2e0fc 100644 --- a/watchtower/blob/justice_kit_test.go +++ b/watchtower/blob/justice_kit_test.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" diff --git a/watchtower/lookout/justice_descriptor_test.go b/watchtower/lookout/justice_descriptor_test.go index 5045b4a0f4..ded2cd6031 100644 --- a/watchtower/lookout/justice_descriptor_test.go +++ b/watchtower/lookout/justice_descriptor_test.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 7eb34f6e37..62d7609469 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index f3a4d5bf4e..e842876b65 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -19,7 +19,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go index 01a9fa01ef..7a39c8ff73 100644 --- a/watchtower/wtclient/manager.go +++ b/watchtower/wtclient/manager.go @@ -12,7 +12,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/subscribe" diff --git a/watchtower/wtdb/client_chan_summary.go b/watchtower/wtdb/client_chan_summary.go index 6fec34c842..9ab77377b4 100644 --- a/watchtower/wtdb/client_chan_summary.go +++ b/watchtower/wtdb/client_chan_summary.go @@ -3,7 +3,7 @@ package wtdb import ( "io" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index b6f6affce6..6e6adacc02 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -9,7 +9,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" diff --git a/zpay32/decode.go b/zpay32/decode.go index 61099cf2f0..76c2c1ecf4 100644 --- a/zpay32/decode.go +++ b/zpay32/decode.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/btcutil/bech32" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/zpay32/encode.go b/zpay32/encode.go index 3e2d799776..43ccd5ecb1 100644 --- a/zpay32/encode.go +++ b/zpay32/encode.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil/bech32" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/zpay32/invoice.go b/zpay32/invoice.go index 7c18253eb0..9c5d86ce2f 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/zpay32/invoice_test.go b/zpay32/invoice_test.go index a4753431e7..55718007db 100644 --- a/zpay32/invoice_test.go +++ b/zpay32/invoice_test.go @@ -17,7 +17,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/require" ) From 117c6bc7817b312767e4b88c750700bdfd0be8a7 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 3 Dec 2024 13:23:01 +0100 Subject: [PATCH 02/59] multi: move routing.TlvTrafficShaper => htlcswitch.AuxTrafficShaper With this commit we move the traffic shaper definition from the routing package to the HTLC switch package as a preparation for being able to use it there as well. At the same time we rename it to AuxTrafficShaper to be more in line with the other auxiliary components. --- config_builder.go | 4 +- htlcswitch/interfaces.go | 57 ++++++++++++++++++++-- htlcswitch/link.go | 42 ++++++++++++++++ htlcswitch/mock.go | 11 +++++ routing/bandwidth.go | 81 ++++--------------------------- routing/bandwidth_test.go | 4 +- routing/mock_test.go | 18 ++++++- routing/payment_lifecycle.go | 5 +- routing/payment_lifecycle_test.go | 4 +- routing/payment_session_source.go | 4 +- routing/router.go | 4 +- routing/router_test.go | 32 +++++++----- 12 files changed, 167 insertions(+), 99 deletions(-) diff --git a/config_builder.go b/config_builder.go index 42650bb68b..ddfc97dc2c 100644 --- a/config_builder.go +++ b/config_builder.go @@ -36,6 +36,7 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -47,7 +48,6 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/rpcwallet" "github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/msgmux" - "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sqldb" @@ -166,7 +166,7 @@ type AuxComponents struct { // TrafficShaper is an optional traffic shaper that can be used to // control the outgoing channel of a payment. - TrafficShaper fn.Option[routing.TlvTrafficShaper] + TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] // MsgRouter is an optional message router that if set will be used in // place of a new blank default message router. diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index d8f55afc69..4346339796 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -206,6 +206,11 @@ const ( Outgoing LinkDirection = true ) +// OptionalBandwidth is a type alias for the result of a bandwidth query that +// may return a bandwidth value or fn.None if the bandwidth is not available or +// not applicable. +type OptionalBandwidth = fn.Option[lnwire.MilliSatoshi] + // ChannelLink is an interface which represents the subsystem for managing the // incoming htlc requests, applying the changes to the channel, and also // propagating/forwarding it to htlc switch. @@ -284,8 +289,8 @@ type ChannelLink interface { // total sent/received milli-satoshis. Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) - // Peer returns the serialized public key of remote peer with which we - // have the channel link opened. + // PeerPubKey returns the serialized public key of remote peer with + // which we have the channel link opened. PeerPubKey() [33]byte // AttachMailBox delivers an active MailBox to the link. The MailBox may @@ -302,9 +307,18 @@ type ChannelLink interface { // commitment of the channel that this link is associated with. CommitmentCustomBlob() fn.Option[tlv.Blob] - // Start/Stop are used to initiate the start/stop of the channel link - // functioning. + // AuxBandwidth returns the bandwidth that can be used for a channel, + // expressed in milli-satoshi. This might be different from the regular + // BTC bandwidth for custom channels. This will always return fn.None() + // for a regular (non-custom) channel. + AuxBandwidth(amount lnwire.MilliSatoshi, cid lnwire.ShortChannelID, + htlcBlob fn.Option[tlv.Blob], + ts AuxTrafficShaper) fn.Result[OptionalBandwidth] + + // Start starts the channel link. Start() error + + // Stop requests the channel link to be shut down. Stop() } @@ -440,7 +454,7 @@ type htlcNotifier interface { NotifyForwardingEvent(key HtlcKey, info HtlcInfo, eventType HtlcEventType) - // NotifyIncomingLinkFailEvent notifies that a htlc has failed on our + // NotifyLinkFailEvent notifies that a htlc has failed on our // incoming link. It takes an isReceive bool to differentiate between // our node's receives and forwards. NotifyLinkFailEvent(key HtlcKey, info HtlcInfo, @@ -461,3 +475,36 @@ type htlcNotifier interface { NotifyFinalHtlcEvent(key models.CircuitKey, info channeldb.FinalHtlcInfo) } + +// AuxHtlcModifier is an interface that allows the sender to modify the outgoing +// HTLC of a payment by changing the amount or the wire message tlv records. +type AuxHtlcModifier interface { + // ProduceHtlcExtraData is a function that, based on the previous extra + // data blob of an HTLC, may produce a different blob or modify the + // amount of bitcoin this htlc should carry. + ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, + htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, + lnwire.CustomRecords, error) +} + +// AuxTrafficShaper is an interface that allows the sender to determine if a +// payment should be carried by a channel based on the TLV records that may be +// present in the `update_add_htlc` message or the channel commitment itself. +type AuxTrafficShaper interface { + AuxHtlcModifier + + // ShouldHandleTraffic is called in order to check if the channel + // identified by the provided channel ID may have external mechanisms + // that would allow it to carry out the payment. + ShouldHandleTraffic(cid lnwire.ShortChannelID, + fundingBlob fn.Option[tlv.Blob]) (bool, error) + + // PaymentBandwidth returns the available bandwidth for a custom channel + // decided by the given channel aux blob and HTLC blob. A return value + // of 0 means there is no bandwidth available. To find out if a channel + // is a custom channel that should be handled by the traffic shaper, the + // ShouldHandleTraffic method should be called first. + PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob], + linkBandwidth, + htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) +} diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 60062862ef..60a3adbe79 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -3415,6 +3415,48 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, return nil } +// AuxBandwidth returns the bandwidth that can be used for a channel, expressed +// in milli-satoshi. This might be different from the regular BTC bandwidth for +// custom channels. This will always return fn.None() for a regular (non-custom) +// channel. +func (l *channelLink) AuxBandwidth(amount lnwire.MilliSatoshi, + cid lnwire.ShortChannelID, htlcBlob fn.Option[tlv.Blob], + ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { + + unknownBandwidth := fn.None[lnwire.MilliSatoshi]() + + fundingBlob := l.FundingCustomBlob() + shouldHandle, err := ts.ShouldHandleTraffic(cid, fundingBlob) + if err != nil { + return fn.Err[OptionalBandwidth](fmt.Errorf("traffic shaper "+ + "failed to decide whether to handle traffic: %w", err)) + } + + log.Debugf("ShortChannelID=%v: aux traffic shaper is handling "+ + "traffic: %v", cid, shouldHandle) + + // If this channel isn't handled by the aux traffic shaper, we'll return + // early. + if !shouldHandle { + return fn.Ok(unknownBandwidth) + } + + // Ask for a specific bandwidth to be used for the channel. + commitmentBlob := l.CommitmentCustomBlob() + auxBandwidth, err := ts.PaymentBandwidth( + htlcBlob, commitmentBlob, l.Bandwidth(), amount, + ) + if err != nil { + return fn.Err[OptionalBandwidth](fmt.Errorf("failed to get "+ + "bandwidth from external traffic shaper: %w", err)) + } + + log.Debugf("ShortChannelID=%v: aux traffic shaper reported available "+ + "bandwidth: %v", cid, auxBandwidth) + + return fn.Ok(fn.Some(auxBandwidth)) +} + // Stats returns the statistics of channel link. // // NOTE: Part of the ChannelLink interface. diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index ce791bef32..918738f415 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -968,6 +968,17 @@ func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { return fn.None[tlv.Blob]() } +// AuxBandwidth returns the bandwidth that can be used for a channel, +// expressed in milli-satoshi. This might be different from the regular +// BTC bandwidth for custom channels. This will always return fn.None() +// for a regular (non-custom) channel. +func (f *mockChannelLink) AuxBandwidth(lnwire.MilliSatoshi, + lnwire.ShortChannelID, + fn.Option[tlv.Blob], AuxTrafficShaper) fn.Result[OptionalBandwidth] { + + return fn.Ok(fn.None[lnwire.MilliSatoshi]()) +} + var _ ChannelLink = (*mockChannelLink)(nil) const testInvoiceCltvExpiry = 6 diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 12e82131dc..eabd66cf82 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -29,39 +29,6 @@ type bandwidthHints interface { firstHopCustomBlob() fn.Option[tlv.Blob] } -// TlvTrafficShaper is an interface that allows the sender to determine if a -// payment should be carried by a channel based on the TLV records that may be -// present in the `update_add_htlc` message or the channel commitment itself. -type TlvTrafficShaper interface { - AuxHtlcModifier - - // ShouldHandleTraffic is called in order to check if the channel - // identified by the provided channel ID may have external mechanisms - // that would allow it to carry out the payment. - ShouldHandleTraffic(cid lnwire.ShortChannelID, - fundingBlob fn.Option[tlv.Blob]) (bool, error) - - // PaymentBandwidth returns the available bandwidth for a custom channel - // decided by the given channel aux blob and HTLC blob. A return value - // of 0 means there is no bandwidth available. To find out if a channel - // is a custom channel that should be handled by the traffic shaper, the - // HandleTraffic method should be called first. - PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob], - linkBandwidth, - htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) -} - -// AuxHtlcModifier is an interface that allows the sender to modify the outgoing -// HTLC of a payment by changing the amount or the wire message tlv records. -type AuxHtlcModifier interface { - // ProduceHtlcExtraData is a function that, based on the previous extra - // data blob of an HTLC, may produce a different blob or modify the - // amount of bitcoin this htlc should carry. - ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, - htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, - lnwire.CustomRecords, error) -} - // getLinkQuery is the function signature used to lookup a link. type getLinkQuery func(lnwire.ShortChannelID) ( htlcswitch.ChannelLink, error) @@ -73,7 +40,7 @@ type bandwidthManager struct { getLink getLinkQuery localChans map[lnwire.ShortChannelID]struct{} firstHopBlob fn.Option[tlv.Blob] - trafficShaper fn.Option[TlvTrafficShaper] + trafficShaper fn.Option[htlcswitch.AuxTrafficShaper] } // newBandwidthManager creates a bandwidth manager for the source node provided @@ -84,13 +51,14 @@ type bandwidthManager struct { // that are inactive, or just don't have enough bandwidth to carry the payment. func newBandwidthManager(graph Graph, sourceNode route.Vertex, linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) { + ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager, + error) { manager := &bandwidthManager{ getLink: linkQuery, localChans: make(map[lnwire.ShortChannelID]struct{}), firstHopBlob: firstHopBlob, - trafficShaper: trafficShaper, + trafficShaper: ts, } // First, we'll collect the set of outbound edges from the target @@ -166,44 +134,15 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, result, err := fn.MapOptionZ( b.trafficShaper, - func(ts TlvTrafficShaper) fn.Result[bandwidthResult] { - fundingBlob := link.FundingCustomBlob() - shouldHandle, err := ts.ShouldHandleTraffic( - cid, fundingBlob, - ) - if err != nil { - return bandwidthErr(fmt.Errorf("traffic "+ - "shaper failed to decide whether to "+ - "handle traffic: %w", err)) - } - - log.Debugf("ShortChannelID=%v: external traffic "+ - "shaper is handling traffic: %v", cid, - shouldHandle) - - // If this channel isn't handled by the external traffic - // shaper, we'll return early. - if !shouldHandle { - return fn.Ok(bandwidthResult{}) - } - - // Ask for a specific bandwidth to be used for the - // channel. - commitmentBlob := link.CommitmentCustomBlob() - auxBandwidth, err := ts.PaymentBandwidth( - b.firstHopBlob, commitmentBlob, linkBandwidth, - amount, - ) + func(s htlcswitch.AuxTrafficShaper) fn.Result[bandwidthResult] { + auxBandwidth, err := link.AuxBandwidth( + amount, cid, b.firstHopBlob, s, + ).Unpack() if err != nil { return bandwidthErr(fmt.Errorf("failed to get "+ - "bandwidth from external traffic "+ - "shaper: %w", err)) + "auxiliary bandwidth: %w", err)) } - log.Debugf("ShortChannelID=%v: external traffic "+ - "shaper reported available bandwidth: %v", cid, - auxBandwidth) - // We don't know the actual HTLC amount that will be // sent using the custom channel. But we'll still want // to make sure we can add another HTLC, using the @@ -213,7 +152,7 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, // the max number of HTLCs on the channel. A proper // balance check is done elsewhere. return fn.Ok(bandwidthResult{ - bandwidth: fn.Some(auxBandwidth), + bandwidth: auxBandwidth, htlcAmount: fn.Some[lnwire.MilliSatoshi](0), }) }, diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index 4872b5a7ec..28b1dfb1ab 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -118,7 +118,9 @@ func TestBandwidthManager(t *testing.T) { m, err := newBandwidthManager( g, sourceNode.pubkey, testCase.linkQuery, fn.None[[]byte](), - fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), ) require.NoError(t, err) diff --git a/routing/mock_test.go b/routing/mock_test.go index 3cdb5ebaf2..3f3f5ea040 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -107,7 +107,7 @@ var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil) func (m *mockPaymentSessionSourceOld) NewPaymentSession( _ *LightningPayment, _ fn.Option[tlv.Blob], - _ fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + _ fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) { return &mockPaymentSessionOld{ routes: m.routes, @@ -635,7 +635,8 @@ var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) func (m *mockPaymentSessionSource) NewPaymentSession( payment *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - tlvShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + tlvShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, + error) { args := m.Called(payment, firstHopBlob, tlvShaper) return args.Get(0).(PaymentSession), args.Error(1) @@ -895,6 +896,19 @@ func (m *mockLink) Bandwidth() lnwire.MilliSatoshi { return m.bandwidth } +// AuxBandwidth returns the bandwidth that can be used for a channel, +// expressed in milli-satoshi. This might be different from the regular +// BTC bandwidth for custom channels. This will always return fn.None() +// for a regular (non-custom) channel. +func (m *mockLink) AuxBandwidth(lnwire.MilliSatoshi, lnwire.ShortChannelID, + fn.Option[tlv.Blob], + htlcswitch.AuxTrafficShaper) fn.Result[htlcswitch.OptionalBandwidth] { + + return fn.Ok[htlcswitch.OptionalBandwidth]( + fn.None[lnwire.MilliSatoshi](), + ) +} + // EligibleToForward returns the mock's configured eligibility. func (m *mockLink) EligibleToForward() bool { return !m.ineligible diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 267ce3965d..6f7034ea6a 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -761,7 +761,8 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { // and apply its side effects to the UpdateAddHTLC message. result, err := fn.MapOptionZ( p.router.cfg.TrafficShaper, - func(ts TlvTrafficShaper) fn.Result[extraDataRequest] { + //nolint:ll + func(ts htlcswitch.AuxTrafficShaper) fn.Result[extraDataRequest] { newAmt, newRecords, err := ts.ProduceHtlcExtraData( rt.TotalAmount, p.firstHopCustomRecords, ) @@ -774,7 +775,7 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { return fn.Err[extraDataRequest](err) } - log.Debugf("TLV traffic shaper returned custom "+ + log.Debugf("Aux traffic shaper returned custom "+ "records %v and amount %d msat for HTLC", spew.Sdump(newRecords), newAmt) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 315c1bad58..d566eb9413 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -30,7 +30,7 @@ func createTestPaymentLifecycle() *paymentLifecycle { quitChan := make(chan struct{}) rt := &ChannelRouter{ cfg: &Config{ - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }, @@ -83,7 +83,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { Payer: mockPayer, Clock: mockClock, MissionControl: mockMissionControl, - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index d5f1a6af41..daaf7743b5 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -4,6 +4,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -52,7 +53,8 @@ type SessionSource struct { // payment's destination. func (m *SessionSource) NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + trafficShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, + error) { getBandwidthHints := func(graph Graph) (bandwidthHints, error) { return newBandwidthManager( diff --git a/routing/router.go b/routing/router.go index 9eabe0b2ae..3405354124 100644 --- a/routing/router.go +++ b/routing/router.go @@ -157,7 +157,7 @@ type PaymentSessionSource interface { // finding a path to the payment's destination. NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, + ts fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) // NewPaymentSessionEmpty creates a new paymentSession instance that is @@ -297,7 +297,7 @@ type Config struct { // TrafficShaper is an optional traffic shaper that can be used to // control the outgoing channel of a payment. - TrafficShaper fn.Option[TlvTrafficShaper] + TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] } // EdgeLocator is a struct used to identify a specific edge. diff --git a/routing/router_test.go b/routing/router_test.go index 2923f1fb90..a69b746f14 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -170,7 +170,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, Clock: clock.NewTestClock(time.Unix(1, 0)), ApplyChannelUpdate: graphBuilder.ApplyChannelUpdate, ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }) @@ -2206,8 +2206,10 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Register mockers with the expected method calls. @@ -2291,8 +2293,10 @@ func TestSendToRouteSkipTempErrNonMPP(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Expect an error to be returned. @@ -2347,8 +2351,10 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. @@ -2431,8 +2437,10 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. @@ -2519,8 +2527,10 @@ func TestSendToRouteTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. From a2e78c3984ba9d79dde20471044ba328c00b258d Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Dec 2024 12:02:24 +0100 Subject: [PATCH 03/59] htlcswitch: thread through packet's inbound wire records For calculating the available auxiliary bandwidth of a channel, we need access to the inbound custom wire records of the HTLC packet, which might contain auxiliary information about the worth of the HTLC packet apart from the BTC value being transported. --- htlcswitch/interfaces.go | 11 ++++---- htlcswitch/link.go | 20 ++++++++------ htlcswitch/link_test.go | 60 +++++++++++++++++++++------------------- htlcswitch/mock.go | 4 +-- htlcswitch/switch.go | 8 +++--- 5 files changed, 55 insertions(+), 48 deletions(-) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 4346339796..caf2abf1ae 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -272,10 +272,10 @@ type ChannelLink interface { // in order to signal to the source of the HTLC, the policy consistency // issue. CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi, - amtToForward lnwire.MilliSatoshi, - incomingTimeout, outgoingTimeout uint32, - inboundFee models.InboundFee, - heightNow uint32, scid lnwire.ShortChannelID) *LinkError + amtToForward lnwire.MilliSatoshi, incomingTimeout, + outgoingTimeout uint32, inboundFee models.InboundFee, + heightNow uint32, scid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError // CheckHtlcTransit should return a nil error if the passed HTLC details // satisfy the current channel policy. Otherwise, a LinkError with a @@ -283,7 +283,8 @@ type ChannelLink interface { // the violation. This call is intended to be used for locally initiated // payments for which there is no corresponding incoming htlc. CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, - timeout uint32, heightNow uint32) *LinkError + timeout uint32, heightNow uint32, + customRecords lnwire.CustomRecords) *LinkError // Stats return the statistics of channel link. Number of updates, // total sent/received milli-satoshis. diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 60a3adbe79..75a35302cb 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -3233,11 +3233,11 @@ func (l *channelLink) UpdateForwardingPolicy( // issue. // // NOTE: Part of the ChannelLink interface. -func (l *channelLink) CheckHtlcForward(payHash [32]byte, - incomingHtlcAmt, amtToForward lnwire.MilliSatoshi, - incomingTimeout, outgoingTimeout uint32, - inboundFee models.InboundFee, - heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { +func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt, + amtToForward lnwire.MilliSatoshi, incomingTimeout, + outgoingTimeout uint32, inboundFee models.InboundFee, + heightNow uint32, originalScid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError { l.RLock() policy := l.cfg.FwrdingPolicy @@ -3286,7 +3286,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // Check whether the outgoing htlc satisfies the channel policy. err := l.canSendHtlc( policy, payHash, amtToForward, outgoingTimeout, heightNow, - originalScid, + originalScid, customRecords, ) if err != nil { return err @@ -3322,8 +3322,8 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // the violation. This call is intended to be used for locally initiated // payments for which there is no corresponding incoming htlc. func (l *channelLink) CheckHtlcTransit(payHash [32]byte, - amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32) *LinkError { + amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32, + customRecords lnwire.CustomRecords) *LinkError { l.RLock() policy := l.cfg.FwrdingPolicy @@ -3334,6 +3334,7 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, // to occur. return l.canSendHtlc( policy, payHash, amt, timeout, heightNow, hop.Source, + customRecords, ) } @@ -3341,7 +3342,8 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, // the channel's amount and time lock constraints. func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { + heightNow uint32, originalScid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError { // As our first sanity check, we'll ensure that the passed HTLC isn't // too small for the next hop. If so, then we'll cancel the HTLC diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 80632b07e9..938dc2e8a5 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6243,9 +6243,9 @@ func TestCheckHtlcForward(t *testing.T) { var hash [32]byte t.Run("satisfied", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if result != nil { t.Fatalf("expected policy to be satisfied") @@ -6253,9 +6253,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("below minhtlc", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 100, 50, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 100, 50, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok { t.Fatalf("expected FailAmountBelowMinimum failure code") @@ -6263,9 +6263,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("above maxhtlc", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1200, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1200, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok { t.Fatalf("expected FailTemporaryChannelFailure failure code") @@ -6273,9 +6273,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("insufficient fee", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1005, 1000, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1005, 1000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok { t.Fatalf("expected FailFeeInsufficient failure code") @@ -6288,17 +6288,17 @@ func TestCheckHtlcForward(t *testing.T) { t.Parallel() result := link.CheckHtlcForward( - hash, 100005, 100000, 200, - 150, models.InboundFee{}, 0, lnwire.ShortChannelID{}, + hash, 100005, 100000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient) require.True(t, ok, "expected FailFeeInsufficient failure code") }) t.Run("expiry too soon", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, models.InboundFee{}, 190, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 150, models.InboundFee{}, 190, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok { t.Fatalf("expected FailExpiryTooSoon failure code") @@ -6306,9 +6306,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("incorrect cltv expiry", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 190, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 190, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok { t.Fatalf("expected FailIncorrectCltvExpiry failure code") @@ -6318,9 +6318,9 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("cltv expiry too far in the future", func(t *testing.T) { // Check that expiry isn't too far in the future. - result := link.CheckHtlcForward(hash, 1500, 1000, - 10200, 10100, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 10200, 10100, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok { t.Fatalf("expected FailExpiryTooFar failure code") @@ -6330,9 +6330,11 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("inbound fee satisfied", func(t *testing.T) { t.Parallel() - result := link.CheckHtlcForward(hash, 1000+10-2-1, 1000, - 200, 150, models.InboundFee{Base: -2, Rate: -1_000}, - 0, lnwire.ShortChannelID{}) + result := link.CheckHtlcForward( + hash, 1000+10-2-1, 1000, 200, 150, + models.InboundFee{Base: -2, Rate: -1_000}, + 0, lnwire.ShortChannelID{}, nil, + ) if result != nil { t.Fatalf("expected policy to be satisfied") } @@ -6341,9 +6343,11 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("inbound fee insufficient", func(t *testing.T) { t.Parallel() - result := link.CheckHtlcForward(hash, 1000+10-10-101-1, 1000, + result := link.CheckHtlcForward( + hash, 1000+10-10-101-1, 1000, 200, 150, models.InboundFee{Base: -10, Rate: -100_000}, - 0, lnwire.ShortChannelID{}) + 0, lnwire.ShortChannelID{}, nil, + ) msg := result.WireMessage() if _, ok := msg.(*lnwire.FailFeeInsufficient); !ok { diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 918738f415..3c201cd701 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -846,14 +846,14 @@ func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) { } func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi, lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32, - lnwire.ShortChannelID) *LinkError { + lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError { return f.checkHtlcForwardResult } func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32) *LinkError { + heightNow uint32, _ lnwire.CustomRecords) *LinkError { return f.checkHtlcTransitResult } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 1a08275ec9..b2c699b140 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -917,6 +917,7 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( currentHeight := atomic.LoadUint32(&s.bestHeight) htlcErr := link.CheckHtlcTransit( htlc.PaymentHash, htlc.Amount, htlc.Expiry, currentHeight, + htlc.CustomRecords, ) if htlcErr != nil { log.Errorf("Link %v policy for local forward not "+ @@ -2887,10 +2888,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket, failure = link.CheckHtlcForward( htlc.PaymentHash, packet.incomingAmount, packet.amount, packet.incomingTimeout, - packet.outgoingTimeout, - packet.inboundFee, - currentHeight, - packet.originalOutgoingChanID, + packet.outgoingTimeout, packet.inboundFee, + currentHeight, packet.originalOutgoingChanID, + htlc.CustomRecords, ) } From 86b3be71feaffa9fe464cf69ce259707cede11ab Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Dec 2024 12:03:55 +0100 Subject: [PATCH 04/59] multi: thread through and use AuxTrafficShaper --- htlcswitch/link.go | 36 +++++++++++++++++++++++++++++++++++- peer/brontide.go | 5 +++++ server.go | 1 + 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 75a35302cb..6a577e5694 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -293,6 +293,10 @@ type ChannelLinkConfig struct { // ShouldFwdExpEndorsement is a closure that indicates whether the link // should forward experimental endorsement signals. ShouldFwdExpEndorsement func() bool + + // AuxTrafficShaper is an optional auxiliary traffic shaper that can be + // used to manage the bandwidth of the link. + AuxTrafficShaper fn.Option[AuxTrafficShaper] } // channelLink is the service which drives a channel's commitment update @@ -3401,8 +3405,38 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, return NewLinkError(&lnwire.FailExpiryTooFar{}) } + // We now check the available bandwidth to see if this HTLC can be + // forwarded. + availableBandwidth := l.Bandwidth() + auxBandwidth, err := fn.MapOptionZ( + l.cfg.AuxTrafficShaper, + func(ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { + var htlcBlob fn.Option[tlv.Blob] + blob, err := customRecords.Serialize() + if err != nil { + return fn.Err[OptionalBandwidth]( + fmt.Errorf("unable to serialize "+ + "custom records: %w", err)) + } + + if len(blob) > 0 { + htlcBlob = fn.Some(blob) + } + + return l.AuxBandwidth(amt, originalScid, htlcBlob, ts) + }, + ).Unpack() + if err != nil { + l.log.Errorf("Unable to determine aux bandwidth: %v", err) + return NewLinkError(&lnwire.FailTemporaryNodeFailure{}) + } + + auxBandwidth.WhenSome(func(bandwidth lnwire.MilliSatoshi) { + availableBandwidth = bandwidth + }) + // Check to see if there is enough balance in this channel. - if amt > l.Bandwidth() { + if amt > availableBandwidth { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, l.Bandwidth()) cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { diff --git a/peer/brontide.go b/peer/brontide.go index 6bc49445ee..7074b10071 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -400,6 +400,10 @@ type Config struct { // way contracts are resolved. AuxResolver fn.Option[lnwallet.AuxContractResolver] + // AuxTrafficShaper is an optional auxiliary traffic shaper that can be + // used to manage the bandwidth of peer links. + AuxTrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] + // PongBuf is a slice we'll reuse instead of allocating memory on the // heap. Since only reads will occur and no writes, there is no need // for any synchronization primitives. As a result, it's safe to share @@ -1330,6 +1334,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, ShouldFwdExpEndorsement: p.cfg.ShouldFwdExpEndorsement, DisallowQuiescence: p.cfg.DisallowQuiescence || !p.remoteFeatures.HasFeature(lnwire.QuiescenceOptional), + AuxTrafficShaper: p.cfg.AuxTrafficShaper, } // Before adding our new link, purge the switch of any pending or live diff --git a/server.go b/server.go index f8f8239ed6..7f725553a3 100644 --- a/server.go +++ b/server.go @@ -4222,6 +4222,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, MsgRouter: s.implCfg.MsgRouter, AuxChanCloser: s.implCfg.AuxChanCloser, AuxResolver: s.implCfg.AuxContractResolver, + AuxTrafficShaper: s.implCfg.TrafficShaper, ShouldFwdExpEndorsement: func() bool { if s.cfg.ProtocolOptions.NoExperimentalEndorsement() { return false From 17bc8827c5d2ab1da36f630140def7d9305b10c6 Mon Sep 17 00:00:00 2001 From: ziggie Date: Tue, 3 Dec 2024 15:42:14 +0100 Subject: [PATCH 05/59] contractcourt: refactor start up of arbitrators We decouple the state machine of the channel arbitrator from the start-up process so that we can startup the whole daemon reliably. --- contractcourt/channel_arbitrator.go | 37 ++++++++++++++++++++---- contractcourt/channel_arbitrator_test.go | 27 ++++++++++++----- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 319b437e4e..7121253a64 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -482,6 +482,20 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { return err } + c.wg.Add(1) + go c.channelAttendant(bestHeight, state.commitSet) + + return nil +} + +// progressStateMachineAfterRestart attempts to progress the state machine +// after a restart. This makes sure that if the state transition failed, we +// will try to progress the state machine again. Moreover it will relaunch +// resolvers if the channel is still in the pending close state and has not +// been fully resolved yet. +func (c *ChannelArbitrator) progressStateMachineAfterRestart(bestHeight int32, + commitSet *CommitSet) error { + // If the channel has been marked pending close in the database, and we // haven't transitioned the state machine to StateContractClosed (or a // succeeding state), then a state transition most likely failed. We'll @@ -527,7 +541,7 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { // on-chain state, and our set of active contracts. startingState := c.state nextState, _, err := c.advanceState( - triggerHeight, trigger, state.commitSet, + triggerHeight, trigger, commitSet, ) if err != nil { switch err { @@ -564,14 +578,12 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { // receive a chain event from the chain watcher that the // commitment has been confirmed on chain, and before we // advance our state step, we call InsertConfirmedCommitSet. - err := c.relaunchResolvers(state.commitSet, triggerHeight) + err := c.relaunchResolvers(commitSet, triggerHeight) if err != nil { return err } } - c.wg.Add(1) - go c.channelAttendant(bestHeight) return nil } @@ -2716,13 +2728,28 @@ func (c *ChannelArbitrator) updateActiveHTLCs() { // Nursery for incubation, and ultimate sweeping. // // NOTE: This MUST be run as a goroutine. -func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { +// +//nolint:funlen +func (c *ChannelArbitrator) channelAttendant(bestHeight int32, + commitSet *CommitSet) { // TODO(roasbeef): tell top chain arb we're done defer func() { c.wg.Done() }() + err := c.progressStateMachineAfterRestart(bestHeight, commitSet) + if err != nil { + // In case of an error, we return early but we do not shutdown + // LND, because there might be other channels that still can be + // resolved and we don't want to interfere with that. + // We continue to run the channel attendant in case the channel + // closes via other means for example the remote pary force + // closes the channel. So we log the error and continue. + log.Errorf("Unable to progress state machine after "+ + "restart: %v", err) + } + for { select { diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 92ad608eb9..b8e32f4e2b 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -21,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -1045,10 +1046,19 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { // Post restart, it should be the case that our resolver was properly // supplemented, and we only have a single resolver in the final set. - if len(chanArb.activeResolvers) != 1 { - t.Fatalf("expected single resolver, instead got: %v", - len(chanArb.activeResolvers)) - } + // The resolvers are added concurrently so we need to wait here. + err = wait.NoError(func() error { + chanArb.activeResolversLock.Lock() + defer chanArb.activeResolversLock.Unlock() + + if len(chanArb.activeResolvers) != 1 { + return fmt.Errorf("expected single resolver, instead "+ + "got: %v", len(chanArb.activeResolvers)) + } + + return nil + }, defaultTimeout) + require.NoError(t, err) // We'll now examine the in-memory state of the active resolvers to // ensure t hey were populated properly. @@ -3000,9 +3010,12 @@ func TestChannelArbitratorStartForceCloseFail(t *testing.T) { { name: "Commitment is rejected with an " + "unmatched error", - broadcastErr: fmt.Errorf("Reject Commitment Tx"), - expectedState: StateBroadcastCommit, - expectedStartup: false, + broadcastErr: fmt.Errorf("Reject Commitment Tx"), + expectedState: StateBroadcastCommit, + // We should still be able to start up since we other + // channels might be closing as well and we should + // resolve the contracts. + expectedStartup: true, }, // We started after the DLP was triggered, and try to force From 0004e3199762571000a78c6fb2953d07085dc319 Mon Sep 17 00:00:00 2001 From: ziggie Date: Wed, 13 Nov 2024 17:01:29 +0100 Subject: [PATCH 06/59] docs: add release-notes --- docs/release-notes/release-notes-0.18.4.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/release-notes/release-notes-0.18.4.md b/docs/release-notes/release-notes-0.18.4.md index 1fd299f3d7..ab72c01c31 100644 --- a/docs/release-notes/release-notes-0.18.4.md +++ b/docs/release-notes/release-notes-0.18.4.md @@ -23,6 +23,10 @@ cause a nil pointer dereference during the probing of a payment request that does not contain a payment address. +* [Fixed a bug](https://github.com/lightningnetwork/lnd/pull/9324) to prevent + potential deadlocks when LND depends on external components (e.g. aux + components, hooks). + # New Features The main channel state machine and database now allow for processing and storing From ac59b06f59ca5d8baec9e16db103dc44dcc4c4b4 Mon Sep 17 00:00:00 2001 From: planetBoy <140164174+Guayaba221@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:48:22 +0100 Subject: [PATCH 07/59] Update ruby.md --- docs/grpc/ruby.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/grpc/ruby.md b/docs/grpc/ruby.md index 599dd2bc7a..457246423f 100644 --- a/docs/grpc/ruby.md +++ b/docs/grpc/ruby.md @@ -58,7 +58,7 @@ $:.unshift(File.dirname(__FILE__)) require 'grpc' require 'lightning_services_pb' -# Due to updated ECDSA generated tls.cert we need to let gprc know that +# Due to updated ECDSA generated tls.cert we need to let grpc know that # we need to use that cipher suite otherwise there will be a handshake # error when we communicate with the lnd rpc server. ENV['GRPC_SSL_CIPHER_SUITES'] = "HIGH+ECDSA" From 7374392abe969be846bf386d43ba30f6fe65b55f Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 5 Dec 2024 23:39:37 +0800 Subject: [PATCH 08/59] lnrpc: sort `Invoice.HTLCs` based on `HtlcIndex` So the returned HTLCs are ordered. --- docs/release-notes/release-notes-0.19.0.md | 4 ++++ lnrpc/invoicesrpc/utils.go | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index db5ca738ef..ee0b0c1360 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -90,6 +90,10 @@ * [The `walletrpc.FundPsbt` method now has a new option to specify the maximum fee to output amounts ratio.](https://github.com/lightningnetwork/lnd/pull/8600) +* When returning the response from list invoices RPC, the `lnrpc.Invoice.Htlcs` + are now [sorted](https://github.com/lightningnetwork/lnd/pull/9337) based on + the `InvoiceHTLC.HtlcIndex`. + ## lncli Additions * [A pre-generated macaroon root key can now be specified in `lncli create` and diff --git a/lnrpc/invoicesrpc/utils.go b/lnrpc/invoicesrpc/utils.go index 955ba6acf2..19ade28fd8 100644 --- a/lnrpc/invoicesrpc/utils.go +++ b/lnrpc/invoicesrpc/utils.go @@ -1,8 +1,10 @@ package invoicesrpc import ( + "cmp" "encoding/hex" "fmt" + "slices" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/chaincfg" @@ -160,6 +162,11 @@ func CreateRPCInvoice(invoice *invoices.Invoice, rpcHtlcs = append(rpcHtlcs, &rpcHtlc) } + // Perform an inplace sort of the HTLCs to ensure they are ordered. + slices.SortFunc(rpcHtlcs, func(i, j *lnrpc.InvoiceHTLC) int { + return cmp.Compare(i.HtlcIndex, j.HtlcIndex) + }) + rpcInvoice := &lnrpc.Invoice{ Memo: string(invoice.Memo), RHash: rHash, From eee4dbd22f824fffa3782758b0c94a621ce105ef Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 17:13:57 +0800 Subject: [PATCH 09/59] sweep: add new state `TxFatal` for erroneous sweepings Also updated the loggings. This new state will be used in the following commit. --- sweep/fee_bumper.go | 59 ++++++++++++++++++++++++++++++---------- sweep/fee_bumper_test.go | 20 ++++++++++++++ sweep/sweeper.go | 5 +++- 3 files changed, 68 insertions(+), 16 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 7bb58ae29e..bfe550886f 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -84,6 +84,11 @@ const ( // TxConfirmed is sent when the tx is confirmed. TxConfirmed + // TxFatal is sent when the inputs in this tx cannot be retried. Txns + // will end up in this state if they have encountered a non-fee related + // error, which means they cannot be retried with increased budget. + TxFatal + // sentinalEvent is used to check if an event is unknown. sentinalEvent ) @@ -99,6 +104,8 @@ func (e BumpEvent) String() string { return "Replaced" case TxConfirmed: return "Confirmed" + case TxFatal: + return "Fatal" default: return "Unknown" } @@ -246,10 +253,20 @@ type BumpResult struct { requestID uint64 } +// String returns a human-readable string for the result. +func (b *BumpResult) String() string { + desc := fmt.Sprintf("Event=%v", b.Event) + if b.Tx != nil { + desc += fmt.Sprintf(", Tx=%v", b.Tx.TxHash()) + } + + return fmt.Sprintf("[%s]", desc) +} + // Validate validates the BumpResult so it's safe to use. func (b *BumpResult) Validate() error { - // Every result must have a tx. - if b.Tx == nil { + // Every result must have a tx except the fatal or failed case. + if b.Tx == nil && b.Event != TxFatal { return fmt.Errorf("%w: nil tx", ErrInvalidBumpResult) } @@ -263,9 +280,11 @@ func (b *BumpResult) Validate() error { return fmt.Errorf("%w: nil replacing tx", ErrInvalidBumpResult) } - // If it's a failed event, it must have an error. - if b.Event == TxFailed && b.Err == nil { - return fmt.Errorf("%w: nil error", ErrInvalidBumpResult) + // If it's a failed or fatal event, it must have an error. + if b.Event == TxFatal || b.Event == TxFailed { + if b.Err == nil { + return fmt.Errorf("%w: nil error", ErrInvalidBumpResult) + } } // If it's a confirmed event, it must have a fee rate and fee. @@ -659,8 +678,7 @@ func (t *TxPublisher) notifyResult(result *BumpResult) { return } - log.Debugf("Sending result for requestID=%v, tx=%v", id, - result.Tx.TxHash()) + log.Debugf("Sending result %v for requestID=%v", result, id) select { // Send the result to the subscriber. @@ -678,20 +696,31 @@ func (t *TxPublisher) notifyResult(result *BumpResult) { func (t *TxPublisher) removeResult(result *BumpResult) { id := result.requestID - // Remove the record from the maps if there's an error. This means this - // tx has failed its broadcast and cannot be retried. There are two - // cases, - // - when the budget cannot cover the fee. - // - when a non-RBF related error occurs. + var txid chainhash.Hash + if result.Tx != nil { + txid = result.Tx.TxHash() + } + + // Remove the record from the maps if there's an error or the tx is + // confirmed. When there's an error, it means this tx has failed its + // broadcast and cannot be retried. There are two cases it may fail, + // - when the budget cannot cover the increased fee calculated by the + // fee function, hence the budget is used up. + // - when a non-fee related error returned from PublishTransaction. switch result.Event { case TxFailed: log.Errorf("Removing monitor record=%v, tx=%v, due to err: %v", - id, result.Tx.TxHash(), result.Err) + id, txid, result.Err) case TxConfirmed: - // Remove the record is the tx is confirmed. + // Remove the record if the tx is confirmed. log.Debugf("Removing confirmed monitor record=%v, tx=%v", id, - result.Tx.TxHash()) + txid) + + case TxFatal: + // Remove the record if there's an error. + log.Debugf("Removing monitor record=%v due to fatal err: %v", + id, result.Err) // Do nothing if it's neither failed or confirmed. default: diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index c9196aee5a..0c107b29ff 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -91,6 +91,19 @@ func TestBumpResultValidate(t *testing.T) { } require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + // A failed event without a tx will give an error. + b = BumpResult{ + Event: TxFailed, + Err: errDummy, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A fatal event without a failure reason will give an error. + b = BumpResult{ + Event: TxFailed, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + // A confirmed event without fee info will give an error. b = BumpResult{ Tx: &wire.MsgTx{}, @@ -104,6 +117,13 @@ func TestBumpResultValidate(t *testing.T) { Event: TxPublished, } require.NoError(t, b.Validate()) + + // Tx is allowed to be nil in a TxFatal event. + b = BumpResult{ + Event: TxFatal, + Err: errDummy, + } + require.NoError(t, b.Validate()) } // TestCalcSweepTxWeight checks that the weight of the sweep tx is calculated diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 9eeefc94b8..a2a48beb25 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -1729,7 +1729,7 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { // NOTE: TxConfirmed event is not handled, since we already subscribe to the // input's spending event, we don't need to do anything here. func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { - log.Debugf("Received bump event [%v] for tx %v", r.Event, r.Tx.TxHash()) + log.Debugf("Received bump result %v", r) switch r.Event { // The tx has been published, we update the inputs' state and create a @@ -1745,6 +1745,9 @@ func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { // with the new one. case TxReplaced: return s.handleBumpEventTxReplaced(r) + + case TxFatal: + // TODO(yy): create a method to remove this input. } return nil From cfbfae6c5113356e507ff7f34febf2cebabbff16 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 25 Oct 2024 15:27:26 +0800 Subject: [PATCH 10/59] sweep: add new error `ErrZeroFeeRateDelta` --- sweep/fee_function.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sweep/fee_function.go b/sweep/fee_function.go index bff44000be..eb2ed4d6b1 100644 --- a/sweep/fee_function.go +++ b/sweep/fee_function.go @@ -14,6 +14,9 @@ var ( // ErrMaxPosition is returned when trying to increase the position of // the fee function while it's already at its max. ErrMaxPosition = errors.New("position already at max") + + // ErrZeroFeeRateDelta is returned when the fee rate delta is zero. + ErrZeroFeeRateDelta = errors.New("fee rate delta is zero") ) // mSatPerKWeight represents a fee rate in msat/kw. @@ -169,7 +172,7 @@ func NewLinearFeeFunction(maxFeeRate chainfee.SatPerKWeight, "endingFeeRate=%v, width=%v, delta=%v", start, end, l.width, l.deltaFeeRate) - return nil, fmt.Errorf("fee rate delta is zero") + return nil, ErrZeroFeeRateDelta } // Attach the calculated values to the fee function. From 20b71a80464b13b221fa321ac559a7be167c035a Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 25 Oct 2024 15:32:30 +0800 Subject: [PATCH 11/59] sweep: add new interface method `Immediate` This prepares the following commit where we now let the fee bumpr decides whether to broadcast immediately or not. --- sweep/fee_bumper.go | 4 ++++ sweep/mock_test.go | 7 +++++++ sweep/sweeper.go | 1 + sweep/sweeper_test.go | 2 ++ sweep/tx_input_set.go | 22 ++++++++++++++++++++++ 5 files changed, 36 insertions(+) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index bfe550886f..1dff9bea88 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -143,6 +143,10 @@ type BumpRequest struct { // ExtraTxOut tracks if this bump request has an optional set of extra // outputs to add to the transaction. ExtraTxOut fn.Option[SweepOutput] + + // Immediate is used to specify that the tx should be broadcast + // immediately. + Immediate bool } // MaxFeeRateAllowed returns the maximum fee rate allowed for the given diff --git a/sweep/mock_test.go b/sweep/mock_test.go index eeeb283969..d42b0320d4 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -268,6 +268,13 @@ func (m *MockInputSet) StartingFeeRate() fn.Option[chainfee.SatPerKWeight] { return args.Get(0).(fn.Option[chainfee.SatPerKWeight]) } +// Immediate returns whether the inputs should be swept immediately. +func (m *MockInputSet) Immediate() bool { + args := m.Called() + + return args.Bool(0) +} + // MockBumper is a mock implementation of the interface Bumper. type MockBumper struct { mock.Mock diff --git a/sweep/sweeper.go b/sweep/sweeper.go index a2a48beb25..6245d7941b 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -827,6 +827,7 @@ func (s *UtxoSweeper) sweep(set InputSet) error { DeliveryAddress: sweepAddr, MaxFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(), StartingFeeRate: set.StartingFeeRate(), + Immediate: set.Immediate(), // TODO(yy): pass the strategy here. } diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 7d99ba93b9..2b527f5c88 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -704,11 +704,13 @@ func TestSweepPendingInputs(t *testing.T) { setNeedWallet.On("Budget").Return(btcutil.Amount(1)).Once() setNeedWallet.On("StartingFeeRate").Return( fn.None[chainfee.SatPerKWeight]()).Once() + setNeedWallet.On("Immediate").Return(false).Once() normalSet.On("Inputs").Return(nil).Maybe() normalSet.On("DeadlineHeight").Return(testHeight).Once() normalSet.On("Budget").Return(btcutil.Amount(1)).Once() normalSet.On("StartingFeeRate").Return( fn.None[chainfee.SatPerKWeight]()).Once() + normalSet.On("Immediate").Return(false).Once() // Make pending inputs for testing. We don't need real values here as // the returned clusters are mocked. diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index adae7cf131..b80d52b0ea 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -64,6 +64,13 @@ type InputSet interface { // StartingFeeRate returns the max starting fee rate found in the // inputs. StartingFeeRate() fn.Option[chainfee.SatPerKWeight] + + // Immediate returns a boolean to indicate whether the tx made from + // this input set should be published immediately. + // + // TODO(yy): create a new method `Params` to combine the informational + // methods DeadlineHeight, Budget, StartingFeeRate and Immediate. + Immediate() bool } // createWalletTxInput converts a wallet utxo into an object that can be added @@ -414,3 +421,18 @@ func (b *BudgetInputSet) StartingFeeRate() fn.Option[chainfee.SatPerKWeight] { return startingFeeRate } + +// Immediate returns whether the inputs should be swept immediately. +// +// NOTE: part of the InputSet interface. +func (b *BudgetInputSet) Immediate() bool { + for _, inp := range b.inputs { + // As long as one of the inputs is immediate, the whole set is + // immediate. + if inp.params.Immediate { + return true + } + } + + return false +} From fcd47e98f0aa2a3aafa9b10ff4de49b4556dc1d9 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 25 Oct 2024 17:31:54 +0800 Subject: [PATCH 12/59] sweep: handle inputs locally instead of relying on the tx This commit changes how inputs are handled upon receiving a bump result. Previously the inputs are taken from the `BumpResult.Tx`, which is now instead being handled locally as we will remember the input set when sending the bump request, and handle this input set when a result is received. --- sweep/sweeper.go | 108 ++++++------ sweep/sweeper_test.go | 375 +++++++++++++++++++----------------------- 2 files changed, 232 insertions(+), 251 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 6245d7941b..ed39548014 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -309,9 +309,9 @@ type UtxoSweeper struct { // updated whenever a new block epoch is received. currentHeight int32 - // bumpResultChan is a channel that receives broadcast results from the + // bumpRespChan is a channel that receives broadcast results from the // TxPublisher. - bumpResultChan chan *BumpResult + bumpRespChan chan *bumpResp } // UtxoSweeperConfig contains dependencies of UtxoSweeper. @@ -395,7 +395,7 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { pendingSweepsReqs: make(chan *pendingSweepsReq), quit: make(chan struct{}), inputs: make(InputsMap), - bumpResultChan: make(chan *BumpResult, 100), + bumpRespChan: make(chan *bumpResp, 100), } } @@ -681,9 +681,9 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { s.sweepPendingInputs(inputs) } - case result := <-s.bumpResultChan: + case resp := <-s.bumpRespChan: // Handle the bump event. - err := s.handleBumpEvent(result) + err := s.handleBumpEvent(resp) if err != nil { log.Errorf("Failed to handle bump event: %v", err) @@ -840,16 +840,11 @@ func (s *UtxoSweeper) sweep(set InputSet) error { // this publish result and future RBF attempt. resp, err := s.cfg.Publisher.Broadcast(req) if err != nil { - outpoints := make([]wire.OutPoint, len(set.Inputs())) - for i, inp := range set.Inputs() { - outpoints[i] = inp.OutPoint() - } - log.Errorf("Initial broadcast failed: %v, inputs=\n%v", err, inputTypeSummary(set.Inputs())) // TODO(yy): find out which input is causing the failure. - s.markInputsPublishFailed(outpoints) + s.markInputsPublishFailed(set) return err } @@ -858,7 +853,7 @@ func (s *UtxoSweeper) sweep(set InputSet) error { // subscribing to the result chan and listen for future updates about // this tx. s.wg.Add(1) - go s.monitorFeeBumpResult(resp) + go s.monitorFeeBumpResult(set, resp) return nil } @@ -868,14 +863,14 @@ func (s *UtxoSweeper) sweep(set InputSet) error { func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) { // Reschedule sweep. for _, input := range set.Inputs() { - pi, ok := s.inputs[input.OutPoint()] + op := input.OutPoint() + pi, ok := s.inputs[op] if !ok { // It could be that this input is an additional wallet // input that was attached. In that case there also // isn't a pending input to update. log.Tracef("Skipped marking input as pending "+ - "published: %v not found in pending inputs", - input.OutPoint()) + "published: %v not found in pending inputs", op) continue } @@ -886,8 +881,7 @@ func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) { // publish. if pi.terminated() { log.Errorf("Expect input %v to not have terminated "+ - "state, instead it has %v", - input.OutPoint, pi.state) + "state, instead it has %v", op, pi.state) continue } @@ -902,9 +896,7 @@ func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) { // markInputsPublished updates the sweeping tx in db and marks the list of // inputs as published. -func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, - inputs []*wire.TxIn) error { - +func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, set InputSet) error { // Mark this tx in db once successfully published. // // NOTE: this will behave as an overwrite, which is fine as the record @@ -916,15 +908,15 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, } // Reschedule sweep. - for _, input := range inputs { - pi, ok := s.inputs[input.PreviousOutPoint] + for _, input := range set.Inputs() { + op := input.OutPoint() + pi, ok := s.inputs[op] if !ok { // It could be that this input is an additional wallet // input that was attached. In that case there also // isn't a pending input to update. log.Tracef("Skipped marking input as published: %v "+ - "not found in pending inputs", - input.PreviousOutPoint) + "not found in pending inputs", op) continue } @@ -933,8 +925,7 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, if pi.state != PendingPublish { // We may get a Published if this is a replacement tx. log.Debugf("Expect input %v to have %v, instead it "+ - "has %v", input.PreviousOutPoint, - PendingPublish, pi.state) + "has %v", op, PendingPublish, pi.state) continue } @@ -950,9 +941,10 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, } // markInputsPublishFailed marks the list of inputs as failed to be published. -func (s *UtxoSweeper) markInputsPublishFailed(outpoints []wire.OutPoint) { +func (s *UtxoSweeper) markInputsPublishFailed(set InputSet) { // Reschedule sweep. - for _, op := range outpoints { + for _, inp := range set.Inputs() { + op := inp.OutPoint() pi, ok := s.inputs[op] if !ok { // It could be that this input is an additional wallet @@ -1540,6 +1532,8 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // sweepPendingInputs is called when the ticker fires. It will create clusters // and attempt to create and publish the sweeping transactions. func (s *UtxoSweeper) sweepPendingInputs(inputs InputsMap) { + log.Debugf("Sweeping %v inputs", len(inputs)) + // Cluster all of our inputs based on the specific Aggregator. sets := s.cfg.Aggregator.ClusterInputs(inputs) @@ -1581,11 +1575,24 @@ func (s *UtxoSweeper) sweepPendingInputs(inputs InputsMap) { } } +// bumpResp wraps the result of a bump attempt returned from the fee bumper and +// the inputs being used. +type bumpResp struct { + // result is the result of the bump attempt returned from the fee + // bumper. + result *BumpResult + + // set is the input set that was used in the bump attempt. + set InputSet +} + // monitorFeeBumpResult subscribes to the passed result chan to listen for // future updates about the sweeping tx. // // NOTE: must run as a goroutine. -func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { +func (s *UtxoSweeper) monitorFeeBumpResult(set InputSet, + resultChan <-chan *BumpResult) { + defer s.wg.Done() for { @@ -1597,9 +1604,14 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { continue } + resp := &bumpResp{ + result: r, + set: set, + } + // Send the result back to the main event loop. select { - case s.bumpResultChan <- r: + case s.bumpRespChan <- resp: case <-s.quit: log.Debug("Sweeper shutting down, skip " + "sending bump result") @@ -1635,25 +1647,25 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { // handleBumpEventTxFailed handles the case where the tx has been failed to // publish. -func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error { +func (s *UtxoSweeper) handleBumpEventTxFailed(resp *bumpResp) { + r := resp.result tx, err := r.Tx, r.Err log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) - outpoints := make([]wire.OutPoint, 0, len(tx.TxIn)) - for _, inp := range tx.TxIn { - outpoints = append(outpoints, inp.PreviousOutPoint) - } - + // NOTE: When marking the inputs as failed, we are using the input set + // instead of the inputs found in the tx. This is fine for current + // version of the sweeper because we always create a tx using ALL of + // the inputs specified by the set. + // // TODO(yy): should we also remove the failed tx from db? - s.markInputsPublishFailed(outpoints) - - return err + s.markInputsPublishFailed(resp.set) } // handleBumpEventTxReplaced handles the case where the sweeping tx has been // replaced by a new one. -func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error { +func (s *UtxoSweeper) handleBumpEventTxReplaced(resp *bumpResp) error { + r := resp.result oldTx := r.ReplacedTx newTx := r.Tx @@ -1693,12 +1705,13 @@ func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error { } // Mark the inputs as published using the replacing tx. - return s.markInputsPublished(tr, r.Tx.TxIn) + return s.markInputsPublished(tr, resp.set) } // handleBumpEventTxPublished handles the case where the sweeping tx has been // successfully published. -func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { +func (s *UtxoSweeper) handleBumpEventTxPublished(resp *bumpResp) error { + r := resp.result tx := r.Tx tr := &TxRecord{ Txid: tx.TxHash(), @@ -1708,7 +1721,7 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { // Inputs have been successfully published so we update their // states. - err := s.markInputsPublished(tr, tx.TxIn) + err := s.markInputsPublished(tr, resp.set) if err != nil { return err } @@ -1729,10 +1742,10 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { // // NOTE: TxConfirmed event is not handled, since we already subscribe to the // input's spending event, we don't need to do anything here. -func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { - log.Debugf("Received bump result %v", r) +func (s *UtxoSweeper) handleBumpEvent(r *bumpResp) error { + log.Debugf("Received bump result %v", r.result) - switch r.Event { + switch r.result.Event { // The tx has been published, we update the inputs' state and create a // record to be stored in the sweeper db. case TxPublished: @@ -1740,7 +1753,8 @@ func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { // The tx has failed, we update the inputs' state. case TxFailed: - return s.handleBumpEventTxFailed(r) + s.handleBumpEventTxFailed(r) + return nil // The tx has been replaced, we will remove the old tx and replace it // with the new one. diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 2b527f5c88..2e1be71c26 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1,6 +1,7 @@ package sweep import ( + "crypto/rand" "errors" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/mock" @@ -33,6 +35,41 @@ var ( }) ) +// createMockInput creates a mock input and saves it to the sweeper's inputs +// map. The created input has the specified state and a random outpoint. It +// will assert the method `OutPoint` is called at least once. +func createMockInput(t *testing.T, s *UtxoSweeper, + state SweepState) *input.MockInput { + + inp := &input.MockInput{} + t.Cleanup(func() { + inp.AssertExpectations(t) + }) + + randBuf := make([]byte, lntypes.HashSize) + _, err := rand.Read(randBuf) + require.NoError(t, err, "internal error, cannot generate random bytes") + + randHash, err := chainhash.NewHash(randBuf) + require.NoError(t, err) + + inp.On("OutPoint").Return(wire.OutPoint{ + Hash: *randHash, + Index: 0, + }) + + // We don't do branch switches based on the witness type here so we + // just mock it. + inp.On("WitnessType").Return(input.CommitmentTimeLock).Maybe() + + s.inputs[inp.OutPoint()] = &SweeperInput{ + Input: inp, + state: state, + } + + return inp +} + // TestMarkInputsPendingPublish checks that given a list of inputs with // different states, only the non-terminal state will be marked as `Published`. func TestMarkInputsPendingPublish(t *testing.T) { @@ -47,50 +84,21 @@ func TestMarkInputsPendingPublish(t *testing.T) { set := &MockInputSet{} defer set.AssertExpectations(t) - // Create three testing inputs. - // - // inputNotExist specifies an input that's not found in the sweeper's - // `pendingInputs` map. - inputNotExist := &input.MockInput{} - defer inputNotExist.AssertExpectations(t) - - inputNotExist.On("OutPoint").Return(wire.OutPoint{Index: 0}) - - // inputInit specifies a newly created input. - inputInit := &input.MockInput{} - defer inputInit.AssertExpectations(t) - - inputInit.On("OutPoint").Return(wire.OutPoint{Index: 1}) - - s.inputs[inputInit.OutPoint()] = &SweeperInput{ - state: Init, - } - - // inputPendingPublish specifies an input that's about to be published. - inputPendingPublish := &input.MockInput{} - defer inputPendingPublish.AssertExpectations(t) - - inputPendingPublish.On("OutPoint").Return(wire.OutPoint{Index: 2}) - - s.inputs[inputPendingPublish.OutPoint()] = &SweeperInput{ - state: PendingPublish, - } - - // inputTerminated specifies an input that's terminated. - inputTerminated := &input.MockInput{} - defer inputTerminated.AssertExpectations(t) - - inputTerminated.On("OutPoint").Return(wire.OutPoint{Index: 3}) - - s.inputs[inputTerminated.OutPoint()] = &SweeperInput{ - state: Excluded, - } + // Create three inputs with different states. + // - inputInit specifies a newly created input. + // - inputPendingPublish specifies an input about to be published. + // - inputTerminated specifies an input that's terminated. + var ( + inputInit = createMockInput(t, s, Init) + inputPendingPublish = createMockInput(t, s, PendingPublish) + inputTerminated = createMockInput(t, s, Excluded) + ) // Mark the test inputs. We expect the non-exist input and the // inputTerminated to be skipped, and the rest to be marked as pending // publish. set.On("Inputs").Return([]input.Input{ - inputNotExist, inputInit, inputPendingPublish, inputTerminated, + inputInit, inputPendingPublish, inputTerminated, }) s.markInputsPendingPublish(set) @@ -122,36 +130,22 @@ func TestMarkInputsPublished(t *testing.T) { dummyTR := &TxRecord{} dummyErr := errors.New("dummy error") + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: mockStore, }) - // Create three testing inputs. - // - // inputNotExist specifies an input that's not found in the sweeper's - // `inputs` map. - inputNotExist := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 1}, - } - - // inputInit specifies a newly created input. When marking this as - // published, we should see an error log as this input hasn't been - // published yet. - inputInit := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 2}, - } - s.inputs[inputInit.PreviousOutPoint] = &SweeperInput{ - state: Init, - } - - // inputPendingPublish specifies an input that's about to be published. - inputPendingPublish := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 3}, - } - s.inputs[inputPendingPublish.PreviousOutPoint] = &SweeperInput{ - state: PendingPublish, - } + // Create two inputs with different states. + // - inputInit specifies a newly created input. + // - inputPendingPublish specifies an input about to be published. + var ( + inputInit = createMockInput(t, s, Init) + inputPendingPublish = createMockInput(t, s, PendingPublish) + ) // First, check that when an error is returned from db, it's properly // returned here. @@ -171,9 +165,9 @@ func TestMarkInputsPublished(t *testing.T) { // Mark the test inputs. We expect the non-exist input and the // inputInit to be skipped, and the final input to be marked as // published. - err = s.markInputsPublished(dummyTR, []*wire.TxIn{ - inputNotExist, inputInit, inputPendingPublish, - }) + set.On("Inputs").Return([]input.Input{inputInit, inputPendingPublish}) + + err = s.markInputsPublished(dummyTR, set) require.NoError(err) // We expect unchanged number of pending inputs. @@ -181,11 +175,11 @@ func TestMarkInputsPublished(t *testing.T) { // We expect the init input's state to stay unchanged. require.Equal(Init, - s.inputs[inputInit.PreviousOutPoint].state) + s.inputs[inputInit.OutPoint()].state) // We expect the pending-publish input's is now marked as published. require.Equal(Published, - s.inputs[inputPendingPublish.PreviousOutPoint].state) + s.inputs[inputPendingPublish.OutPoint()].state) // Assert mocked statements are executed as expected. mockStore.AssertExpectations(t) @@ -202,117 +196,75 @@ func TestMarkInputsPublishFailed(t *testing.T) { // Create a mock sweeper store. mockStore := NewMockSweeperStore() + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: mockStore, }) - // Create testing inputs for each state. - // - // inputNotExist specifies an input that's not found in the sweeper's - // `inputs` map. - inputNotExist := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 1}, - } - - // inputInit specifies a newly created input. When marking this as - // published, we should see an error log as this input hasn't been - // published yet. - inputInit := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 2}, - } - s.inputs[inputInit.PreviousOutPoint] = &SweeperInput{ - state: Init, - } - - // inputPendingPublish specifies an input that's about to be published. - inputPendingPublish := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 3}, - } - s.inputs[inputPendingPublish.PreviousOutPoint] = &SweeperInput{ - state: PendingPublish, - } - - // inputPublished specifies an input that's published. - inputPublished := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 4}, - } - s.inputs[inputPublished.PreviousOutPoint] = &SweeperInput{ - state: Published, - } - - // inputPublishFailed specifies an input that's failed to be published. - inputPublishFailed := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 5}, - } - s.inputs[inputPublishFailed.PreviousOutPoint] = &SweeperInput{ - state: PublishFailed, - } - - // inputSwept specifies an input that's swept. - inputSwept := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 6}, - } - s.inputs[inputSwept.PreviousOutPoint] = &SweeperInput{ - state: Swept, - } - - // inputExcluded specifies an input that's excluded. - inputExcluded := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 7}, - } - s.inputs[inputExcluded.PreviousOutPoint] = &SweeperInput{ - state: Excluded, - } - - // inputFailed specifies an input that's failed. - inputFailed := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 8}, - } - s.inputs[inputFailed.PreviousOutPoint] = &SweeperInput{ - state: Failed, - } + // Create inputs with different states. + // - inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + // - inputPendingPublish specifies an input about to be published. + // - inputPublished specifies an input that's published. + // - inputPublishFailed specifies an input that's failed to be + // published. + // - inputSwept specifies an input that's swept. + // - inputExcluded specifies an input that's excluded. + // - inputFailed specifies an input that's failed. + var ( + inputInit = createMockInput(t, s, Init) + inputPendingPublish = createMockInput(t, s, PendingPublish) + inputPublished = createMockInput(t, s, Published) + inputPublishFailed = createMockInput(t, s, PublishFailed) + inputSwept = createMockInput(t, s, Swept) + inputExcluded = createMockInput(t, s, Excluded) + inputFailed = createMockInput(t, s, Failed) + ) - // Gather all inputs' outpoints. - pendingOps := make([]wire.OutPoint, 0, len(s.inputs)+1) - for op := range s.inputs { - pendingOps = append(pendingOps, op) - } - pendingOps = append(pendingOps, inputNotExist.PreviousOutPoint) + // Gather all inputs. + set.On("Inputs").Return([]input.Input{ + inputInit, inputPendingPublish, inputPublished, + inputPublishFailed, inputSwept, inputExcluded, inputFailed, + }) // Mark the test inputs. We expect the non-exist input and the // inputInit to be skipped, and the final input to be marked as // published. - s.markInputsPublishFailed(pendingOps) + s.markInputsPublishFailed(set) // We expect unchanged number of pending inputs. require.Len(s.inputs, 7) // We expect the init input's state to stay unchanged. require.Equal(Init, - s.inputs[inputInit.PreviousOutPoint].state) + s.inputs[inputInit.OutPoint()].state) // We expect the pending-publish input's is now marked as publish // failed. require.Equal(PublishFailed, - s.inputs[inputPendingPublish.PreviousOutPoint].state) + s.inputs[inputPendingPublish.OutPoint()].state) // We expect the published input's is now marked as publish failed. require.Equal(PublishFailed, - s.inputs[inputPublished.PreviousOutPoint].state) + s.inputs[inputPublished.OutPoint()].state) // We expect the publish failed input to stay unchanged. require.Equal(PublishFailed, - s.inputs[inputPublishFailed.PreviousOutPoint].state) + s.inputs[inputPublishFailed.OutPoint()].state) // We expect the swept input to stay unchanged. - require.Equal(Swept, s.inputs[inputSwept.PreviousOutPoint].state) + require.Equal(Swept, s.inputs[inputSwept.OutPoint()].state) // We expect the excluded input to stay unchanged. - require.Equal(Excluded, s.inputs[inputExcluded.PreviousOutPoint].state) + require.Equal(Excluded, s.inputs[inputExcluded.OutPoint()].state) // We expect the failed input to stay unchanged. - require.Equal(Failed, s.inputs[inputFailed.PreviousOutPoint].state) + require.Equal(Failed, s.inputs[inputFailed.OutPoint()].state) // Assert mocked statements are executed as expected. mockStore.AssertExpectations(t) @@ -738,33 +690,33 @@ func TestSweepPendingInputs(t *testing.T) { func TestHandleBumpEventTxFailed(t *testing.T) { t.Parallel() + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{}) - var ( - // Create four testing outpoints. - op1 = wire.OutPoint{Hash: chainhash.Hash{1}} - op2 = wire.OutPoint{Hash: chainhash.Hash{2}} - op3 = wire.OutPoint{Hash: chainhash.Hash{3}} - opNotExist = wire.OutPoint{Hash: chainhash.Hash{4}} - ) + // inputNotExist specifies an input that's not found in the sweeper's + // `pendingInputs` map. + inputNotExist := &input.MockInput{} + defer inputNotExist.AssertExpectations(t) + inputNotExist.On("OutPoint").Return(wire.OutPoint{Index: 0}) + opNotExist := inputNotExist.OutPoint() // Create three mock inputs. - input1 := &input.MockInput{} - defer input1.AssertExpectations(t) - - input2 := &input.MockInput{} - defer input2.AssertExpectations(t) + var ( + input1 = createMockInput(t, s, PendingPublish) + input2 = createMockInput(t, s, PendingPublish) + input3 = createMockInput(t, s, PendingPublish) + ) - input3 := &input.MockInput{} - defer input3.AssertExpectations(t) + op1 := input1.OutPoint() + op2 := input2.OutPoint() + op3 := input3.OutPoint() // Construct the initial state for the sweeper. - s.inputs = InputsMap{ - op1: &SweeperInput{Input: input1, state: PendingPublish}, - op2: &SweeperInput{Input: input2, state: PendingPublish}, - op3: &SweeperInput{Input: input3, state: PendingPublish}, - } + set.On("Inputs").Return([]input.Input{input1, input2, input3}) // Create a testing tx that spends the first two inputs. tx := &wire.MsgTx{ @@ -782,16 +734,26 @@ func TestHandleBumpEventTxFailed(t *testing.T) { Err: errDummy, } + // Create a testing bump response. + resp := &bumpResp{ + result: br, + set: set, + } + // Call the method under test. - err := s.handleBumpEvent(br) + err := s.handleBumpEvent(resp) require.ErrorIs(t, err, errDummy) // Assert the states of the first two inputs are updated. require.Equal(t, PublishFailed, s.inputs[op1].state) require.Equal(t, PublishFailed, s.inputs[op2].state) - // Assert the state of the third input is not updated. - require.Equal(t, PendingPublish, s.inputs[op3].state) + // Assert the state of the third input. + // + // NOTE: Although the tx doesn't spend it, we still mark this input as + // failed as we are treating the input set as the single source of + // truth. + require.Equal(t, PublishFailed, s.inputs[op3].state) // Assert the non-existing input is not added to the pending inputs. require.NotContains(t, s.inputs, opNotExist) @@ -810,23 +772,21 @@ func TestHandleBumpEventTxReplaced(t *testing.T) { wallet := &MockWallet{} defer wallet.AssertExpectations(t) + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: store, Wallet: wallet, }) - // Create a testing outpoint. - op := wire.OutPoint{Hash: chainhash.Hash{1}} - // Create a mock input. - inp := &input.MockInput{} - defer inp.AssertExpectations(t) + inp := createMockInput(t, s, PendingPublish) + set.On("Inputs").Return([]input.Input{inp}) - // Construct the initial state for the sweeper. - s.inputs = InputsMap{ - op: &SweeperInput{Input: inp, state: PendingPublish}, - } + op := inp.OutPoint() // Create a testing tx that spends the input. tx := &wire.MsgTx{ @@ -851,12 +811,18 @@ func TestHandleBumpEventTxReplaced(t *testing.T) { Event: TxReplaced, } + // Create a testing bump response. + resp := &bumpResp{ + result: br, + set: set, + } + // Mock the store to return an error. dummyErr := errors.New("dummy error") store.On("GetTx", tx.TxHash()).Return(nil, dummyErr).Once() // Call the method under test and assert the error is returned. - err := s.handleBumpEventTxReplaced(br) + err := s.handleBumpEventTxReplaced(resp) require.ErrorIs(t, err, dummyErr) // Mock the store to return the old tx record. @@ -871,7 +837,7 @@ func TestHandleBumpEventTxReplaced(t *testing.T) { store.On("DeleteTx", tx.TxHash()).Return(dummyErr).Once() // Call the method under test and assert the error is returned. - err = s.handleBumpEventTxReplaced(br) + err = s.handleBumpEventTxReplaced(resp) require.ErrorIs(t, err, dummyErr) // Mock the store to return the old tx record and delete it without @@ -891,7 +857,7 @@ func TestHandleBumpEventTxReplaced(t *testing.T) { wallet.On("CancelRebroadcast", tx.TxHash()).Once() // Call the method under test. - err = s.handleBumpEventTxReplaced(br) + err = s.handleBumpEventTxReplaced(resp) require.NoError(t, err) // Assert the state of the input is updated. @@ -907,22 +873,20 @@ func TestHandleBumpEventTxPublished(t *testing.T) { store := &MockSweeperStore{} defer store.AssertExpectations(t) + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: store, }) - // Create a testing outpoint. - op := wire.OutPoint{Hash: chainhash.Hash{1}} - // Create a mock input. - inp := &input.MockInput{} - defer inp.AssertExpectations(t) + inp := createMockInput(t, s, PendingPublish) + set.On("Inputs").Return([]input.Input{inp}) - // Construct the initial state for the sweeper. - s.inputs = InputsMap{ - op: &SweeperInput{Input: inp, state: PendingPublish}, - } + op := inp.OutPoint() // Create a testing tx that spends the input. tx := &wire.MsgTx{ @@ -938,6 +902,12 @@ func TestHandleBumpEventTxPublished(t *testing.T) { Event: TxPublished, } + // Create a testing bump response. + resp := &bumpResp{ + result: br, + set: set, + } + // Mock the store to save the new tx record. store.On("StoreTx", &TxRecord{ Txid: tx.TxHash(), @@ -945,7 +915,7 @@ func TestHandleBumpEventTxPublished(t *testing.T) { }).Return(nil).Once() // Call the method under test. - err := s.handleBumpEventTxPublished(br) + err := s.handleBumpEventTxPublished(resp) require.NoError(t, err) // Assert the state of the input is updated. @@ -963,25 +933,21 @@ func TestMonitorFeeBumpResult(t *testing.T) { wallet := &MockWallet{} defer wallet.AssertExpectations(t) + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: store, Wallet: wallet, }) - // Create a testing outpoint. - op := wire.OutPoint{Hash: chainhash.Hash{1}} - // Create a mock input. - inp := &input.MockInput{} - defer inp.AssertExpectations(t) - - // Construct the initial state for the sweeper. - s.inputs = InputsMap{ - op: &SweeperInput{Input: inp, state: PendingPublish}, - } + inp := createMockInput(t, s, PendingPublish) // Create a testing tx that spends the input. + op := inp.OutPoint() tx := &wire.MsgTx{ LockTime: 1, TxIn: []*wire.TxIn{ @@ -1060,7 +1026,8 @@ func TestMonitorFeeBumpResult(t *testing.T) { return resultChan }, shouldExit: false, - }, { + }, + { // When the sweeper is shutting down, the monitor loop // should exit. name: "exit on sweeper shutdown", @@ -1087,7 +1054,7 @@ func TestMonitorFeeBumpResult(t *testing.T) { s.wg.Add(1) go func() { - s.monitorFeeBumpResult(resultChan) + s.monitorFeeBumpResult(set, resultChan) close(done) }() From 6f1511a1cab63999dfbdd72ce8eee454daddd80d Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 25 Oct 2024 18:31:46 +0800 Subject: [PATCH 13/59] sweep: add `handleInitialBroadcast` to handle initial broadcast This commit adds a new method `handleInitialBroadcast` to handle the initial broadcast. Previously we'd broadcast immediately inside `Broadcast`, which soon will not work after the `blockbeat` is implemented as the action to publish is now always triggered by a new block. Meanwhile, we still keep the option to bypass the block trigger so users can broadcast immediately by setting `Immediate` to true. --- sweep/fee_bumper.go | 172 ++++++++++++++----- sweep/fee_bumper_test.go | 362 ++++++++++++++++++++++++++------------- 2 files changed, 370 insertions(+), 164 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 1dff9bea88..1e159fa127 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -376,40 +376,52 @@ func (t *TxPublisher) isNeutrinoBackend() bool { return t.cfg.Wallet.BackEnd() == "neutrino" } -// Broadcast is used to publish the tx created from the given inputs. It will, -// 1. init a fee function based on the given strategy. -// 2. create an RBF-compliant tx and monitor it for confirmation. -// 3. notify the initial broadcast result back to the caller. -// The initial broadcast is guaranteed to be RBF-compliant unless the budget -// specified cannot cover the fee. +// Broadcast is used to publish the tx created from the given inputs. It will +// register the broadcast request and return a chan to the caller to subscribe +// the broadcast result. The initial broadcast is guaranteed to be +// RBF-compliant unless the budget specified cannot cover the fee. // // NOTE: part of the Bumper interface. func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { log.Tracef("Received broadcast request: %s", lnutils.SpewLogClosure( req)) - // Attempt an initial broadcast which is guaranteed to comply with the - // RBF rules. - result, err := t.initialBroadcast(req) - if err != nil { - log.Errorf("Initial broadcast failed: %v", err) - - return nil, err - } + // Store the request. + requestID, record := t.storeInitialRecord(req) // Create a chan to send the result to the caller. subscriber := make(chan *BumpResult, 1) - t.subscriberChans.Store(result.requestID, subscriber) + t.subscriberChans.Store(requestID, subscriber) - // Send the initial broadcast result to the caller. - t.handleResult(result) + // Publish the tx immediately if specified. + if req.Immediate { + t.handleInitialBroadcast(record, requestID) + } return subscriber, nil } +// storeInitialRecord initializes a monitor record and saves it in the map. +func (t *TxPublisher) storeInitialRecord(req *BumpRequest) ( + uint64, *monitorRecord) { + + // Increase the request counter. + // + // NOTE: this is the only place where we increase the counter. + requestID := t.requestCounter.Add(1) + + // Register the record. + record := &monitorRecord{req: req} + t.records.Store(requestID, record) + + return requestID, record +} + // initialBroadcast initializes a fee function, creates an RBF-compliant tx and // broadcasts it. -func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { +func (t *TxPublisher) initialBroadcast(requestID uint64, + req *BumpRequest) (*BumpResult, error) { + // Create a fee bumping algorithm to be used for future RBF. feeAlgo, err := t.initializeFeeFunction(req) if err != nil { @@ -418,7 +430,7 @@ func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { // Create the initial tx to be broadcasted. This tx is guaranteed to // comply with the RBF restrictions. - requestID, err := t.createRBFCompliantTx(req, feeAlgo) + err = t.createRBFCompliantTx(requestID, req, feeAlgo) if err != nil { return nil, fmt.Errorf("create RBF-compliant tx: %w", err) } @@ -465,8 +477,8 @@ func (t *TxPublisher) initializeFeeFunction( // so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee // and redo the process until the tx is valid, or return an error when non-RBF // related errors occur or the budget has been used up. -func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, - f FeeFunction) (uint64, error) { +func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest, + f FeeFunction) error { for { // Create a new tx with the given fee rate and check its @@ -475,18 +487,19 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, switch { case err == nil: - // The tx is valid, return the request ID. - requestID := t.storeRecord( - sweepCtx.tx, req, f, sweepCtx.fee, + // The tx is valid, store it. + t.storeRecord( + requestID, sweepCtx.tx, req, f, sweepCtx.fee, sweepCtx.outpointToTxIndex, ) - log.Infof("Created tx %v for %v inputs: feerate=%v, "+ - "fee=%v, inputs=%v", sweepCtx.tx.TxHash(), - len(req.Inputs), f.FeeRate(), sweepCtx.fee, + log.Infof("Created initial sweep tx=%v for %v inputs: "+ + "feerate=%v, fee=%v, inputs:\n%v", + sweepCtx.tx.TxHash(), len(req.Inputs), + f.FeeRate(), sweepCtx.fee, inputTypeSummary(req.Inputs)) - return requestID, nil + return nil // If the error indicates the fees paid is not enough, we will // ask the fee function to increase the fee rate and retry. @@ -517,7 +530,7 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // cluster these inputs differetly. increased, err = f.Increment() if err != nil { - return 0, err + return err } } @@ -527,21 +540,15 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // mempool acceptance. default: log.Debugf("Failed to create RBF-compliant tx: %v", err) - return 0, err + return err } } } // storeRecord stores the given record in the records map. -func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, - f FeeFunction, fee btcutil.Amount, - outpointToTxIndex map[wire.OutPoint]int) uint64 { - - // Increase the request counter. - // - // NOTE: this is the only place where we increase the - // counter. - requestID := t.requestCounter.Add(1) +func (t *TxPublisher) storeRecord(requestID uint64, tx *wire.MsgTx, + req *BumpRequest, f FeeFunction, fee btcutil.Amount, + outpointToTxIndex map[wire.OutPoint]int) { // Register the record. t.records.Store(requestID, &monitorRecord{ @@ -551,8 +558,6 @@ func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, fee: fee, outpointToTxIndex: outpointToTxIndex, }) - - return requestID } // createAndCheckTx creates a tx based on the given inputs, change output @@ -852,18 +857,27 @@ func (t *TxPublisher) processRecords() { // confirmed. confirmedRecords := make(map[uint64]*monitorRecord) - // feeBumpRecords stores a map of the records which need to be bumped. + // feeBumpRecords stores a map of records which need to be bumped. feeBumpRecords := make(map[uint64]*monitorRecord) - // failedRecords stores a map of the records which has inputs being - // spent by a third party. + // failedRecords stores a map of records which has inputs being spent + // by a third party. // // NOTE: this is only used for neutrino backend. failedRecords := make(map[uint64]*monitorRecord) + // initialRecords stores a map of records which are being created and + // published for the first time. + initialRecords := make(map[uint64]*monitorRecord) + // visitor is a helper closure that visits each record and divides them // into two groups. visitor := func(requestID uint64, r *monitorRecord) error { + if r.tx == nil { + initialRecords[requestID] = r + return nil + } + log.Tracef("Checking monitor recordID=%v for tx=%v", requestID, r.tx.TxHash()) @@ -891,9 +905,14 @@ func (t *TxPublisher) processRecords() { return nil } - // Iterate through all the records and divide them into two groups. + // Iterate through all the records and divide them into four groups. t.records.ForEach(visitor) + // Handle the initial broadcast. + for requestID, r := range initialRecords { + t.handleInitialBroadcast(r, requestID) + } + // For records that are confirmed, we'll notify the caller about this // result. for requestID, r := range confirmedRecords { @@ -949,6 +968,69 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { t.handleResult(result) } +// handleInitialBroadcast is called when a new request is received. It will +// handle the initial tx creation and broadcast. In details, +// 1. init a fee function based on the given strategy. +// 2. create an RBF-compliant tx and monitor it for confirmation. +// 3. notify the initial broadcast result back to the caller. +func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord, + requestID uint64) { + + log.Debugf("Initial broadcast for requestID=%v", requestID) + + var ( + result *BumpResult + err error + ) + + // Attempt an initial broadcast which is guaranteed to comply with the + // RBF rules. + result, err = t.initialBroadcast(requestID, r.req) + if err != nil { + log.Errorf("Initial broadcast failed: %v", err) + + // We now decide what type of event to send. + var event BumpEvent + + switch { + // When the error is due to a dust output, we'll send a + // TxFailed so these inputs can be retried with a different + // group in the next block. + case errors.Is(err, ErrTxNoOutput): + event = TxFailed + + // When the error is due to budget being used up, we'll send a + // TxFailed so these inputs can be retried with a different + // group in the next block. + case errors.Is(err, ErrMaxPosition): + event = TxFailed + + // When the error is due to zero fee rate delta, we'll send a + // TxFailed so these inputs can be retried in the next block. + case errors.Is(err, ErrZeroFeeRateDelta): + event = TxFailed + + // Otherwise this is not a fee-related error and the tx cannot + // be retried. In that case we will fail ALL the inputs in this + // tx, which means they will be removed from the sweeper and + // never be tried again. + // + // TODO(yy): Find out which input is causing the failure and + // fail that one only. + default: + event = TxFatal + } + + result = &BumpResult{ + Event: event, + Err: err, + requestID: requestID, + } + } + + t.handleResult(result) +} + // handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will // attempt to bump the fee of the tx. // diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 0c107b29ff..50c1fdced6 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -352,13 +352,10 @@ func TestStoreRecord(t *testing.T) { } // Call the method under test. - requestID := tp.storeRecord(tx, req, feeFunc, fee, utxoIndex) - - // Check the request ID is as expected. - require.Equal(t, initialCounter+1, requestID) + tp.storeRecord(initialCounter, tx, req, feeFunc, fee, utxoIndex) // Read the saved record and compare. - record, ok := tp.records.Load(requestID) + record, ok := tp.records.Load(initialCounter) require.True(t, ok) require.Equal(t, tx, record.tx) require.Equal(t, feeFunc, record.feeFunction) @@ -655,23 +652,19 @@ func TestCreateRBFCompliantTx(t *testing.T) { }, } + var requestCounter atomic.Uint64 for _, tc := range testCases { tc := tc + rid := requestCounter.Add(1) t.Run(tc.name, func(t *testing.T) { tc.setupMock() // Call the method under test. - id, err := tp.createRBFCompliantTx(req, m.feeFunc) + err := tp.createRBFCompliantTx(rid, req, m.feeFunc) // Check the result is as expected. require.ErrorIs(t, err, tc.expectedErr) - - // If there's an error, expect the requestID to be - // empty. - if tc.expectedErr != nil { - require.Zero(t, id) - } }) } } @@ -704,7 +697,8 @@ func TestTxPublisherBroadcast(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex) // Quickly check when the requestID cannot be found, an error is // returned. @@ -799,6 +793,9 @@ func TestRemoveResult(t *testing.T) { op: 0, } + // Create a test request ID counter. + requestCounter := atomic.Uint64{} + testCases := []struct { name string setupRecord func() uint64 @@ -810,12 +807,13 @@ func TestRemoveResult(t *testing.T) { // removed. name: "remove on TxConfirmed", setupRecord: func() uint64 { - id := tp.storeRecord( - tx, req, m.feeFunc, fee, utxoIndex, + rid := requestCounter.Add(1) + tp.storeRecord( + rid, tx, req, m.feeFunc, fee, utxoIndex, ) - tp.subscriberChans.Store(id, nil) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxConfirmed, @@ -827,12 +825,13 @@ func TestRemoveResult(t *testing.T) { // When the tx is failed, the records will be removed. name: "remove on TxFailed", setupRecord: func() uint64 { - id := tp.storeRecord( - tx, req, m.feeFunc, fee, utxoIndex, + rid := requestCounter.Add(1) + tp.storeRecord( + rid, tx, req, m.feeFunc, fee, utxoIndex, ) - tp.subscriberChans.Store(id, nil) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxFailed, @@ -845,12 +844,13 @@ func TestRemoveResult(t *testing.T) { // Noop when the tx is neither confirmed or failed. name: "noop when tx is not confirmed or failed", setupRecord: func() uint64 { - id := tp.storeRecord( - tx, req, m.feeFunc, fee, utxoIndex, + rid := requestCounter.Add(1) + tp.storeRecord( + rid, tx, req, m.feeFunc, fee, utxoIndex, ) - tp.subscriberChans.Store(id, nil) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxPublished, @@ -905,7 +905,8 @@ func TestNotifyResult(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -953,41 +954,17 @@ func TestNotifyResult(t *testing.T) { } } -// TestBroadcastSuccess checks the public `Broadcast` method can successfully -// broadcast a tx based on the request. -func TestBroadcastSuccess(t *testing.T) { +// TestBroadcast checks the public `Broadcast` method can successfully register +// a broadcast request. +func TestBroadcast(t *testing.T) { t.Parallel() // Create a publisher using the mocks. - tp, m := createTestPublisher(t) + tp, _ := createTestPublisher(t) // Create a test feerate. feerate := chainfee.SatPerKWeight(1000) - // Mock the fee estimator to return the testing fee rate. - // - // We are not testing `NewLinearFeeFunction` here, so the actual params - // used are irrelevant. - m.estimator.On("EstimateFeePerKW", mock.Anything).Return( - feerate, nil).Once() - m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() - - // Mock the signer to always return a valid script. - // - // NOTE: we are not testing the utility of creating valid txes here, so - // this is fine to be mocked. This behaves essentially as skipping the - // Signer check and alaways assume the tx has a valid sig. - script := &input.Script{} - m.signer.On("ComputeInputScript", mock.Anything, - mock.Anything).Return(script, nil) - - // Mock the testmempoolaccept to pass. - m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - - // Mock the wallet to publish successfully. - m.wallet.On("PublishTransaction", - mock.Anything, mock.Anything).Return(nil).Once() - // Create a test request. inp := createTestInput(1000, input.WitnessKeyHash) @@ -1003,25 +980,23 @@ func TestBroadcastSuccess(t *testing.T) { // Send the req and expect no error. resultChan, err := tp.Broadcast(req) require.NoError(t, err) - - // Check the result is sent back. - select { - case <-time.After(time.Second): - t.Fatal("timeout waiting for subscriber to receive result") - - case result := <-resultChan: - // We expect the first result to be TxPublished. - require.Equal(t, TxPublished, result.Event) - } + require.NotNil(t, resultChan) // Validate the record was stored. require.Equal(t, 1, tp.records.Len()) require.Equal(t, 1, tp.subscriberChans.Len()) + + // Validate the record. + rid := tp.requestCounter.Load() + record, found := tp.records.Load(rid) + require.True(t, found) + require.Equal(t, req, record.req) } -// TestBroadcastFail checks the public `Broadcast` returns the error or a -// failed result when the broadcast fails. -func TestBroadcastFail(t *testing.T) { +// TestBroadcastImmediate checks the public `Broadcast` method can successfully +// register a broadcast request and publish the tx when `Immediate` flag is +// set. +func TestBroadcastImmediate(t *testing.T) { t.Parallel() // Create a publisher using the mocks. @@ -1040,64 +1015,28 @@ func TestBroadcastFail(t *testing.T) { Budget: btcutil.Amount(1000), MaxFeeRate: feerate * 10, DeadlineHeight: 10, + Immediate: true, } - // Mock the fee estimator to return the testing fee rate. + // Mock the fee estimator to return an error. // - // We are not testing `NewLinearFeeFunction` here, so the actual params - // used are irrelevant. + // NOTE: We are not testing `handleInitialBroadcast` here, but only + // interested in checking that this method is indeed called when + // `Immediate` is true. Thus we mock the method to return an error to + // quickly abort. As long as this mocked method is called, we know the + // `Immediate` flag works. m.estimator.On("EstimateFeePerKW", mock.Anything).Return( - feerate, nil).Twice() - m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() - - // Mock the signer to always return a valid script. - // - // NOTE: we are not testing the utility of creating valid txes here, so - // this is fine to be mocked. This behaves essentially as skipping the - // Signer check and alaways assume the tx has a valid sig. - script := &input.Script{} - m.signer.On("ComputeInputScript", mock.Anything, - mock.Anything).Return(script, nil) - - // Mock the testmempoolaccept to return an error. - m.wallet.On("CheckMempoolAcceptance", - mock.Anything).Return(errDummy).Once() + chainfee.SatPerKWeight(0), errDummy).Once() - // Send the req and expect an error returned. + // Send the req and expect no error. resultChan, err := tp.Broadcast(req) - require.ErrorIs(t, err, errDummy) - require.Nil(t, resultChan) - - // Validate the record was NOT stored. - require.Equal(t, 0, tp.records.Len()) - require.Equal(t, 0, tp.subscriberChans.Len()) - - // Mock the testmempoolaccept again, this time it passes. - m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - - // Mock the wallet to fail on publish. - m.wallet.On("PublishTransaction", - mock.Anything, mock.Anything).Return(errDummy).Once() - - // Send the req and expect no error returned. - resultChan, err = tp.Broadcast(req) require.NoError(t, err) + require.NotNil(t, resultChan) - // Check the result is sent back. - select { - case <-time.After(time.Second): - t.Fatal("timeout waiting for subscriber to receive result") - - case result := <-resultChan: - // We expect the result to be TxFailed and the error is set in - // the result. - require.Equal(t, TxFailed, result.Event) - require.ErrorIs(t, result.Err, errDummy) - } - - // Validate the record was removed. - require.Equal(t, 0, tp.records.Len()) - require.Equal(t, 0, tp.subscriberChans.Len()) + // Validate the record was removed due to an error returned in initial + // broadcast. + require.Empty(t, tp.records.Len()) + require.Empty(t, tp.subscriberChans.Len()) } // TestCreateAnPublishFail checks all the error cases are handled properly in @@ -1270,7 +1209,8 @@ func TestHandleTxConfirmed(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex) record, ok := tp.records.Load(requestID) require.True(t, ok) @@ -1350,7 +1290,8 @@ func TestHandleFeeBumpTx(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -1551,3 +1492,186 @@ func TestProcessRecords(t *testing.T) { require.Equal(t, requestID2, result.requestID) } } + +// TestHandleInitialBroadcastSuccess checks `handleInitialBroadcast` method can +// successfully broadcast a tx based on the request. +func TestHandleInitialBroadcastSuccess(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Once() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to pass. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate * 10, + DeadlineHeight: 10, + } + + // Register the testing record use `Broadcast`. + resultChan, err := tp.Broadcast(req) + require.NoError(t, err) + + // Grab the monitor record from the map. + rid := tp.requestCounter.Load() + rec, ok := tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxPublished. + require.Equal(t, TxPublished, result.Event) + } + + // Validate the record was stored. + require.Equal(t, 1, tp.records.Len()) + require.Equal(t, 1, tp.subscriberChans.Len()) +} + +// TestHandleInitialBroadcastFail checks `handleInitialBroadcast` returns the +// error or a failed result when the broadcast fails. +func TestHandleInitialBroadcastFail(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate * 10, + DeadlineHeight: 10, + } + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Twice() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return an error. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + + // Register the testing record use `Broadcast`. + resultChan, err := tp.Broadcast(req) + require.NoError(t, err) + + // Grab the monitor record from the map. + rid := tp.requestCounter.Load() + rec, ok := tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test and expect an error returned. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxFatal. + require.Equal(t, TxFatal, result.Event) + } + + // Validate the record was NOT stored. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) + + // Mock the testmempoolaccept again, this time it passes. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to fail on publish. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(errDummy).Once() + + // Register the testing record use `Broadcast`. + resultChan, err = tp.Broadcast(req) + require.NoError(t, err) + + // Grab the monitor record from the map. + rid = tp.requestCounter.Load() + rec, ok = tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the result to be TxFailed and the error is set in + // the result. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, errDummy) + } + + // Validate the record was removed. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) +} From f4635a2189f8a0402e20668102b11c96efecf94d Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 17:39:45 +0800 Subject: [PATCH 14/59] sweep: remove redundant error from `Broadcast` --- sweep/fee_bumper.go | 10 +++++----- sweep/fee_bumper_test.go | 15 +++++---------- sweep/mock_test.go | 6 +++--- sweep/sweeper.go | 11 +---------- sweep/sweeper_test.go | 9 ++------- 5 files changed, 16 insertions(+), 35 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 1e159fa127..c2902bb60b 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -65,7 +65,7 @@ type Bumper interface { // and monitors its confirmation status for potential fee bumping. It // returns a chan that the caller can use to receive updates about the // broadcast result and potential RBF attempts. - Broadcast(req *BumpRequest) (<-chan *BumpResult, error) + Broadcast(req *BumpRequest) <-chan *BumpResult } // BumpEvent represents the event of a fee bumping attempt. @@ -382,9 +382,9 @@ func (t *TxPublisher) isNeutrinoBackend() bool { // RBF-compliant unless the budget specified cannot cover the fee. // // NOTE: part of the Bumper interface. -func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { - log.Tracef("Received broadcast request: %s", lnutils.SpewLogClosure( - req)) +func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult { + log.Tracef("Received broadcast request: %s", + lnutils.SpewLogClosure(req)) // Store the request. requestID, record := t.storeInitialRecord(req) @@ -398,7 +398,7 @@ func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { t.handleInitialBroadcast(record, requestID) } - return subscriber, nil + return subscriber } // storeInitialRecord initializes a monitor record and saves it in the map. diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 50c1fdced6..ba41d65695 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -978,8 +978,7 @@ func TestBroadcast(t *testing.T) { } // Send the req and expect no error. - resultChan, err := tp.Broadcast(req) - require.NoError(t, err) + resultChan := tp.Broadcast(req) require.NotNil(t, resultChan) // Validate the record was stored. @@ -1029,8 +1028,7 @@ func TestBroadcastImmediate(t *testing.T) { chainfee.SatPerKWeight(0), errDummy).Once() // Send the req and expect no error. - resultChan, err := tp.Broadcast(req) - require.NoError(t, err) + resultChan := tp.Broadcast(req) require.NotNil(t, resultChan) // Validate the record was removed due to an error returned in initial @@ -1541,8 +1539,7 @@ func TestHandleInitialBroadcastSuccess(t *testing.T) { } // Register the testing record use `Broadcast`. - resultChan, err := tp.Broadcast(req) - require.NoError(t, err) + resultChan := tp.Broadcast(req) // Grab the monitor record from the map. rid := tp.requestCounter.Load() @@ -1613,8 +1610,7 @@ func TestHandleInitialBroadcastFail(t *testing.T) { mock.Anything).Return(errDummy).Once() // Register the testing record use `Broadcast`. - resultChan, err := tp.Broadcast(req) - require.NoError(t, err) + resultChan := tp.Broadcast(req) // Grab the monitor record from the map. rid := tp.requestCounter.Load() @@ -1647,8 +1643,7 @@ func TestHandleInitialBroadcastFail(t *testing.T) { mock.Anything, mock.Anything).Return(errDummy).Once() // Register the testing record use `Broadcast`. - resultChan, err = tp.Broadcast(req) - require.NoError(t, err) + resultChan = tp.Broadcast(req) // Grab the monitor record from the map. rid = tp.requestCounter.Load() diff --git a/sweep/mock_test.go b/sweep/mock_test.go index d42b0320d4..f9471f22a0 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -284,14 +284,14 @@ type MockBumper struct { var _ Bumper = (*MockBumper)(nil) // Broadcast broadcasts the transaction to the network. -func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { +func (m *MockBumper) Broadcast(req *BumpRequest) <-chan *BumpResult { args := m.Called(req) if args.Get(0) == nil { - return nil, args.Error(1) + return nil } - return args.Get(0).(chan *BumpResult), args.Error(1) + return args.Get(0).(chan *BumpResult) } // MockFeeFunction is a mock implementation of the FeeFunction interface. diff --git a/sweep/sweeper.go b/sweep/sweeper.go index ed39548014..fb4f212fca 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -838,16 +838,7 @@ func (s *UtxoSweeper) sweep(set InputSet) error { // Broadcast will return a read-only chan that we will listen to for // this publish result and future RBF attempt. - resp, err := s.cfg.Publisher.Broadcast(req) - if err != nil { - log.Errorf("Initial broadcast failed: %v, inputs=\n%v", err, - inputTypeSummary(set.Inputs())) - - // TODO(yy): find out which input is causing the failure. - s.markInputsPublishFailed(set) - - return err - } + resp := s.cfg.Publisher.Broadcast(req) // Successfully sent the broadcast attempt, we now handle the result by // subscribing to the result chan and listen for future updates about diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 2e1be71c26..5cdbba2975 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -673,13 +673,8 @@ func TestSweepPendingInputs(t *testing.T) { setNeedWallet, normalSet, }) - // Mock `Broadcast` to return an error. This should cause the - // `createSweepTx` inside `sweep` to fail. This is done so we can - // terminate the method early as we are only interested in testing the - // workflow in `sweepPendingInputs`. We don't need to test `sweep` here - // as it should be tested in its own unit test. - dummyErr := errors.New("dummy error") - publisher.On("Broadcast", mock.Anything).Return(nil, dummyErr).Twice() + // Mock `Broadcast` to return a result. + publisher.On("Broadcast", mock.Anything).Return(nil).Twice() // Call the method under test. s.sweepPendingInputs(pis) From f7b301d52d969a68aa201596792b7c052d013fcc Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 19:25:06 +0800 Subject: [PATCH 15/59] sweep: add method `handleBumpEventError` and fix `markInputFailed` Previously in `markInputFailed`, we'd remove all inputs under the same group via `removeExclusiveGroup`. This is wrong as when the current sweep fails for this input, it shouldn't affect other inputs. --- sweep/sweeper.go | 65 +++++++++++++++++++--- sweep/sweeper_test.go | 122 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+), 6 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index fb4f212fca..0d1d40492d 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -1441,11 +1441,6 @@ func (s *UtxoSweeper) markInputFailed(pi *SweeperInput, err error) { pi.state = Failed - // Remove all other inputs in this exclusive group. - if pi.params.ExclusiveGroup != nil { - s.removeExclusiveGroup(*pi.params.ExclusiveGroup) - } - s.signalResult(pi, Result{Err: err}) } @@ -1728,6 +1723,62 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(resp *bumpResp) error { return nil } +// handleBumpEventTxFatal handles the case where there's an unexpected error +// when creating or publishing the sweeping tx. In this case, the tx will be +// removed from the sweeper store and the inputs will be marked as `Failed`, +// which means they will not be retried. +func (s *UtxoSweeper) handleBumpEventTxFatal(resp *bumpResp) error { + r := resp.result + + // Remove the tx from the sweeper store if there is one. Since this is + // a broadcast error, it's likely there isn't a tx here. + if r.Tx != nil { + txid := r.Tx.TxHash() + log.Infof("Tx=%v failed with unexpected error: %v", txid, r.Err) + + // Remove the tx from the sweeper db if it exists. + if err := s.cfg.Store.DeleteTx(txid); err != nil { + return fmt.Errorf("delete tx record for %v: %w", txid, + err) + } + } + + // Mark the inputs as failed. + s.markInputsFailed(resp.set, r.Err) + + return nil +} + +// markInputsFailed marks all inputs found in the tx as failed. It will also +// notify all the subscribers of these inputs. +func (s *UtxoSweeper) markInputsFailed(set InputSet, err error) { + for _, inp := range set.Inputs() { + outpoint := inp.OutPoint() + + input, ok := s.inputs[outpoint] + if !ok { + // It's very likely that a spending tx contains inputs + // that we don't know. + log.Tracef("Skipped marking input as failed: %v not "+ + "found in pending inputs", outpoint) + + continue + } + + // If the input is already in a terminal state, we don't want + // to rewrite it, which also indicates an error as we only get + // an error event during the initial broadcast. + if input.terminated() { + log.Errorf("Skipped marking input=%v as failed due to "+ + "unexpected state=%v", outpoint, input.state) + + continue + } + + s.markInputFailed(input, err) + } +} + // handleBumpEvent handles the result sent from the bumper based on its event // type. // @@ -1752,8 +1803,10 @@ func (s *UtxoSweeper) handleBumpEvent(r *bumpResp) error { case TxReplaced: return s.handleBumpEventTxReplaced(r) + // There's a fatal error in creating the tx, we will remove the tx from + // the sweeper db and mark the inputs as failed. case TxFatal: - // TODO(yy): create a method to remove this input. + return s.handleBumpEventTxFatal(r) } return nil diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 5cdbba2975..39b758cb9b 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1075,3 +1075,125 @@ func TestMonitorFeeBumpResult(t *testing.T) { }) } } + +// TestMarkInputsFailed checks that given a list of inputs with different +// states, the method `markInputsFailed` correctly marks the inputs as failed. +func TestMarkInputsFailed(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{}) + + // Create testing inputs for each state. + // - inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + // - inputPendingPublish specifies an input about to be published. + // - inputPublished specifies an input that's published. + // - inputPublishFailed specifies an input that's failed to be + // published. + // - inputSwept specifies an input that's swept. + // - inputExcluded specifies an input that's excluded. + // - inputFailed specifies an input that's failed. + var ( + inputInit = createMockInput(t, s, Init) + inputPendingPublish = createMockInput(t, s, PendingPublish) + inputPublished = createMockInput(t, s, Published) + inputPublishFailed = createMockInput(t, s, PublishFailed) + inputSwept = createMockInput(t, s, Swept) + inputExcluded = createMockInput(t, s, Excluded) + inputFailed = createMockInput(t, s, Failed) + ) + + // Gather all inputs. + set.On("Inputs").Return([]input.Input{ + inputInit, inputPendingPublish, inputPublished, + inputPublishFailed, inputSwept, inputExcluded, inputFailed, + }) + + // Mark the test inputs. We expect the non-exist input and + // inputSwept/inputExcluded/inputFailed to be skipped. + s.markInputsFailed(set, errDummy) + + // We expect unchanged number of pending inputs. + require.Len(s.inputs, 7) + + // We expect the init input's to be marked as failed. + require.Equal(Failed, s.inputs[inputInit.OutPoint()].state) + + // We expect the pending-publish input to be marked as failed. + require.Equal(Failed, s.inputs[inputPendingPublish.OutPoint()].state) + + // We expect the published input to be marked as failed. + require.Equal(Failed, s.inputs[inputPublished.OutPoint()].state) + + // We expect the publish failed input to be markd as failed. + require.Equal(Failed, s.inputs[inputPublishFailed.OutPoint()].state) + + // We expect the swept input to stay unchanged. + require.Equal(Swept, s.inputs[inputSwept.OutPoint()].state) + + // We expect the excluded input to stay unchanged. + require.Equal(Excluded, s.inputs[inputExcluded.OutPoint()].state) + + // We expect the failed input to stay unchanged. + require.Equal(Failed, s.inputs[inputFailed.OutPoint()].state) +} + +// TestHandleBumpEventTxFatal checks that `handleBumpEventTxFatal` correctly +// handles a `TxFatal` event. +func TestHandleBumpEventTxFatal(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a mock input set. We are not testing `markInputFailed` here, + // so the actual set doesn't matter. + set := &MockInputSet{} + defer set.AssertExpectations(t) + set.On("Inputs").Return(nil) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a dummy tx. + tx := &wire.MsgTx{ + LockTime: 1, + } + + // Create a testing bump response. + result := &BumpResult{ + Err: errDummy, + Tx: tx, + } + resp := &bumpResp{ + result: result, + set: set, + } + + // Mock the store to return an error. + store.On("DeleteTx", mock.Anything).Return(errDummy).Once() + + // Call the method under test and assert the error is returned. + err := s.handleBumpEventTxFatal(resp) + rt.ErrorIs(err, errDummy) + + // Mock the store to return nil. + store.On("DeleteTx", mock.Anything).Return(nil).Once() + + // Call the method under test and assert no error is returned. + err = s.handleBumpEventTxFatal(resp) + rt.NoError(err) +} From b3e7f4c51baa7a0095949a288c67c11b42ea3a9c Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 19:28:01 +0800 Subject: [PATCH 16/59] sweep: add method `isMature` on `SweeperInput` Also updated `handlePendingSweepsReq` to skip immature inputs so the returned results are the same as those in pre-0.18.0. --- sweep/sweeper.go | 51 +++++++++++++++++++++++++++++++------------ sweep/sweeper_test.go | 2 ++ 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 0d1d40492d..7f794431f3 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -222,6 +222,34 @@ func (p *SweeperInput) terminated() bool { } } +// isMature returns a boolean indicating whether the input has a timelock that +// has been reached or not. The locktime found is also returned. +func (p *SweeperInput) isMature(currentHeight uint32) (bool, uint32) { + locktime, _ := p.RequiredLockTime() + if currentHeight < locktime { + log.Debugf("Input %v has locktime=%v, current height is %v", + p.OutPoint(), locktime, currentHeight) + + return false, locktime + } + + // If the input has a CSV that's not yet reached, we will skip + // this input and wait for the expiry. + // + // NOTE: We need to consider whether this input can be included in the + // next block or not, which means the CSV will be checked against the + // currentHeight plus one. + locktime = p.BlocksToMaturity() + p.HeightHint() + if currentHeight+1 < locktime { + log.Debugf("Input %v has CSV expiry=%v, current height is %v", + p.OutPoint(), locktime, currentHeight) + + return false, locktime + } + + return true, locktime +} + // InputsMap is a type alias for a set of pending inputs. type InputsMap = map[wire.OutPoint]*SweeperInput @@ -1038,6 +1066,12 @@ func (s *UtxoSweeper) handlePendingSweepsReq( resps := make(map[wire.OutPoint]*PendingInputResponse, len(s.inputs)) for _, inp := range s.inputs { + // Skip immature inputs for compatibility. + mature, _ := inp.isMature(uint32(s.currentHeight)) + if !mature { + continue + } + // Only the exported fields are set, as we expect the response // to only be consumed externally. op := inp.OutPoint() @@ -1485,20 +1519,9 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // If the input has a locktime that's not yet reached, we will // skip this input and wait for the locktime to be reached. - locktime, _ := input.RequiredLockTime() - if uint32(s.currentHeight) < locktime { - log.Warnf("Skipping input %v due to locktime=%v not "+ - "reached, current height is %v", op, locktime, - s.currentHeight) - - continue - } - - // If the input has a CSV that's not yet reached, we will skip - // this input and wait for the expiry. - locktime = input.BlocksToMaturity() + input.HeightHint() - if s.currentHeight < int32(locktime)-1 { - log.Infof("Skipping input %v due to CSV expiry=%v not "+ + mature, locktime := input.isMature(uint32(s.currentHeight)) + if !mature { + log.Infof("Skipping input %v due to locktime=%v not "+ "reached, current height is %v", op, locktime, s.currentHeight) diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 39b758cb9b..5b827aff89 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -443,6 +443,7 @@ func TestUpdateSweeperInputs(t *testing.T) { // returned. inp2.On("RequiredLockTime").Return( uint32(s.currentHeight+1), true).Once() + inp2.On("OutPoint").Return(wire.OutPoint{Index: 2}).Maybe() input7 := &SweeperInput{state: Init, Input: inp2} // Mock the input to have a CSV expiry in the future so it will NOT be @@ -451,6 +452,7 @@ func TestUpdateSweeperInputs(t *testing.T) { uint32(s.currentHeight), false).Once() inp3.On("BlocksToMaturity").Return(uint32(2)).Once() inp3.On("HeightHint").Return(uint32(s.currentHeight)).Once() + inp3.On("OutPoint").Return(wire.OutPoint{Index: 3}).Maybe() input8 := &SweeperInput{state: Init, Input: inp3} // Add the inputs to the sweeper. After the update, we should see the From 22349379fc5b92a358af1139299d681e1ca764b4 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 1 May 2024 02:58:27 +0800 Subject: [PATCH 17/59] sweep: make sure defaultDeadline is derived from the mature height --- sweep/sweeper.go | 54 +++++++++++++++++++++++++++++++++++-------- sweep/sweeper_test.go | 2 +- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 7f794431f3..813d0b0ca0 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -530,7 +530,7 @@ func (s *UtxoSweeper) SweepInput(inp input.Input, } absoluteTimeLock, _ := inp.RequiredLockTime() - log.Infof("Sweep request received: out_point=%v, witness_type=%v, "+ + log.Debugf("Sweep request received: out_point=%v, witness_type=%v, "+ "relative_time_lock=%v, absolute_time_lock=%v, amount=%v, "+ "parent=(%v), params=(%v)", inp.OutPoint(), inp.WitnessType(), inp.BlocksToMaturity(), absoluteTimeLock, @@ -736,7 +736,18 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { inputs := s.updateSweeperInputs() log.Debugf("Received new block: height=%v, attempt "+ - "sweeping %d inputs", epoch.Height, len(inputs)) + "sweeping %d inputs:\n%s", + epoch.Height, len(inputs), + lnutils.NewLogClosure(func() string { + inps := make( + []input.Input, 0, len(inputs), + ) + for _, in := range inputs { + inps = append(inps, in) + } + + return inputTypeSummary(inps) + })) // Attempt to sweep any pending inputs. s.sweepPendingInputs(inputs) @@ -1207,13 +1218,29 @@ func (s *UtxoSweeper) mempoolLookup(op wire.OutPoint) fn.Option[wire.MsgTx] { return s.cfg.Mempool.LookupInputMempoolSpend(op) } -// handleNewInput processes a new input by registering spend notification and -// scheduling sweeping for it. -func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { +// calculateDefaultDeadline calculates the default deadline height for a sweep +// request that has no deadline height specified. +func (s *UtxoSweeper) calculateDefaultDeadline(pi *SweeperInput) int32 { // Create a default deadline height, which will be used when there's no // DeadlineHeight specified for a given input. defaultDeadline := s.currentHeight + int32(s.cfg.NoDeadlineConfTarget) + // If the input is immature and has a locktime, we'll use the locktime + // height as the starting height. + matured, locktime := pi.isMature(uint32(s.currentHeight)) + if !matured { + defaultDeadline = int32(locktime + s.cfg.NoDeadlineConfTarget) + log.Debugf("Input %v is immature, using locktime=%v instead "+ + "of current height=%d", pi.OutPoint(), locktime, + s.currentHeight) + } + + return defaultDeadline +} + +// handleNewInput processes a new input by registering spend notification and +// scheduling sweeping for it. +func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { outpoint := input.input.OutPoint() pi, pending := s.inputs[outpoint] if pending { @@ -1238,15 +1265,22 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { Input: input.input, params: input.params, rbf: rbfInfo, - // Set the acutal deadline height. - DeadlineHeight: input.params.DeadlineHeight.UnwrapOr( - defaultDeadline, - ), } + // Set the acutal deadline height. + pi.DeadlineHeight = input.params.DeadlineHeight.UnwrapOr( + s.calculateDefaultDeadline(pi), + ) + s.inputs[outpoint] = pi log.Tracef("input %v, state=%v, added to inputs", outpoint, pi.state) + log.Infof("Registered sweep request at block %d: out_point=%v, "+ + "witness_type=%v, amount=%v, deadline=%d, params=(%v)", + s.currentHeight, pi.OutPoint(), pi.WitnessType(), + btcutil.Amount(pi.SignDesc().Output.Value), pi.DeadlineHeight, + pi.params) + // Start watching for spend of this input, either by us or the remote // party. cancel, err := s.monitorSpend( @@ -1660,7 +1694,7 @@ func (s *UtxoSweeper) handleBumpEventTxFailed(resp *bumpResp) { r := resp.result tx, err := r.Tx, r.Err - log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) + log.Warnf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) // NOTE: When marking the inputs as failed, we are using the input set // instead of the inputs found in the tx. This is fine for current diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 5b827aff89..16a4a46fbe 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -739,7 +739,7 @@ func TestHandleBumpEventTxFailed(t *testing.T) { // Call the method under test. err := s.handleBumpEvent(resp) - require.ErrorIs(t, err, errDummy) + require.NoError(t, err) // Assert the states of the first two inputs are updated. require.Equal(t, PublishFailed, s.inputs[op1].state) From 7d4ccb76060fa611d3ac155778886181bc8f2b46 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 25 Oct 2024 15:17:41 +0800 Subject: [PATCH 18/59] sweep: remove redundant loopvar assign --- sweep/fee_bumper.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index c2902bb60b..c08a0a762b 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -916,11 +916,9 @@ func (t *TxPublisher) processRecords() { // For records that are confirmed, we'll notify the caller about this // result. for requestID, r := range confirmedRecords { - rec := r - log.Debugf("Tx=%v is confirmed", r.tx.TxHash()) t.wg.Add(1) - go t.handleTxConfirmed(rec, requestID) + go t.handleTxConfirmed(r, requestID) } // Get the current height to be used in the following goroutines. @@ -928,22 +926,18 @@ func (t *TxPublisher) processRecords() { // For records that are not confirmed, we perform a fee bump if needed. for requestID, r := range feeBumpRecords { - rec := r - log.Debugf("Attempting to fee bump Tx=%v", r.tx.TxHash()) t.wg.Add(1) - go t.handleFeeBumpTx(requestID, rec, currentHeight) + go t.handleFeeBumpTx(requestID, r, currentHeight) } // For records that are failed, we'll notify the caller about this // result. for requestID, r := range failedRecords { - rec := r - log.Debugf("Tx=%v has inputs been spent by a third party, "+ "failing it now", r.tx.TxHash()) t.wg.Add(1) - go t.handleThirdPartySpent(rec, requestID) + go t.handleThirdPartySpent(r, requestID) } } From 383b6a81c72b7210aab917c145ccdf0778c96186 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 7 Nov 2024 15:07:26 +0800 Subject: [PATCH 19/59] sweep: break `initialBroadcast` into two steps With the combination of the following commit we can have a more granular control over the bump result when handling it in the sweeper. --- sweep/fee_bumper.go | 111 ++++++++++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 46 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index c08a0a762b..181f274839 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -417,31 +417,23 @@ func (t *TxPublisher) storeInitialRecord(req *BumpRequest) ( return requestID, record } -// initialBroadcast initializes a fee function, creates an RBF-compliant tx and -// broadcasts it. -func (t *TxPublisher) initialBroadcast(requestID uint64, - req *BumpRequest) (*BumpResult, error) { - +// initializeTx initializes a fee function and creates an RBF-compliant tx. If +// succeeded, the initial tx is stored in the records map. +func (t *TxPublisher) initializeTx(requestID uint64, req *BumpRequest) error { // Create a fee bumping algorithm to be used for future RBF. feeAlgo, err := t.initializeFeeFunction(req) if err != nil { - return nil, fmt.Errorf("init fee function: %w", err) + return fmt.Errorf("init fee function: %w", err) } // Create the initial tx to be broadcasted. This tx is guaranteed to // comply with the RBF restrictions. err = t.createRBFCompliantTx(requestID, req, feeAlgo) if err != nil { - return nil, fmt.Errorf("create RBF-compliant tx: %w", err) - } - - // Broadcast the tx and return the monitored record. - result, err := t.broadcast(requestID) - if err != nil { - return nil, fmt.Errorf("broadcast sweep tx: %w", err) + return fmt.Errorf("create RBF-compliant tx: %w", err) } - return result, nil + return nil } // initializeFeeFunction initializes a fee function to be used for this request @@ -962,6 +954,50 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { t.handleResult(result) } +// handleInitialTxError takes the error from `initializeTx` and decides the +// bump event. It will construct a BumpResult and handles it. +func (t *TxPublisher) handleInitialTxError(requestID uint64, err error) { + // We now decide what type of event to send. + var event BumpEvent + + switch { + // When the error is due to a dust output, we'll send a TxFailed so + // these inputs can be retried with a different group in the next + // block. + case errors.Is(err, ErrTxNoOutput): + event = TxFailed + + // When the error is due to budget being used up, we'll send a TxFailed + // so these inputs can be retried with a different group in the next + // block. + case errors.Is(err, ErrMaxPosition): + event = TxFailed + + // When the error is due to zero fee rate delta, we'll send a TxFailed + // so these inputs can be retried in the next block. + case errors.Is(err, ErrZeroFeeRateDelta): + event = TxFailed + + // Otherwise this is not a fee-related error and the tx cannot be + // retried. In that case we will fail ALL the inputs in this tx, which + // means they will be removed from the sweeper and never be tried + // again. + // + // TODO(yy): Find out which input is causing the failure and fail that + // one only. + default: + event = TxFatal + } + + result := &BumpResult{ + Event: event, + Err: err, + requestID: requestID, + } + + t.handleResult(result) +} + // handleInitialBroadcast is called when a new request is received. It will // handle the initial tx creation and broadcast. In details, // 1. init a fee function based on the given strategy. @@ -979,44 +1015,27 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord, // Attempt an initial broadcast which is guaranteed to comply with the // RBF rules. - result, err = t.initialBroadcast(requestID, r.req) + // + // Create the initial tx to be broadcasted. + err = t.initializeTx(requestID, r.req) if err != nil { log.Errorf("Initial broadcast failed: %v", err) - // We now decide what type of event to send. - var event BumpEvent + // We now handle the initialization error and exit. + t.handleInitialTxError(requestID, err) - switch { - // When the error is due to a dust output, we'll send a - // TxFailed so these inputs can be retried with a different - // group in the next block. - case errors.Is(err, ErrTxNoOutput): - event = TxFailed - - // When the error is due to budget being used up, we'll send a - // TxFailed so these inputs can be retried with a different - // group in the next block. - case errors.Is(err, ErrMaxPosition): - event = TxFailed - - // When the error is due to zero fee rate delta, we'll send a - // TxFailed so these inputs can be retried in the next block. - case errors.Is(err, ErrZeroFeeRateDelta): - event = TxFailed - - // Otherwise this is not a fee-related error and the tx cannot - // be retried. In that case we will fail ALL the inputs in this - // tx, which means they will be removed from the sweeper and - // never be tried again. - // - // TODO(yy): Find out which input is causing the failure and - // fail that one only. - default: - event = TxFatal - } + return + } + // Successfully created the first tx, now broadcast it. + result, err = t.broadcast(requestID) + if err != nil { + // The broadcast failed, which can only happen if the tx record + // cannot be found or the aux sweeper returns an error. In + // either case, we will send back a TxFail event so these + // inputs can be retried. result = &BumpResult{ - Event: event, + Event: TxFailed, Err: err, requestID: requestID, } From 74b7e5b93601319b926f8d7f2714c771f4131737 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 7 Nov 2024 20:28:50 +0800 Subject: [PATCH 20/59] sweep: make sure nil tx is handled After previous commit, it should be clear that the tx may be failed to created in a `TxFailed` event. We now make sure to catch it to avoid panic. --- sweep/fee_bumper.go | 22 ++++++++++++++++------ sweep/fee_bumper_test.go | 14 +++++++------- sweep/sweeper.go | 13 ++++++++++++- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 181f274839..13fa1272c6 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -75,7 +75,17 @@ const ( // TxPublished is sent when the broadcast attempt is finished. TxPublished BumpEvent = iota - // TxFailed is sent when the broadcast attempt fails. + // TxFailed is sent when the tx has encountered a fee-related error + // during its creation or broadcast, or an internal error from the fee + // bumper. In either case the inputs in this tx should be retried with + // either a different grouping strategy or an increased budget. + // + // NOTE: We also send this event when there's a third party spend + // event, and the sweeper will handle cleaning this up once it's + // confirmed. + // + // TODO(yy): Remove the above usage once we remove sweeping non-CPFP + // anchors. TxFailed // TxReplaced is sent when the original tx is replaced by a new one. @@ -269,8 +279,10 @@ func (b *BumpResult) String() string { // Validate validates the BumpResult so it's safe to use. func (b *BumpResult) Validate() error { + isFailureEvent := b.Event == TxFailed || b.Event == TxFatal + // Every result must have a tx except the fatal or failed case. - if b.Tx == nil && b.Event != TxFatal { + if b.Tx == nil && !isFailureEvent { return fmt.Errorf("%w: nil tx", ErrInvalidBumpResult) } @@ -285,10 +297,8 @@ func (b *BumpResult) Validate() error { } // If it's a failed or fatal event, it must have an error. - if b.Event == TxFatal || b.Event == TxFailed { - if b.Err == nil { - return fmt.Errorf("%w: nil error", ErrInvalidBumpResult) - } + if isFailureEvent && b.Err == nil { + return fmt.Errorf("%w: nil error", ErrInvalidBumpResult) } // If it's a confirmed event, it must have a fee rate and fee. diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index ba41d65695..54c67dbe28 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -91,13 +91,6 @@ func TestBumpResultValidate(t *testing.T) { } require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) - // A failed event without a tx will give an error. - b = BumpResult{ - Event: TxFailed, - Err: errDummy, - } - require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) - // A fatal event without a failure reason will give an error. b = BumpResult{ Event: TxFailed, @@ -118,6 +111,13 @@ func TestBumpResultValidate(t *testing.T) { } require.NoError(t, b.Validate()) + // Tx is allowed to be nil in a TxFailed event. + b = BumpResult{ + Event: TxFailed, + Err: errDummy, + } + require.NoError(t, b.Validate()) + // Tx is allowed to be nil in a TxFatal event. b = BumpResult{ Event: TxFatal, diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 813d0b0ca0..976fceff31 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -1669,6 +1669,14 @@ func (s *UtxoSweeper) monitorFeeBumpResult(set InputSet, // in sweeper and rely solely on this event to mark // inputs as Swept? if r.Event == TxConfirmed || r.Event == TxFailed { + // Exit if the tx is failed to be created. + if r.Tx == nil { + log.Debugf("Received %v for nil tx, "+ + "exit monitor", r.Event) + + return + } + log.Debugf("Received %v for sweep tx %v, exit "+ "fee bump monitor", r.Event, r.Tx.TxHash()) @@ -1694,7 +1702,10 @@ func (s *UtxoSweeper) handleBumpEventTxFailed(resp *bumpResp) { r := resp.result tx, err := r.Tx, r.Err - log.Warnf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) + if tx != nil { + log.Warnf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), + err) + } // NOTE: When marking the inputs as failed, we are using the input set // instead of the inputs found in the tx. This is fine for current From 83075349c0e2870027c1fc0e19c67aa043f27c18 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 27 Jun 2024 08:36:19 +0800 Subject: [PATCH 21/59] chainio: introduce `chainio` to handle block synchronization This commit inits the package `chainio` and defines the interface `Blockbeat` and `Consumer`. The `Consumer` must be implemented by other subsystems if it requires block epoch subscription. --- chainio/README.md | 152 +++++++++++++++++++++++++++++++++++++++++++ chainio/interface.go | 40 ++++++++++++ chainio/log.go | 32 +++++++++ 3 files changed, 224 insertions(+) create mode 100644 chainio/README.md create mode 100644 chainio/interface.go create mode 100644 chainio/log.go diff --git a/chainio/README.md b/chainio/README.md new file mode 100644 index 0000000000..b11e38157c --- /dev/null +++ b/chainio/README.md @@ -0,0 +1,152 @@ +# Chainio + +`chainio` is a package designed to provide blockchain data access to various +subsystems within `lnd`. When a new block is received, it is encapsulated in a +`Blockbeat` object and disseminated to all registered consumers. Consumers may +receive these updates either concurrently or sequentially, based on their +registration configuration, ensuring that each subsystem maintains a +synchronized view of the current block state. + +The main components include: + +- `Blockbeat`: An interface that provides information about the block. + +- `Consumer`: An interface that specifies how subsystems handle the blockbeat. + +- `BlockbeatDispatcher`: The core service responsible for receiving each block + and distributing it to all consumers. + +Additionally, the `BeatConsumer` struct provides a partial implementation of +the `Consumer` interface. This struct helps reduce code duplication, allowing +subsystems to avoid re-implementing the `ProcessBlock` method and provides a +commonly used `NotifyBlockProcessed` method. + + +### Register a Consumer + +Consumers within the same queue are notified **sequentially**, while all queues +are notified **concurrently**. A queue consists of a slice of consumers, which +are notified in left-to-right order. Developers are responsible for determining +dependencies in block consumption across subsystems: independent subsystems +should be notified concurrently, whereas dependent subsystems should be +notified sequentially. + +To notify the consumers concurrently, put them in different queues, +```go +// consumer1 and consumer2 will be notified concurrently. +queue1 := []chainio.Consumer{consumer1} +blockbeatDispatcher.RegisterQueue(consumer1) + +queue2 := []chainio.Consumer{consumer2} +blockbeatDispatcher.RegisterQueue(consumer2) +``` + +To notify the consumers sequentially, put them in the same queue, +```go +// consumers will be notified sequentially via, +// consumer1 -> consumer2 -> consumer3 +queue := []chainio.Consumer{ + consumer1, + consumer2, + consumer3, +} +blockbeatDispatcher.RegisterQueue(queue) +``` + +### Implement the `Consumer` Interface + +Implementing the `Consumer` interface is straightforward. Below is an example +of how +[`sweep.TxPublisher`](https://github.com/lightningnetwork/lnd/blob/5cec466fad44c582a64cfaeb91f6d5fd302fcf85/sweep/fee_bumper.go#L310) +implements this interface. + +To start, embed the partial implementation `chainio.BeatConsumer`, which +already provides the `ProcessBlock` implementation and commonly used +`NotifyBlockProcessed` method, and exposes `BlockbeatChan` for the consumer to +receive blockbeats. + +```go +type TxPublisher struct { + started atomic.Bool + stopped atomic.Bool + + chainio.BeatConsumer + + ... +``` + +We should also remember to initialize this `BeatConsumer`, + +```go +... +// Mount the block consumer. +tp.BeatConsumer = chainio.NewBeatConsumer(tp.quit, tp.Name()) +``` + +Finally, in the main event loop, read from `BlockbeatChan`, process the +received blockbeat, and, crucially, call `tp.NotifyBlockProcessed` to inform +the blockbeat dispatcher that processing is complete. + +```go +for { + select { + case beat := <-tp.BlockbeatChan: + // Consume this blockbeat, usually it means updating the subsystem + // using the new block data. + + // Notify we've processed the block. + tp.NotifyBlockProcessed(beat, nil) + + ... +``` + +### Existing Queues + +Currently, we have a single queue of consumers dedicated to handling force +closures. This queue includes `ChainArbitrator`, `UtxoSweeper`, and +`TxPublisher`, with `ChainArbitrator` managing two internal consumers: +`chainWatcher` and `ChannelArbitrator`. The blockbeat flows sequentially +through the chain as follows: `ChainArbitrator => chainWatcher => +ChannelArbitrator => UtxoSweeper => TxPublisher`. The following diagram +illustrates the flow within the public subsystems. + +```mermaid +sequenceDiagram + autonumber + participant bb as BlockBeat + participant cc as ChainArb + participant us as UtxoSweeper + participant tp as TxPublisher + + note left of bb: 0. received block x,
dispatching... + + note over bb,cc: 1. send block x to ChainArb,
wait for its done signal + bb->>cc: block x + rect rgba(165, 0, 85, 0.8) + critical signal processed + cc->>bb: processed block + option Process error or timeout + bb->>bb: error and exit + end + end + + note over bb,us: 2. send block x to UtxoSweeper, wait for its done signal + bb->>us: block x + rect rgba(165, 0, 85, 0.8) + critical signal processed + us->>bb: processed block + option Process error or timeout + bb->>bb: error and exit + end + end + + note over bb,tp: 3. send block x to TxPublisher, wait for its done signal + bb->>tp: block x + rect rgba(165, 0, 85, 0.8) + critical signal processed + tp->>bb: processed block + option Process error or timeout + bb->>bb: error and exit + end + end +``` diff --git a/chainio/interface.go b/chainio/interface.go new file mode 100644 index 0000000000..70827f2076 --- /dev/null +++ b/chainio/interface.go @@ -0,0 +1,40 @@ +package chainio + +// Blockbeat defines an interface that can be used by subsystems to retrieve +// block data. It is sent by the BlockbeatDispatcher whenever a new block is +// received. Once the subsystem finishes processing the block, it must signal +// it by calling NotifyBlockProcessed. +// +// The blockchain is a state machine - whenever there's a state change, it's +// manifested in a block. The blockbeat is a way to notify subsystems of this +// state change, and to provide them with the data they need to process it. In +// other words, subsystems must react to this state change and should consider +// being driven by the blockbeat in their own state machines. +type Blockbeat interface { + // Height returns the current block height. + Height() int32 +} + +// Consumer defines a blockbeat consumer interface. Subsystems that need block +// info must implement it. +type Consumer interface { + // TODO(yy): We should also define the start methods used by the + // consumers such that when implementing the interface, the consumer + // will always be started with a blockbeat. This cannot be enforced at + // the moment as we need refactor all the start methods to only take a + // beat. + // + // Start(beat Blockbeat) error + + // Name returns a human-readable string for this subsystem. + Name() string + + // ProcessBlock takes a blockbeat and processes it. It should not + // return until the subsystem has updated its state based on the block + // data. + // + // NOTE: The consumer must try its best to NOT return an error. If an + // error is returned from processing the block, it means the subsystem + // cannot react to onchain state changes and lnd will shutdown. + ProcessBlock(b Blockbeat) error +} diff --git a/chainio/log.go b/chainio/log.go new file mode 100644 index 0000000000..2d8c26f7a5 --- /dev/null +++ b/chainio/log.go @@ -0,0 +1,32 @@ +package chainio + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/build" +) + +// Subsystem defines the logging code for this subsystem. +const Subsystem = "CHIO" + +// clog is a logger that is initialized with no output filters. This means the +// package will not perform any logging by default until the caller requests +// it. +var clog btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// DisableLog disables all library log output. Logging output is disabled by +// default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. This +// should be used in preference to SetLogWriter if the caller is also using +// btclog. +func UseLogger(logger btclog.Logger) { + clog = logger +} From c366e41a19975d1159799653b61062f62ae8ff9a Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 27 Jun 2024 08:41:53 +0800 Subject: [PATCH 22/59] chainio: implement `Blockbeat` In this commit, a minimal implementation of `Blockbeat` is added to synchronize block heights, which will be used in `ChainArb`, `Sweeper`, and `TxPublisher` so blocks are processed sequentially among them. --- chainio/blockbeat.go | 54 +++++++++++++++++++++++++++++++++++++++ chainio/blockbeat_test.go | 28 ++++++++++++++++++++ chainio/interface.go | 19 +++++++++++--- 3 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 chainio/blockbeat.go create mode 100644 chainio/blockbeat_test.go diff --git a/chainio/blockbeat.go b/chainio/blockbeat.go new file mode 100644 index 0000000000..79188657fe --- /dev/null +++ b/chainio/blockbeat.go @@ -0,0 +1,54 @@ +package chainio + +import ( + "fmt" + + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/chainntnfs" +) + +// Beat implements the Blockbeat interface. It contains the block epoch and a +// customized logger. +// +// TODO(yy): extend this to check for confirmation status - which serves as the +// single source of truth, to avoid the potential race between receiving blocks +// and `GetTransactionDetails/RegisterSpendNtfn/RegisterConfirmationsNtfn`. +type Beat struct { + // epoch is the current block epoch the blockbeat is aware of. + epoch chainntnfs.BlockEpoch + + // log is the customized logger for the blockbeat which prints the + // block height. + log btclog.Logger +} + +// Compile-time check to ensure Beat satisfies the Blockbeat interface. +var _ Blockbeat = (*Beat)(nil) + +// NewBeat creates a new beat with the specified block epoch and a customized +// logger. +func NewBeat(epoch chainntnfs.BlockEpoch) *Beat { + b := &Beat{ + epoch: epoch, + } + + // Create a customized logger for the blockbeat. + logPrefix := fmt.Sprintf("Height[%6d]:", b.Height()) + b.log = clog.WithPrefix(logPrefix) + + return b +} + +// Height returns the height of the block epoch. +// +// NOTE: Part of the Blockbeat interface. +func (b *Beat) Height() int32 { + return b.epoch.Height +} + +// logger returns the logger for the blockbeat. +// +// NOTE: Part of the private blockbeat interface. +func (b *Beat) logger() btclog.Logger { + return b.log +} diff --git a/chainio/blockbeat_test.go b/chainio/blockbeat_test.go new file mode 100644 index 0000000000..9326651b38 --- /dev/null +++ b/chainio/blockbeat_test.go @@ -0,0 +1,28 @@ +package chainio + +import ( + "errors" + "testing" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/stretchr/testify/require" +) + +var errDummy = errors.New("dummy error") + +// TestNewBeat tests the NewBeat and Height functions. +func TestNewBeat(t *testing.T) { + t.Parallel() + + // Create a testing epoch. + epoch := chainntnfs.BlockEpoch{ + Height: 1, + } + + // Create the beat and check the internal state. + beat := NewBeat(epoch) + require.Equal(t, epoch, beat.epoch) + + // Check the height function. + require.Equal(t, epoch.Height, beat.Height()) +} diff --git a/chainio/interface.go b/chainio/interface.go index 70827f2076..03c09faf7c 100644 --- a/chainio/interface.go +++ b/chainio/interface.go @@ -1,9 +1,11 @@ package chainio +import "github.com/btcsuite/btclog/v2" + // Blockbeat defines an interface that can be used by subsystems to retrieve -// block data. It is sent by the BlockbeatDispatcher whenever a new block is -// received. Once the subsystem finishes processing the block, it must signal -// it by calling NotifyBlockProcessed. +// block data. It is sent by the BlockbeatDispatcher to all the registered +// consumers whenever a new block is received. Once the consumer finishes +// processing the block, it must signal it by calling `NotifyBlockProcessed`. // // The blockchain is a state machine - whenever there's a state change, it's // manifested in a block. The blockbeat is a way to notify subsystems of this @@ -11,10 +13,21 @@ package chainio // other words, subsystems must react to this state change and should consider // being driven by the blockbeat in their own state machines. type Blockbeat interface { + // blockbeat is a private interface that's only used in this package. + blockbeat + // Height returns the current block height. Height() int32 } +// blockbeat defines a set of private methods used in this package to make +// interaction with the blockbeat easier. +type blockbeat interface { + // logger returns the internal logger used by the blockbeat which has a + // block height prefix. + logger() btclog.Logger +} + // Consumer defines a blockbeat consumer interface. Subsystems that need block // info must implement it. type Consumer interface { From f8e1f2df7ffdfc73787a4798138043e997d3ef26 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 1 Nov 2024 06:15:34 +0800 Subject: [PATCH 23/59] chainio: add helper methods to dispatch beats This commit adds two methods to handle dispatching beats. These are exported methods so other systems can send beats to their managed subinstances. --- chainio/dispatcher.go | 105 ++++++++++++++++++++++++ chainio/dispatcher_test.go | 161 +++++++++++++++++++++++++++++++++++++ chainio/mocks.go | 50 ++++++++++++ 3 files changed, 316 insertions(+) create mode 100644 chainio/dispatcher.go create mode 100644 chainio/dispatcher_test.go create mode 100644 chainio/mocks.go diff --git a/chainio/dispatcher.go b/chainio/dispatcher.go new file mode 100644 index 0000000000..d6900b8f9c --- /dev/null +++ b/chainio/dispatcher.go @@ -0,0 +1,105 @@ +package chainio + +import ( + "errors" + "fmt" + "time" +) + +// DefaultProcessBlockTimeout is the timeout value used when waiting for one +// consumer to finish processing the new block epoch. +var DefaultProcessBlockTimeout = 60 * time.Second + +// ErrProcessBlockTimeout is the error returned when a consumer takes too long +// to process the block. +var ErrProcessBlockTimeout = errors.New("process block timeout") + +// DispatchSequential takes a list of consumers and notify them about the new +// epoch sequentially. It requires the consumer to finish processing the block +// within the specified time, otherwise a timeout error is returned. +func DispatchSequential(b Blockbeat, consumers []Consumer) error { + for _, c := range consumers { + // Send the beat to the consumer. + err := notifyAndWait(b, c, DefaultProcessBlockTimeout) + if err != nil { + b.logger().Errorf("Failed to process block: %v", err) + + return err + } + } + + return nil +} + +// DispatchConcurrent notifies each consumer concurrently about the blockbeat. +// It requires the consumer to finish processing the block within the specified +// time, otherwise a timeout error is returned. +func DispatchConcurrent(b Blockbeat, consumers []Consumer) error { + // errChans is a map of channels that will be used to receive errors + // returned from notifying the consumers. + errChans := make(map[string]chan error, len(consumers)) + + // Notify each queue in goroutines. + for _, c := range consumers { + // Create a signal chan. + errChan := make(chan error, 1) + errChans[c.Name()] = errChan + + // Notify each consumer concurrently. + go func(c Consumer, beat Blockbeat) { + // Send the beat to the consumer. + errChan <- notifyAndWait( + b, c, DefaultProcessBlockTimeout, + ) + }(c, b) + } + + // Wait for all consumers in each queue to finish. + for name, errChan := range errChans { + err := <-errChan + if err != nil { + b.logger().Errorf("Consumer=%v failed to process "+ + "block: %v", name, err) + + return err + } + } + + return nil +} + +// notifyAndWait sends the blockbeat to the specified consumer. It requires the +// consumer to finish processing the block within the specified time, otherwise +// a timeout error is returned. +func notifyAndWait(b Blockbeat, c Consumer, timeout time.Duration) error { + b.logger().Debugf("Waiting for consumer[%s] to process it", c.Name()) + + // Record the time it takes the consumer to process this block. + start := time.Now() + + errChan := make(chan error, 1) + go func() { + errChan <- c.ProcessBlock(b) + }() + + // We expect the consumer to finish processing this block under 30s, + // otherwise a timeout error is returned. + select { + case err := <-errChan: + if err == nil { + break + } + + return fmt.Errorf("%s got err in ProcessBlock: %w", c.Name(), + err) + + case <-time.After(timeout): + return fmt.Errorf("consumer %s: %w", c.Name(), + ErrProcessBlockTimeout) + } + + b.logger().Debugf("Consumer[%s] processed block in %v", c.Name(), + time.Since(start)) + + return nil +} diff --git a/chainio/dispatcher_test.go b/chainio/dispatcher_test.go new file mode 100644 index 0000000000..c41138fd28 --- /dev/null +++ b/chainio/dispatcher_test.go @@ -0,0 +1,161 @@ +package chainio + +import ( + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// TestNotifyAndWaitOnConsumerErr asserts when the consumer returns an error, +// it's returned by notifyAndWait. +func TestNotifyAndWaitOnConsumerErr(t *testing.T) { + t.Parallel() + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock ProcessBlock to return an error. + consumer.On("ProcessBlock", mockBeat).Return(errDummy).Once() + + // Call the method under test. + err := notifyAndWait(mockBeat, consumer, DefaultProcessBlockTimeout) + + // We expect the error to be returned. + require.ErrorIs(t, err, errDummy) +} + +// TestNotifyAndWaitOnConsumerErr asserts when the consumer successfully +// processed the beat, no error is returned. +func TestNotifyAndWaitOnConsumerSuccess(t *testing.T) { + t.Parallel() + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock ProcessBlock to return nil. + consumer.On("ProcessBlock", mockBeat).Return(nil).Once() + + // Call the method under test. + err := notifyAndWait(mockBeat, consumer, DefaultProcessBlockTimeout) + + // We expect a nil error to be returned. + require.NoError(t, err) +} + +// TestNotifyAndWaitOnConsumerTimeout asserts when the consumer times out +// processing the block, the timeout error is returned. +func TestNotifyAndWaitOnConsumerTimeout(t *testing.T) { + t.Parallel() + + // Set timeout to be 10ms. + processBlockTimeout := 10 * time.Millisecond + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock ProcessBlock to return nil but blocks on returning. + consumer.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Sleep one second to block on the method. + time.Sleep(processBlockTimeout * 100) + }).Once() + + // Call the method under test. + err := notifyAndWait(mockBeat, consumer, processBlockTimeout) + + // We expect a timeout error to be returned. + require.ErrorIs(t, err, ErrProcessBlockTimeout) +} + +// TestDispatchSequential checks that the beat is sent to the consumers +// sequentially. +func TestDispatchSequential(t *testing.T) { + t.Parallel() + + // Create three mock consumers. + consumer1 := &MockConsumer{} + defer consumer1.AssertExpectations(t) + consumer1.On("Name").Return("mocker1") + + consumer2 := &MockConsumer{} + defer consumer2.AssertExpectations(t) + consumer2.On("Name").Return("mocker2") + + consumer3 := &MockConsumer{} + defer consumer3.AssertExpectations(t) + consumer3.On("Name").Return("mocker3") + + consumers := []Consumer{consumer1, consumer2, consumer3} + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // prevConsumer specifies the previous consumer that was called. + var prevConsumer string + + // Mock the ProcessBlock on consumers to reutrn immediately. + consumer1.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Check the order of the consumers. + // + // The first consumer should have no previous consumer. + require.Empty(t, prevConsumer) + + // Set the consumer as the previous consumer. + prevConsumer = consumer1.Name() + }).Once() + + consumer2.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Check the order of the consumers. + // + // The second consumer should see consumer1. + require.Equal(t, consumer1.Name(), prevConsumer) + + // Set the consumer as the previous consumer. + prevConsumer = consumer2.Name() + }).Once() + + consumer3.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Check the order of the consumers. + // + // The third consumer should see consumer2. + require.Equal(t, consumer2.Name(), prevConsumer) + + // Set the consumer as the previous consumer. + prevConsumer = consumer3.Name() + }).Once() + + // Call the method under test. + err := DispatchSequential(mockBeat, consumers) + require.NoError(t, err) + + // Check the previous consumer is the last consumer. + require.Equal(t, consumer3.Name(), prevConsumer) +} diff --git a/chainio/mocks.go b/chainio/mocks.go new file mode 100644 index 0000000000..5677734e1d --- /dev/null +++ b/chainio/mocks.go @@ -0,0 +1,50 @@ +package chainio + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/stretchr/testify/mock" +) + +// MockConsumer is a mock implementation of the Consumer interface. +type MockConsumer struct { + mock.Mock +} + +// Compile-time constraint to ensure MockConsumer implements Consumer. +var _ Consumer = (*MockConsumer)(nil) + +// Name returns a human-readable string for this subsystem. +func (m *MockConsumer) Name() string { + args := m.Called() + return args.String(0) +} + +// ProcessBlock takes a blockbeat and processes it. A receive-only error chan +// must be returned. +func (m *MockConsumer) ProcessBlock(b Blockbeat) error { + args := m.Called(b) + + return args.Error(0) +} + +// MockBlockbeat is a mock implementation of the Blockbeat interface. +type MockBlockbeat struct { + mock.Mock +} + +// Compile-time constraint to ensure MockBlockbeat implements Blockbeat. +var _ Blockbeat = (*MockBlockbeat)(nil) + +// Height returns the current block height. +func (m *MockBlockbeat) Height() int32 { + args := m.Called() + + return args.Get(0).(int32) +} + +// logger returns the logger for the blockbeat. +func (m *MockBlockbeat) logger() btclog.Logger { + args := m.Called() + + return args.Get(0).(btclog.Logger) +} From dfae94bf0df4c95313b0a6f67e696149d4d65865 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 27 Jun 2024 08:43:26 +0800 Subject: [PATCH 24/59] chainio: add `BlockbeatDispatcher` to dispatch blockbeats This commit adds a blockbeat dispatcher which handles sending new blocks to all subscribed consumers. --- chainio/dispatcher.go | 194 ++++++++++++++++++++++++++++++++ chainio/dispatcher_test.go | 222 +++++++++++++++++++++++++++++++++++++ 2 files changed, 416 insertions(+) diff --git a/chainio/dispatcher.go b/chainio/dispatcher.go index d6900b8f9c..244a3ac8f7 100644 --- a/chainio/dispatcher.go +++ b/chainio/dispatcher.go @@ -3,7 +3,12 @@ package chainio import ( "errors" "fmt" + "sync" + "sync/atomic" "time" + + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/chainntnfs" ) // DefaultProcessBlockTimeout is the timeout value used when waiting for one @@ -14,6 +19,195 @@ var DefaultProcessBlockTimeout = 60 * time.Second // to process the block. var ErrProcessBlockTimeout = errors.New("process block timeout") +// BlockbeatDispatcher is a service that handles dispatching new blocks to +// `lnd`'s subsystems. During startup, subsystems that are block-driven should +// implement the `Consumer` interface and register themselves via +// `RegisterQueue`. When two subsystems are independent of each other, they +// should be registered in different queues so blocks are notified concurrently. +// Otherwise, when living in the same queue, the subsystems are notified of the +// new blocks sequentially, which means it's critical to understand the +// relationship of these systems to properly handle the order. +type BlockbeatDispatcher struct { + wg sync.WaitGroup + + // notifier is used to receive new block epochs. + notifier chainntnfs.ChainNotifier + + // beat is the latest blockbeat received. + beat Blockbeat + + // consumerQueues is a map of consumers that will receive blocks. Its + // key is a unique counter and its value is a queue of consumers. Each + // queue is notified concurrently, and consumers in the same queue is + // notified sequentially. + consumerQueues map[uint32][]Consumer + + // counter is used to assign a unique id to each queue. + counter atomic.Uint32 + + // quit is used to signal the BlockbeatDispatcher to stop. + quit chan struct{} +} + +// NewBlockbeatDispatcher returns a new blockbeat dispatcher instance. +func NewBlockbeatDispatcher(n chainntnfs.ChainNotifier) *BlockbeatDispatcher { + return &BlockbeatDispatcher{ + notifier: n, + quit: make(chan struct{}), + consumerQueues: make(map[uint32][]Consumer), + } +} + +// RegisterQueue takes a list of consumers and registers them in the same +// queue. +// +// NOTE: these consumers are notified sequentially. +func (b *BlockbeatDispatcher) RegisterQueue(consumers []Consumer) { + qid := b.counter.Add(1) + + b.consumerQueues[qid] = append(b.consumerQueues[qid], consumers...) + clog.Infof("Registered queue=%d with %d blockbeat consumers", qid, + len(consumers)) + + for _, c := range consumers { + clog.Debugf("Consumer [%s] registered in queue %d", c.Name(), + qid) + } +} + +// Start starts the blockbeat dispatcher - it registers a block notification +// and monitors and dispatches new blocks in a goroutine. It will refuse to +// start if there are no registered consumers. +func (b *BlockbeatDispatcher) Start() error { + // Make sure consumers are registered. + if len(b.consumerQueues) == 0 { + return fmt.Errorf("no consumers registered") + } + + // Start listening to new block epochs. We should get a notification + // with the current best block immediately. + blockEpochs, err := b.notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return fmt.Errorf("register block epoch ntfn: %w", err) + } + + clog.Infof("BlockbeatDispatcher is starting with %d consumer queues", + len(b.consumerQueues)) + defer clog.Debug("BlockbeatDispatcher started") + + b.wg.Add(1) + go b.dispatchBlocks(blockEpochs) + + return nil +} + +// Stop shuts down the blockbeat dispatcher. +func (b *BlockbeatDispatcher) Stop() { + clog.Info("BlockbeatDispatcher is stopping") + defer clog.Debug("BlockbeatDispatcher stopped") + + // Signal the dispatchBlocks goroutine to stop. + close(b.quit) + b.wg.Wait() +} + +func (b *BlockbeatDispatcher) log() btclog.Logger { + return b.beat.logger() +} + +// dispatchBlocks listens to new block epoch and dispatches it to all the +// consumers. Each queue is notified concurrently, and the consumers in the +// same queue are notified sequentially. +// +// NOTE: Must be run as a goroutine. +func (b *BlockbeatDispatcher) dispatchBlocks( + blockEpochs *chainntnfs.BlockEpochEvent) { + + defer b.wg.Done() + defer blockEpochs.Cancel() + + for { + select { + case blockEpoch, ok := <-blockEpochs.Epochs: + if !ok { + clog.Debugf("Block epoch channel closed") + + return + } + + clog.Infof("Received new block %v at height %d, "+ + "notifying consumers...", blockEpoch.Hash, + blockEpoch.Height) + + // Record the time it takes the consumer to process + // this block. + start := time.Now() + + // Update the current block epoch. + b.beat = NewBeat(*blockEpoch) + + // Notify all consumers. + err := b.notifyQueues() + if err != nil { + b.log().Errorf("Notify block failed: %v", err) + } + + b.log().Infof("Notified all consumers on new block "+ + "in %v", time.Since(start)) + + case <-b.quit: + b.log().Debugf("BlockbeatDispatcher quit signal " + + "received") + + return + } + } +} + +// notifyQueues notifies each queue concurrently about the latest block epoch. +func (b *BlockbeatDispatcher) notifyQueues() error { + // errChans is a map of channels that will be used to receive errors + // returned from notifying the consumers. + errChans := make(map[uint32]chan error, len(b.consumerQueues)) + + // Notify each queue in goroutines. + for qid, consumers := range b.consumerQueues { + b.log().Debugf("Notifying queue=%d with %d consumers", qid, + len(consumers)) + + // Create a signal chan. + errChan := make(chan error, 1) + errChans[qid] = errChan + + // Notify each queue concurrently. + go func(qid uint32, c []Consumer, beat Blockbeat) { + // Notify each consumer in this queue sequentially. + errChan <- DispatchSequential(beat, c) + }(qid, consumers, b.beat) + } + + // Wait for all consumers in each queue to finish. + for qid, errChan := range errChans { + select { + case err := <-errChan: + if err != nil { + return fmt.Errorf("queue=%d got err: %w", qid, + err) + } + + b.log().Debugf("Notified queue=%d", qid) + + case <-b.quit: + b.log().Debugf("BlockbeatDispatcher quit signal " + + "received, exit notifyQueues") + + return nil + } + } + + return nil +} + // DispatchSequential takes a list of consumers and notify them about the new // epoch sequentially. It requires the consumer to finish processing the block // within the specified time, otherwise a timeout error is returned. diff --git a/chainio/dispatcher_test.go b/chainio/dispatcher_test.go index c41138fd28..88044c0201 100644 --- a/chainio/dispatcher_test.go +++ b/chainio/dispatcher_test.go @@ -4,6 +4,8 @@ import ( "testing" "time" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -159,3 +161,223 @@ func TestDispatchSequential(t *testing.T) { // Check the previous consumer is the last consumer. require.Equal(t, consumer3.Name(), prevConsumer) } + +// TestRegisterQueue tests the RegisterQueue function. +func TestRegisterQueue(t *testing.T) { + t.Parallel() + + // Create two mock consumers. + consumer1 := &MockConsumer{} + defer consumer1.AssertExpectations(t) + consumer1.On("Name").Return("mocker1") + + consumer2 := &MockConsumer{} + defer consumer2.AssertExpectations(t) + consumer2.On("Name").Return("mocker2") + + consumers := []Consumer{consumer1, consumer2} + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Register the consumers. + b.RegisterQueue(consumers) + + // Assert that the consumers have been registered. + // + // We should have one queue. + require.Len(t, b.consumerQueues, 1) + + // The queue should have two consumers. + queue, ok := b.consumerQueues[1] + require.True(t, ok) + require.Len(t, queue, 2) +} + +// TestStartDispatcher tests the Start method. +func TestStartDispatcher(t *testing.T) { + t.Parallel() + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Start the dispatcher without consumers should return an error. + err := b.Start() + require.Error(t, err) + + // Create a consumer and register it. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker1") + b.RegisterQueue([]Consumer{consumer}) + + // Mock the chain notifier to return an error. + mockNotifier.On("RegisterBlockEpochNtfn", + mock.Anything).Return(nil, errDummy).Once() + + // Start the dispatcher now should return the error. + err = b.Start() + require.ErrorIs(t, err, errDummy) + + // Mock the chain notifier to return a valid notifier. + blockEpochs := &chainntnfs.BlockEpochEvent{} + mockNotifier.On("RegisterBlockEpochNtfn", + mock.Anything).Return(blockEpochs, nil).Once() + + // Start the dispatcher now should not return an error. + err = b.Start() + require.NoError(t, err) +} + +// TestDispatchBlocks asserts the blocks are properly dispatched to the queues. +func TestDispatchBlocks(t *testing.T) { + t.Parallel() + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Create the beat and attach it to the dispatcher. + epoch := chainntnfs.BlockEpoch{Height: 1} + beat := NewBeat(epoch) + b.beat = beat + + // Create a consumer and register it. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker1") + b.RegisterQueue([]Consumer{consumer}) + + // Mock the consumer to return nil error on ProcessBlock. This + // implictly asserts that the step `notifyQueues` is successfully + // reached in the `dispatchBlocks` method. + consumer.On("ProcessBlock", mock.Anything).Return(nil).Once() + + // Create a test epoch chan. + epochChan := make(chan *chainntnfs.BlockEpoch, 1) + blockEpochs := &chainntnfs.BlockEpochEvent{ + Epochs: epochChan, + Cancel: func() {}, + } + + // Call the method in a goroutine. + done := make(chan struct{}) + b.wg.Add(1) + go func() { + defer close(done) + b.dispatchBlocks(blockEpochs) + }() + + // Send an epoch. + epoch = chainntnfs.BlockEpoch{Height: 2} + epochChan <- &epoch + + // Wait for the dispatcher to process the epoch. + time.Sleep(100 * time.Millisecond) + + // Stop the dispatcher. + b.Stop() + + // We expect the dispatcher to stop immediately. + _, err := fn.RecvOrTimeout(done, time.Second) + require.NoError(t, err) +} + +// TestNotifyQueuesSuccess checks when the dispatcher successfully notifies all +// the queues, no error is returned. +func TestNotifyQueuesSuccess(t *testing.T) { + t.Parallel() + + // Create two mock consumers. + consumer1 := &MockConsumer{} + defer consumer1.AssertExpectations(t) + consumer1.On("Name").Return("mocker1") + + consumer2 := &MockConsumer{} + defer consumer2.AssertExpectations(t) + consumer2.On("Name").Return("mocker2") + + // Create two queues. + queue1 := []Consumer{consumer1} + queue2 := []Consumer{consumer2} + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Register the queues. + b.RegisterQueue(queue1) + b.RegisterQueue(queue2) + + // Attach the blockbeat. + b.beat = mockBeat + + // Mock the consumers to return nil error on ProcessBlock for + // both calls. + consumer1.On("ProcessBlock", mockBeat).Return(nil).Once() + consumer2.On("ProcessBlock", mockBeat).Return(nil).Once() + + // Notify the queues. The mockers will be asserted in the end to + // validate the calls. + err := b.notifyQueues() + require.NoError(t, err) +} + +// TestNotifyQueuesError checks when one of the queue returns an error, this +// error is returned by the method. +func TestNotifyQueuesError(t *testing.T) { + t.Parallel() + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker1") + + // Create one queue. + queue := []Consumer{consumer} + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Register the queues. + b.RegisterQueue(queue) + + // Attach the blockbeat. + b.beat = mockBeat + + // Mock the consumer to return an error on ProcessBlock. + consumer.On("ProcessBlock", mockBeat).Return(errDummy).Once() + + // Notify the queues. The mockers will be asserted in the end to + // validate the calls. + err := b.notifyQueues() + require.ErrorIs(t, err, errDummy) +} From 87f2f840c4609737d0f210a33e777a06b6181325 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 17 Oct 2024 10:58:59 +0800 Subject: [PATCH 25/59] chainio: add partial implementation of `Consumer` interface --- chainio/consumer.go | 113 ++++++++++++++++++++++ chainio/consumer_test.go | 202 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 315 insertions(+) create mode 100644 chainio/consumer.go create mode 100644 chainio/consumer_test.go diff --git a/chainio/consumer.go b/chainio/consumer.go new file mode 100644 index 0000000000..a9ec25745b --- /dev/null +++ b/chainio/consumer.go @@ -0,0 +1,113 @@ +package chainio + +// BeatConsumer defines a supplementary component that should be used by +// subsystems which implement the `Consumer` interface. It partially implements +// the `Consumer` interface by providing the method `ProcessBlock` such that +// subsystems don't need to re-implement it. +// +// While inheritance is not commonly used in Go, subsystems embedding this +// struct cannot pass the interface check for `Consumer` because the `Name` +// method is not implemented, which gives us a "mortise and tenon" structure. +// In addition to reducing code duplication, this design allows `ProcessBlock` +// to work on the concrete type `Beat` to access its internal states. +type BeatConsumer struct { + // BlockbeatChan is a channel to receive blocks from Blockbeat. The + // received block contains the best known height and the txns confirmed + // in this block. + BlockbeatChan chan Blockbeat + + // name is the name of the consumer which embeds the BlockConsumer. + name string + + // quit is a channel that closes when the BlockConsumer is shutting + // down. + // + // NOTE: this quit channel should be mounted to the same quit channel + // used by the subsystem. + quit chan struct{} + + // errChan is a buffered chan that receives an error returned from + // processing this block. + errChan chan error +} + +// NewBeatConsumer creates a new BlockConsumer. +func NewBeatConsumer(quit chan struct{}, name string) BeatConsumer { + // Refuse to start `lnd` if the quit channel is not initialized. We + // treat this case as if we are facing a nil pointer dereference, as + // there's no point to return an error here, which will cause the node + // to fail to be started anyway. + if quit == nil { + panic("quit channel is nil") + } + + b := BeatConsumer{ + BlockbeatChan: make(chan Blockbeat), + name: name, + errChan: make(chan error, 1), + quit: quit, + } + + return b +} + +// ProcessBlock takes a blockbeat and sends it to the consumer's blockbeat +// channel. It will send it to the subsystem's BlockbeatChan, and block until +// the processed result is received from the subsystem. The subsystem must call +// `NotifyBlockProcessed` after it has finished processing the block. +// +// NOTE: part of the `chainio.Consumer` interface. +func (b *BeatConsumer) ProcessBlock(beat Blockbeat) error { + // Update the current height. + beat.logger().Tracef("set current height for [%s]", b.name) + + select { + // Send the beat to the blockbeat channel. It's expected that the + // consumer will read from this channel and process the block. Once + // processed, it should return the error or nil to the beat.Err chan. + case b.BlockbeatChan <- beat: + beat.logger().Tracef("Sent blockbeat to [%s]", b.name) + + case <-b.quit: + beat.logger().Debugf("[%s] received shutdown before sending "+ + "beat", b.name) + + return nil + } + + // Check the consumer's err chan. We expect the consumer to call + // `beat.NotifyBlockProcessed` to send the error back here. + select { + case err := <-b.errChan: + beat.logger().Debugf("[%s] processed beat: err=%v", b.name, err) + + return err + + case <-b.quit: + beat.logger().Debugf("[%s] received shutdown", b.name) + } + + return nil +} + +// NotifyBlockProcessed signals that the block has been processed. It takes the +// blockbeat being processed and an error resulted from processing it. This +// error is then sent back to the consumer's err chan to unblock +// `ProcessBlock`. +// +// NOTE: This method must be called by the subsystem after it has finished +// processing the block. +func (b *BeatConsumer) NotifyBlockProcessed(beat Blockbeat, err error) { + // Update the current height. + beat.logger().Debugf("[%s]: notifying beat processed", b.name) + + select { + case b.errChan <- err: + beat.logger().Debugf("[%s]: notified beat processed, err=%v", + b.name, err) + + case <-b.quit: + beat.logger().Debugf("[%s] received shutdown before notifying "+ + "beat processed", b.name) + } +} diff --git a/chainio/consumer_test.go b/chainio/consumer_test.go new file mode 100644 index 0000000000..3ef79b61b4 --- /dev/null +++ b/chainio/consumer_test.go @@ -0,0 +1,202 @@ +package chainio + +import ( + "testing" + "time" + + "github.com/lightningnetwork/lnd/fn" + "github.com/stretchr/testify/require" +) + +// TestNewBeatConsumer tests the NewBeatConsumer function. +func TestNewBeatConsumer(t *testing.T) { + t.Parallel() + + quitChan := make(chan struct{}) + name := "test" + + // Test the NewBeatConsumer function. + b := NewBeatConsumer(quitChan, name) + + // Assert the state. + require.Equal(t, quitChan, b.quit) + require.Equal(t, name, b.name) + require.NotNil(t, b.BlockbeatChan) +} + +// TestProcessBlockSuccess tests when the block is processed successfully, no +// error is returned. +func TestProcessBlockSuccess(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan. + consumerErrChan := make(chan error, 1) + b.errChan = consumerErrChan + + // Call the method under test. + resultChan := make(chan error, 1) + go func() { + resultChan <- b.ProcessBlock(mockBeat) + }() + + // Assert the beat is sent to the blockbeat channel. + beat, err := fn.RecvOrTimeout(b.BlockbeatChan, time.Second) + require.NoError(t, err) + require.Equal(t, mockBeat, beat) + + // Send nil to the consumer's error channel. + consumerErrChan <- nil + + // Assert the result of ProcessBlock is nil. + result, err := fn.RecvOrTimeout(resultChan, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestProcessBlockConsumerQuitBeforeSend tests when the consumer is quit +// before sending the beat, the method returns immediately. +func TestProcessBlockConsumerQuitBeforeSend(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Call the method under test. + resultChan := make(chan error, 1) + go func() { + resultChan <- b.ProcessBlock(mockBeat) + }() + + // Instead of reading the BlockbeatChan, close the quit channel. + close(quitChan) + + // Assert ProcessBlock returned nil. + result, err := fn.RecvOrTimeout(resultChan, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestProcessBlockConsumerQuitAfterSend tests when the consumer is quit after +// sending the beat, the method returns immediately. +func TestProcessBlockConsumerQuitAfterSend(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan. + consumerErrChan := make(chan error, 1) + b.errChan = consumerErrChan + + // Call the method under test. + resultChan := make(chan error, 1) + go func() { + resultChan <- b.ProcessBlock(mockBeat) + }() + + // Assert the beat is sent to the blockbeat channel. + beat, err := fn.RecvOrTimeout(b.BlockbeatChan, time.Second) + require.NoError(t, err) + require.Equal(t, mockBeat, beat) + + // Instead of sending nil to the consumer's error channel, close the + // quit chanel. + close(quitChan) + + // Assert ProcessBlock returned nil. + result, err := fn.RecvOrTimeout(resultChan, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestNotifyBlockProcessedSendErr asserts the error can be sent and read by +// the beat via NotifyBlockProcessed. +func TestNotifyBlockProcessedSendErr(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan. + consumerErrChan := make(chan error, 1) + b.errChan = consumerErrChan + + // Call the method under test. + done := make(chan error) + go func() { + defer close(done) + b.NotifyBlockProcessed(mockBeat, errDummy) + }() + + // Assert the error is sent to the beat's err chan. + result, err := fn.RecvOrTimeout(consumerErrChan, time.Second) + require.NoError(t, err) + require.ErrorIs(t, result, errDummy) + + // Assert the done channel is closed. + result, err = fn.RecvOrTimeout(done, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestNotifyBlockProcessedOnQuit asserts NotifyBlockProcessed exits +// immediately when the quit channel is closed. +func TestNotifyBlockProcessedOnQuit(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan - we don't buffer it so it will block + // on sending the error. + consumerErrChan := make(chan error) + b.errChan = consumerErrChan + + // Call the method under test. + done := make(chan error) + go func() { + defer close(done) + b.NotifyBlockProcessed(mockBeat, errDummy) + }() + + // Close the quit channel so the method will return. + close(b.quit) + + // Assert the done channel is closed. + result, err := fn.RecvOrTimeout(done, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} From 5ac6198c73c2de7e3cc344b6803f1006c7ad7258 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 29 Oct 2024 21:23:29 +0800 Subject: [PATCH 26/59] multi: implement `Consumer` on subsystems This commit implements `Consumer` on `TxPublisher`, `UtxoSweeper`, `ChainArbitrator` and `ChannelArbitrator`. --- contractcourt/chain_arbitrator.go | 20 +++++++++++++++++++- contractcourt/channel_arbitrator.go | 22 +++++++++++++++++++++- sweep/fee_bumper.go | 20 +++++++++++++++++++- sweep/sweeper.go | 20 +++++++++++++++++++- 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 646d68b869..e8ba2907aa 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/walletdb" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" @@ -244,6 +245,10 @@ type ChainArbitrator struct { started int32 // To be used atomically. stopped int32 // To be used atomically. + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + sync.Mutex // activeChannels is a map of all the active contracts that are still @@ -272,15 +277,23 @@ type ChainArbitrator struct { func NewChainArbitrator(cfg ChainArbitratorConfig, db *channeldb.DB) *ChainArbitrator { - return &ChainArbitrator{ + c := &ChainArbitrator{ cfg: cfg, activeChannels: make(map[wire.OutPoint]*ChannelArbitrator), activeWatchers: make(map[wire.OutPoint]*chainWatcher), chanSource: db, quit: make(chan struct{}), } + + // Mount the block consumer. + c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name()) + + return c } +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*ChainArbitrator)(nil) + // arbChannel is a wrapper around an open channel that channel arbitrators // interact with. type arbChannel struct { @@ -1361,3 +1374,8 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID, // TODO(roasbeef): arbitration reports // * types: contested, waiting for success conf, etc + +// NOTE: part of the `chainio.Consumer` interface. +func (c *ChainArbitrator) Name() string { + return "ChainArbitrator" +} diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index ef1246f760..65cba2170a 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" @@ -330,6 +331,10 @@ type ChannelArbitrator struct { started int32 // To be used atomically. stopped int32 // To be used atomically. + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + // startTimestamp is the time when this ChannelArbitrator was started. startTimestamp time.Time @@ -404,7 +409,7 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, unmerged[RemotePendingHtlcSet] = htlcSets[RemotePendingHtlcSet] } - return &ChannelArbitrator{ + c := &ChannelArbitrator{ log: log, blocks: make(chan int32, arbitratorBlockBufferSize), signalUpdates: make(chan *signalUpdateMsg), @@ -415,8 +420,16 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, cfg: cfg, quit: make(chan struct{}), } + + // Mount the block consumer. + c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name()) + + return c } +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*ChannelArbitrator)(nil) + // chanArbStartState contains the information from disk that we need to start // up a channel arbitrator. type chanArbStartState struct { @@ -3131,6 +3144,13 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32, } } +// Name returns a human-readable string for this subsystem. +// +// NOTE: Part of chainio.Consumer interface. +func (c *ChannelArbitrator) Name() string { + return fmt.Sprintf("ChannelArbitrator(%v)", c.cfg.ChanPoint) +} + // checkLegacyBreach returns StateFullyResolved if the channel was closed with // a breach transaction before the channel arbitrator launched its own breach // resolver. StateContractClosed is returned if this is a modern breach close diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 13fa1272c6..3c973eb119 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" @@ -344,6 +345,10 @@ type TxPublisher struct { started atomic.Bool stopped atomic.Bool + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + wg sync.WaitGroup // cfg specifies the configuration of the TxPublisher. @@ -371,14 +376,22 @@ type TxPublisher struct { // Compile-time constraint to ensure TxPublisher implements Bumper. var _ Bumper = (*TxPublisher)(nil) +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*TxPublisher)(nil) + // NewTxPublisher creates a new TxPublisher. func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher { - return &TxPublisher{ + tp := &TxPublisher{ cfg: &cfg, records: lnutils.SyncMap[uint64, *monitorRecord]{}, subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{}, quit: make(chan struct{}), } + + // Mount the block consumer. + tp.BeatConsumer = chainio.NewBeatConsumer(tp.quit, tp.Name()) + + return tp } // isNeutrinoBackend checks if the wallet backend is neutrino. @@ -427,6 +440,11 @@ func (t *TxPublisher) storeInitialRecord(req *BumpRequest) ( return requestID, record } +// NOTE: part of the `chainio.Consumer` interface. +func (t *TxPublisher) Name() string { + return "TxPublisher" +} + // initializeTx initializes a fee function and creates an RBF-compliant tx. If // succeeded, the initial tx is stored in the records map. func (t *TxPublisher) initializeTx(requestID uint64, req *BumpRequest) error { diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 976fceff31..9b06c05495 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" @@ -308,6 +309,10 @@ type UtxoSweeper struct { started uint32 // To be used atomically. stopped uint32 // To be used atomically. + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + cfg *UtxoSweeperConfig newInputs chan *sweepInputMessage @@ -342,6 +347,9 @@ type UtxoSweeper struct { bumpRespChan chan *bumpResp } +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*UtxoSweeper)(nil) + // UtxoSweeperConfig contains dependencies of UtxoSweeper. type UtxoSweeperConfig struct { // GenSweepScript generates a P2WKH script belonging to the wallet where @@ -415,7 +423,7 @@ type sweepInputMessage struct { // New returns a new Sweeper instance. func New(cfg *UtxoSweeperConfig) *UtxoSweeper { - return &UtxoSweeper{ + s := &UtxoSweeper{ cfg: cfg, newInputs: make(chan *sweepInputMessage), spendChan: make(chan *chainntnfs.SpendDetail), @@ -425,6 +433,11 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { inputs: make(InputsMap), bumpRespChan: make(chan *bumpResp, 100), } + + // Mount the block consumer. + s.BeatConsumer = chainio.NewBeatConsumer(s.quit, s.Name()) + + return s } // Start starts the process of constructing and publish sweep txes. @@ -508,6 +521,11 @@ func (s *UtxoSweeper) Stop() error { return nil } +// NOTE: part of the `chainio.Consumer` interface. +func (s *UtxoSweeper) Name() string { + return "UtxoSweeper" +} + // SweepInput sweeps inputs back into the wallet. The inputs will be batched and // swept after the batch time window ends. A custom fee preference can be // provided to determine what fee rate should be used for the input. Note that From 798629d0e1dd92c07c7721f0de244907ea6fc3d8 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 4 Jun 2024 20:31:03 +0800 Subject: [PATCH 27/59] sweep: remove block subscription in `UtxoSweeper` and `TxPublisher` This commit removes the independent block subscriptions in `UtxoSweeper` and `TxPublisher`. These subsystems now listen to the `BlockbeatChan` for new blocks. --- sweep/fee_bumper.go | 31 +++++++++---------------------- sweep/sweeper.go | 42 +++++++++--------------------------------- 2 files changed, 18 insertions(+), 55 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 3c973eb119..ad02ac6194 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -802,13 +802,8 @@ func (t *TxPublisher) Start() error { return fmt.Errorf("TxPublisher started more than once") } - blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } - t.wg.Add(1) - go t.monitor(blockEvent) + go t.monitor() log.Debugf("TxPublisher started") @@ -836,33 +831,25 @@ func (t *TxPublisher) Stop() error { // to be bumped. If so, it will attempt to bump the fee of the tx. // // NOTE: Must be run as a goroutine. -func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { - defer blockEvent.Cancel() +func (t *TxPublisher) monitor() { defer t.wg.Done() for { select { - case epoch, ok := <-blockEvent.Epochs: - if !ok { - // We should stop the publisher before stopping - // the chain service. Otherwise it indicates an - // error. - log.Error("Block epoch channel closed, exit " + - "monitor") - - return - } - - log.Debugf("TxPublisher received new block: %v", - epoch.Height) + case beat := <-t.BlockbeatChan: + height := beat.Height() + log.Debugf("TxPublisher received new block: %v", height) // Update the best known height for the publisher. - t.currentHeight.Store(epoch.Height) + t.currentHeight.Store(height) // Check all monitored txns to see if any of them needs // to be bumped. t.processRecords() + // Notify we've processed the block. + t.NotifyBlockProcessed(beat, nil) + case <-t.quit: log.Debug("Fee bumper stopped, exit monitor") return diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 9b06c05495..64306935c9 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -452,21 +452,12 @@ func (s *UtxoSweeper) Start() error { // not change from here on. s.relayFeeRate = s.cfg.FeeEstimator.RelayFeePerKW() - // We need to register for block epochs and retry sweeping every block. - // We should get a notification with the current best block immediately - // if we don't provide any epoch. We'll wait for that in the collector. - blockEpochs, err := s.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } - // Start sweeper main loop. s.wg.Add(1) go func() { - defer blockEpochs.Cancel() defer s.wg.Done() - s.collector(blockEpochs.Epochs) + s.collector() // The collector exited and won't longer handle incoming // requests. This can happen on shutdown, when the block @@ -657,17 +648,8 @@ func (s *UtxoSweeper) removeConflictSweepDescendants( // collector is the sweeper main loop. It processes new inputs, spend // notifications and counts down to publication of the sweep tx. -func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { - // We registered for the block epochs with a nil request. The notifier - // should send us the current best block immediately. So we need to wait - // for it here because we need to know the current best height. - select { - case bestBlock := <-blockEpochs: - s.currentHeight = bestBlock.Height - - case <-s.quit: - return - } +func (s *UtxoSweeper) collector() { + defer s.wg.Done() for { // Clean inputs, which will remove inputs that are swept, @@ -737,25 +719,16 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // A new block comes in, update the bestHeight, perform a check // over all pending inputs and publish sweeping txns if needed. - case epoch, ok := <-blockEpochs: - if !ok { - // We should stop the sweeper before stopping - // the chain service. Otherwise it indicates an - // error. - log.Error("Block epoch channel closed") - - return - } - + case beat := <-s.BlockbeatChan: // Update the sweeper to the best height. - s.currentHeight = epoch.Height + s.currentHeight = beat.Height() // Update the inputs with the latest height. inputs := s.updateSweeperInputs() log.Debugf("Received new block: height=%v, attempt "+ "sweeping %d inputs:\n%s", - epoch.Height, len(inputs), + s.currentHeight, len(inputs), lnutils.NewLogClosure(func() string { inps := make( []input.Input, 0, len(inputs), @@ -770,6 +743,9 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // Attempt to sweep any pending inputs. s.sweepPendingInputs(inputs) + // Notify we've processed the block. + s.NotifyBlockProcessed(beat, nil) + case <-s.quit: return } From 9ecd07227eb6621181829990be0ac6f2a05c8bc5 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 18 Nov 2024 11:09:21 +0800 Subject: [PATCH 28/59] sweep: remove redundant notifications during shutdown This commit removes the hack introduced in #4851. Previously we had this issue because the chain notifier was stopped before the sweeper, which was changed a while back and we now always stop the chain notifier last. In addition, since we no longer subscribe to the block epoch chan directly, this issue can no longer happen. --- sweep/sweeper.go | 33 +-------------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 64306935c9..b91a6808fb 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -454,38 +454,7 @@ func (s *UtxoSweeper) Start() error { // Start sweeper main loop. s.wg.Add(1) - go func() { - defer s.wg.Done() - - s.collector() - - // The collector exited and won't longer handle incoming - // requests. This can happen on shutdown, when the block - // notifier shuts down before the sweeper and its clients. In - // order to not deadlock the clients waiting for their requests - // being handled, we handle them here and immediately return an - // error. When the sweeper finally is shut down we can exit as - // the clients will be notified. - for { - select { - case inp := <-s.newInputs: - inp.resultChan <- Result{ - Err: ErrSweeperShuttingDown, - } - - case req := <-s.pendingSweepsReqs: - req.errChan <- ErrSweeperShuttingDown - - case req := <-s.updateReqs: - req.responseChan <- &updateResp{ - err: ErrSweeperShuttingDown, - } - - case <-s.quit: - return - } - } - }() + go s.collector() return nil } From 4622db28b2c73840381f6ff5855675fe0a703d2e Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 4 Jun 2024 20:53:33 +0800 Subject: [PATCH 29/59] contractcourt: remove `waitForHeight` in resolvers The sweeper can handle the waiting so there's no need to wait for blocks inside the resolvers. By offering the inputs prior to their mature heights also guarantees the inputs with the same deadline are aggregated. --- contractcourt/commit_sweep_resolver.go | 63 --------------------- contractcourt/commit_sweep_resolver_test.go | 26 ++------- contractcourt/htlc_success_resolver.go | 26 +-------- contractcourt/htlc_success_resolver_test.go | 4 -- contractcourt/htlc_timeout_resolver.go | 26 +-------- contractcourt/htlc_timeout_resolver_test.go | 5 -- 6 files changed, 6 insertions(+), 144 deletions(-) diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 6019a0dbc6..8e89d33fdc 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -101,36 +101,6 @@ func (c *commitSweepResolver) ResolverKey() []byte { return key[:] } -// waitForHeight registers for block notifications and waits for the provided -// block height to be reached. -func waitForHeight(waitHeight uint32, notifier chainntnfs.ChainNotifier, - quit <-chan struct{}) error { - - // Register for block epochs. After registration, the current height - // will be sent on the channel immediately. - blockEpochs, err := notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - defer blockEpochs.Cancel() - - for { - select { - case newBlock, ok := <-blockEpochs.Epochs: - if !ok { - return errResolverShuttingDown - } - height := newBlock.Height - if height >= int32(waitHeight) { - return nil - } - - case <-quit: - return errResolverShuttingDown - } - } -} - // waitForSpend waits for the given outpoint to be spent, and returns the // details of the spending tx. func waitForSpend(op *wire.OutPoint, pkScript []byte, heightHint uint32, @@ -225,39 +195,6 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { c.currentReport.MaturityHeight = unlockHeight c.reportLock.Unlock() - // If there is a csv/cltv lock, we'll wait for that. - if c.commitResolution.MaturityDelay > 0 || c.hasCLTV() { - // Determine what height we should wait until for the locks to - // expire. - var waitHeight uint32 - switch { - // If we have both a csv and cltv lock, we'll need to look at - // both and see which expires later. - case c.commitResolution.MaturityDelay > 0 && c.hasCLTV(): - c.log.Debugf("waiting for CSV and CLTV lock to expire "+ - "at height %v", unlockHeight) - // If the CSV expires after the CLTV, or there is no - // CLTV, then we can broadcast a sweep a block before. - // Otherwise, we need to broadcast at our expected - // unlock height. - waitHeight = uint32(math.Max( - float64(unlockHeight-1), float64(c.leaseExpiry), - )) - - // If we only have a csv lock, wait for the height before the - // lock expires as the spend path should be unlocked by then. - case c.commitResolution.MaturityDelay > 0: - c.log.Debugf("waiting for CSV lock to expire at "+ - "height %v", unlockHeight) - waitHeight = unlockHeight - 1 - } - - err := waitForHeight(waitHeight, c.Notifier, c.quit) - if err != nil { - return nil, err - } - } - var ( isLocalCommitTx bool diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index 077fb8f82c..85f5ba4e48 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -90,12 +90,6 @@ func (i *commitSweepResolverTestContext) resolve() { }() } -func (i *commitSweepResolverTestContext) notifyEpoch(height int32) { - i.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: height, - } -} - func (i *commitSweepResolverTestContext) waitForResult() { i.t.Helper() @@ -292,22 +286,10 @@ func testCommitSweepResolverDelay(t *testing.T, sweepErr error) { t.Fatal("report maturity height incorrect") } - // Notify initial block height. The csv lock is still in effect, so we - // don't expect any sweep to happen yet. - ctx.notifyEpoch(testInitialBlockHeight) - - select { - case <-ctx.sweeper.sweptInputs: - t.Fatal("no sweep expected") - case <-time.After(sweepProcessInterval): - } - - // A new block arrives. The commit tx confirmed at height -1 and the csv - // is 3, so a spend will be valid in the first block after height +1. - ctx.notifyEpoch(testInitialBlockHeight + 1) - - <-ctx.sweeper.sweptInputs - + // Notify initial block height. Although the csv lock is still in + // effect, we expect the input being sent to the sweeper before the csv + // lock expires. + // // Set the resolution report outcome based on whether our sweep // succeeded. outcome := channeldb.ResolverOutcomeClaimed diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 9d09f844dc..9e655cc48a 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -359,30 +359,6 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( "height %v", h, h.htlc.RHash[:], waitHeight) } - // Deduct one block so this input is offered to the sweeper one block - // earlier since the sweeper will wait for one block to trigger the - // sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) - if err != nil { - return nil, err - } - // We'll use this input index to determine the second-level output // index on the transaction, as the signatures requires the indexes to // be the same. We don't look for the second-level output script @@ -421,7 +397,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( h.htlc.RHash[:], budget, waitHeight) // TODO(roasbeef): need to update above for leased types - _, err = h.Sweeper.SweepInput( + _, err := h.Sweeper.SweepInput( inp, sweep.Params{ Budget: budget, diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index c0206d8f14..ae7f1b390a 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -437,10 +437,6 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { } } - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } - // We expect it to sweep the second-level // transaction we notfied about above. resolver := ctx.resolver.(*htlcSuccessResolver) diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 545e7c6135..0f019a36a1 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -789,30 +789,6 @@ func (h *htlcTimeoutResolver) handleCommitSpend( "height %v", h, h.htlc.RHash[:], waitHeight) } - // Deduct one block so this input is offered to the sweeper one - // block earlier since the sweeper will wait for one block to - // trigger the sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) - if err != nil { - return nil, err - } - // We'll use this input index to determine the second-level // output index on the transaction, as the signatures requires // the indexes to be the same. We don't look for the @@ -853,7 +829,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend( "sweeper with no deadline and budget=%v at height=%v", h, h.htlc.RHash[:], budget, waitHeight) - _, err = h.Sweeper.SweepInput( + _, err := h.Sweeper.SweepInput( inp, sweep.Params{ Budget: budget, diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index 0e4f1336c2..0b6ccd5b39 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -1120,11 +1120,6 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { t.Fatalf("resolution not sent") } - // Mimic CSV lock expiring. - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } - // The timeout tx output should now be given to // the sweeper. resolver := ctx.resolver.(*htlcTimeoutResolver) From 6f986a918959311b0dfdb52b8db3f96299747d46 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 29 Oct 2024 21:48:13 +0800 Subject: [PATCH 30/59] contractcourt: remove block subscription in chain arbitrator This commit removes the block subscriptions used in `ChainArbitrator` and replaced them with the blockbeat managed by `BlockbeatDispatcher`. --- contractcourt/breach_arbitrator_test.go | 2 +- contractcourt/chain_arbitrator.go | 124 +++++++----------------- contractcourt/chain_arbitrator_test.go | 2 - 3 files changed, 38 insertions(+), 90 deletions(-) diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index c387c21797..99ed852696 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -36,7 +36,7 @@ import ( ) var ( - defaultTimeout = 30 * time.Second + defaultTimeout = 10 * time.Second breachOutPoints = []wire.OutPoint{ { diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index e8ba2907aa..0a42dd57b7 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -267,6 +267,9 @@ type ChainArbitrator struct { // active channels that it must still watch over. chanSource *channeldb.DB + // beat is the current best known blockbeat. + beat chainio.Blockbeat + quit chan struct{} wg sync.WaitGroup @@ -797,18 +800,11 @@ func (c *ChainArbitrator) Start() error { } } - // Subscribe to a single stream of block epoch notifications that we - // will dispatch to all active arbitrators. - blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - // Start our goroutine which will dispatch blocks to each arbitrator. c.wg.Add(1) go func() { defer c.wg.Done() - c.dispatchBlocks(blockEpoch) + c.dispatchBlocks() }() // TODO(roasbeef): eventually move all breach watching here @@ -816,94 +812,22 @@ func (c *ChainArbitrator) Start() error { return nil } -// blockRecipient contains the information we need to dispatch a block to a -// channel arbitrator. -type blockRecipient struct { - // chanPoint is the funding outpoint of the channel. - chanPoint wire.OutPoint - - // blocks is the channel that new block heights are sent into. This - // channel should be sufficiently buffered as to not block the sender. - blocks chan<- int32 - - // quit is closed if the receiving entity is shutting down. - quit chan struct{} -} - // dispatchBlocks consumes a block epoch notification stream and dispatches // blocks to each of the chain arb's active channel arbitrators. This function // must be run in a goroutine. -func (c *ChainArbitrator) dispatchBlocks( - blockEpoch *chainntnfs.BlockEpochEvent) { - - // getRecipients is a helper function which acquires the chain arb - // lock and returns a set of block recipients which can be used to - // dispatch blocks. - getRecipients := func() []blockRecipient { - c.Lock() - blocks := make([]blockRecipient, 0, len(c.activeChannels)) - for _, channel := range c.activeChannels { - blocks = append(blocks, blockRecipient{ - chanPoint: channel.cfg.ChanPoint, - blocks: channel.blocks, - quit: channel.quit, - }) - } - c.Unlock() - - return blocks - } - - // On exit, cancel our blocks subscription and close each block channel - // so that the arbitrators know they will no longer be receiving blocks. - defer func() { - blockEpoch.Cancel() - - recipients := getRecipients() - for _, recipient := range recipients { - close(recipient.blocks) - } - }() - +func (c *ChainArbitrator) dispatchBlocks() { // Consume block epochs until we receive the instruction to shutdown. for { select { // Consume block epochs, exiting if our subscription is // terminated. - case block, ok := <-blockEpoch.Epochs: - if !ok { - log.Trace("dispatchBlocks block epoch " + - "cancelled") - return - } + case beat := <-c.BlockbeatChan: + // Set the current blockbeat. + c.beat = beat - // Get the set of currently active channels block - // subscription channels and dispatch the block to - // each. - for _, recipient := range getRecipients() { - select { - // Deliver the block to the arbitrator. - case recipient.blocks <- block.Height: - - // If the recipient is shutting down, exit - // without delivering the block. This may be - // the case when two blocks are mined in quick - // succession, and the arbitrator resolves - // after the first block, and does not need to - // consume the second block. - case <-recipient.quit: - log.Debugf("channel: %v exit without "+ - "receiving block: %v", - recipient.chanPoint, - block.Height) - - // If the chain arb is shutting down, we don't - // need to deliver any more blocks (everything - // will be shutting down). - case <-c.quit: - return - } - } + // Send this blockbeat to all the active channels and + // wait for them to finish processing it. + c.handleBlockbeat(beat) // Exit if the chain arbitrator is shutting down. case <-c.quit: @@ -912,6 +836,32 @@ func (c *ChainArbitrator) dispatchBlocks( } } +// handleBlockbeat sends the blockbeat to all active channel arbitrator in +// parallel and wait for them to finish processing it. +func (c *ChainArbitrator) handleBlockbeat(beat chainio.Blockbeat) { + // Read the active channels in a lock. + c.Lock() + + // Create a slice to record active channel arbitrator. + channels := make([]chainio.Consumer, 0, len(c.activeChannels)) + + // Copy the active channels to the slice. + for _, channel := range c.activeChannels { + channels = append(channels, channel) + } + + c.Unlock() + + // Iterate all the copied channels and send the blockbeat to them. + // + // NOTE: This method will timeout if the processing of blocks of the + // subsystems is too long (60s). + err := chainio.DispatchConcurrent(beat, channels) + + // Notify the chain arbitrator has processed the block. + c.NotifyBlockProcessed(beat, err) +} + // republishClosingTxs will load any stored cooperative or unilateral closing // transactions and republish them. This helps ensure propagation of the // transactions in the event that prior publications failed. diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index fe2603ca5a..de6d69900b 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -77,7 +77,6 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { @@ -158,7 +157,6 @@ func TestResolveContract(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { From 0d65d785f71ed06950fed706a357ce74966ca7b1 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 29 Oct 2024 21:48:58 +0800 Subject: [PATCH 31/59] contractcourt: remove block subscription in channel arbitrator This commit removes the block subscriptions used in `ChannelArbitrator`, replaced them with the blockbeat managed by `BlockbeatDispatcher`. --- contractcourt/channel_arbitrator.go | 53 +++++++++++++----------- contractcourt/channel_arbitrator_test.go | 48 +++++++++++++++++---- 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 65cba2170a..4ddcd43e13 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -357,11 +357,6 @@ type ChannelArbitrator struct { // to do its duty. cfg ChannelArbitratorConfig - // blocks is a channel that the arbitrator will receive new blocks on. - // This channel should be buffered by so that it does not block the - // sender. - blocks chan int32 - // signalUpdates is a channel that any new live signals for the channel // we're watching over will be sent. signalUpdates chan *signalUpdateMsg @@ -411,7 +406,6 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, c := &ChannelArbitrator{ log: log, - blocks: make(chan int32, arbitratorBlockBufferSize), signalUpdates: make(chan *signalUpdateMsg), resolutionSignal: make(chan struct{}), forceCloseReqs: make(chan *forceCloseReq), @@ -2769,31 +2763,21 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32, // A new block has arrived, we'll examine all the active HTLC's // to see if any of them have expired, and also update our // track of the best current height. - case blockHeight, ok := <-c.blocks: - if !ok { - return - } - bestHeight = blockHeight + case beat := <-c.BlockbeatChan: + bestHeight = beat.Height() - // If we're not in the default state, then we can - // ignore this signal as we're waiting for contract - // resolution. - if c.state != StateDefault { - continue - } + log.Debugf("ChannelArbitrator(%v): new block height=%v", + c.cfg.ChanPoint, bestHeight) - // Now that a new block has arrived, we'll attempt to - // advance our state forward. - nextState, _, err := c.advanceState( - uint32(bestHeight), chainTrigger, nil, - ) + err := c.handleBlockbeat(beat) if err != nil { - log.Errorf("Unable to advance state: %v", err) + log.Errorf("Handle block=%v got err: %v", + bestHeight, err) } // If as a result of this trigger, the contract is // fully resolved, then well exit. - if nextState == StateFullyResolved { + if c.state == StateFullyResolved { return } @@ -3144,6 +3128,27 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32, } } +// handleBlockbeat processes a newly received blockbeat by advancing the +// arbitrator's internal state using the received block height. +func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Blockbeat) error { + // Notify we've processed the block. + defer c.NotifyBlockProcessed(beat, nil) + + // Try to advance the state if we are in StateDefault. + if c.state == StateDefault { + // Now that a new block has arrived, we'll attempt to advance + // our state forward. + _, _, err := c.advanceState( + uint32(beat.Height()), chainTrigger, nil, + ) + if err != nil { + return fmt.Errorf("unable to advance state: %w", err) + } + } + + return nil +} + // Name returns a human-readable string for this subsystem. // // NOTE: Part of chainio.Consumer interface. diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 7150e5aea4..8715e1e63d 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" @@ -227,6 +228,15 @@ func (c *chanArbTestCtx) CleanUp() { } } +// receiveBlockbeat mocks the behavior of a blockbeat being sent by the +// BlockbeatDispatcher, which essentially mocks the method `ProcessBlock`. +func (c *chanArbTestCtx) receiveBlockbeat(height int) { + go func() { + beat := newBeatFromHeight(int32(height)) + c.chanArb.BlockbeatChan <- beat + }() +} + // AssertStateTransitions asserts that the state machine steps through the // passed states in order. func (c *chanArbTestCtx) AssertStateTransitions(expectedStates ...ArbitratorState) { @@ -1037,7 +1047,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { } require.Equal(t, expectedFinalHtlcs, chanArbCtx.finalHtlcs) - // We'll no re-create the resolver, notice that we use the existing + // We'll now re-create the resolver, notice that we use the existing // arbLog so it carries over the same on-disk state. chanArbCtxNew, err := chanArbCtx.Restart(nil) require.NoError(t, err, "unable to create ChannelArbitrator") @@ -1084,7 +1094,12 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { } // Send a notification that the expiry height has been reached. + // + // TODO(yy): remove the EpochChan and use the blockbeat below once + // resolvers are hooked with the blockbeat. oldNotifier.EpochChan <- &chainntnfs.BlockEpoch{Height: 10} + // beat := chainio.NewBlockbeatFromHeight(10) + // chanArb.BlockbeatChan <- beat // htlcOutgoingContestResolver is now transforming into a // htlcTimeoutResolver and should send the contract off for incubation. @@ -1924,7 +1939,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // now mine a block (height 5), which is 5 blocks away // (our grace delta) from the expiry of that HTLC. case testCase.htlcExpired: - chanArbCtx.chanArb.blocks <- 5 + beat := newBeatFromHeight(5) + chanArbCtx.chanArb.BlockbeatChan <- beat // Otherwise, we'll just trigger a regular force close // request. @@ -2036,8 +2052,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // so instead, we'll mine another block which'll cause // it to re-examine its state and realize there're no // more HTLCs. - chanArbCtx.chanArb.blocks <- 6 - chanArbCtx.AssertStateTransitions(StateFullyResolved) + chanArbCtx.receiveBlockbeat(6) }) } } @@ -2108,13 +2123,15 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { // We will advance the uptime to 10 seconds which should be still within // the grace period and should not trigger going to chain. testClock.SetTime(startTime.Add(time.Second * 10)) - chanArbCtx.chanArb.blocks <- 5 + beat := newBeatFromHeight(5) + chanArbCtx.chanArb.BlockbeatChan <- beat chanArbCtx.AssertState(StateDefault) // We will advance the uptime to 16 seconds which should trigger going // to chain. testClock.SetTime(startTime.Add(time.Second * 16)) - chanArbCtx.chanArb.blocks <- 6 + beat = newBeatFromHeight(6) + chanArbCtx.chanArb.BlockbeatChan <- beat chanArbCtx.AssertStateTransitions( StateBroadcastCommit, StateCommitmentBroadcasted, @@ -2482,7 +2499,7 @@ func TestSweepAnchors(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + chanArbCtx.receiveBlockbeat(int(heightHint)) htlcIndexBase := uint64(99) deadlineDelta := uint32(10) @@ -2645,7 +2662,7 @@ func TestSweepLocalAnchor(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + chanArbCtx.receiveBlockbeat(int(heightHint)) htlcIndex := uint64(99) deadlineDelta := uint32(10) @@ -2793,7 +2810,8 @@ func TestChannelArbitratorAnchors(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + beat := newBeatFromHeight(int32(heightHint)) + chanArbCtx.chanArb.BlockbeatChan <- beat htlcAmt := lnwire.MilliSatoshi(1_000_000) @@ -2960,10 +2978,14 @@ func TestChannelArbitratorAnchors(t *testing.T) { // to htlcWithPreimage's CLTV. require.Equal(t, 2, len(chanArbCtx.sweeper.deadlines)) require.EqualValues(t, + heightHint+deadlinePreimageDelta/2, + chanArbCtx.sweeper.deadlines[0], "want %d, got %d", heightHint+deadlinePreimageDelta/2, chanArbCtx.sweeper.deadlines[0], ) require.EqualValues(t, + heightHint+deadlinePreimageDelta/2, + chanArbCtx.sweeper.deadlines[1], "want %d, got %d", heightHint+deadlinePreimageDelta/2, chanArbCtx.sweeper.deadlines[1], ) @@ -3146,3 +3168,11 @@ func (m *mockChannel) ForceCloseChan() (*wire.MsgTx, error) { return &wire.MsgTx{}, nil } + +func newBeatFromHeight(height int32) *chainio.Beat { + epoch := chainntnfs.BlockEpoch{ + Height: height, + } + + return chainio.NewBeat(epoch) +} From d61f054c0c9f1cdeb0f9913ab1cae3d2b8025e33 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 5 Jun 2024 00:56:39 +0800 Subject: [PATCH 32/59] contractcourt: remove the `immediate` param used in `Resolve` This `immediate` flag was added as a hack so during a restart, the pending resolvers would offer the inputs to the sweeper and ask it to sweep them immediately. This is no longer need due to `blockbeat`, as now during restart, a block is always sent to all subsystems via the flow `ChainArb` -> `ChannelArb` -> resolvers -> sweeper. Thus, when there are pending inputs offered, they will be processed by the sweeper immediately. --- contractcourt/anchor_resolver.go | 2 +- contractcourt/breach_resolver.go | 2 +- contractcourt/channel_arbitrator.go | 20 +++++++--------- contractcourt/commit_sweep_resolver.go | 4 +--- contractcourt/commit_sweep_resolver_test.go | 2 +- contractcourt/contract_resolver.go | 2 +- .../htlc_incoming_contest_resolver.go | 4 +--- .../htlc_incoming_contest_resolver_test.go | 2 +- .../htlc_outgoing_contest_resolver.go | 4 +--- .../htlc_outgoing_contest_resolver_test.go | 2 +- contractcourt/htlc_success_resolver.go | 24 +++++++------------ contractcourt/htlc_success_resolver_test.go | 2 +- contractcourt/htlc_timeout_resolver.go | 20 +++++++--------- contractcourt/htlc_timeout_resolver_test.go | 2 +- 14 files changed, 36 insertions(+), 56 deletions(-) diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index e482c4c713..af7ac76462 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -84,7 +84,7 @@ func (c *anchorResolver) ResolverKey() []byte { } // Resolve offers the anchor output to the sweeper and waits for it to be swept. -func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) { +func (c *anchorResolver) Resolve() (ContractResolver, error) { // Attempt to update the sweep parameters to the post-confirmation // situation. We don't want to force sweep anymore, because the anchor // lost its special purpose to get the commitment confirmed. It is just diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index 740b4471d5..63395651cc 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -47,7 +47,7 @@ func (b *breachResolver) ResolverKey() []byte { // been broadcast. // // TODO(yy): let sweeper handle the breach inputs. -func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) { +func (b *breachResolver) Resolve() (ContractResolver, error) { if !b.subscribed { complete, err := b.SubscribeBreachComplete( &b.ChanPoint, b.replyChan, diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 4ddcd43e13..dfbf6c97f1 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -816,7 +816,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // TODO(roasbeef): this isn't re-launched? } - c.launchResolvers(unresolvedContracts, true) + c.launchResolvers(unresolvedContracts) return nil } @@ -1355,7 +1355,7 @@ func (c *ChannelArbitrator) stateStep( // Finally, we'll launch all the required contract resolvers. // Once they're all resolved, we're no longer needed. - c.launchResolvers(resolvers, false) + c.launchResolvers(resolvers) nextState = StateWaitingFullResolution @@ -1579,16 +1579,14 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32, } // launchResolvers updates the activeResolvers list and starts the resolvers. -func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver, - immediate bool) { - +func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver) { c.activeResolversLock.Lock() - defer c.activeResolversLock.Unlock() - c.activeResolvers = resolvers + c.activeResolversLock.Unlock() + for _, contract := range resolvers { c.wg.Add(1) - go c.resolveContract(contract, immediate) + go c.resolveContract(contract) } } @@ -2560,9 +2558,7 @@ func (c *ChannelArbitrator) replaceResolver(oldResolver, // contracts. // // NOTE: This MUST be run as a goroutine. -func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, - immediate bool) { - +func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver) { defer c.wg.Done() log.Debugf("ChannelArbitrator(%v): attempting to resolve %T", @@ -2583,7 +2579,7 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, default: // Otherwise, we'll attempt to resolve the current // contract. - nextContract, err := currentContract.Resolve(immediate) + nextContract, err := currentContract.Resolve() if err != nil { if err == errResolverShuttingDown { return diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 8e89d33fdc..9a3bb058ae 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -165,9 +165,7 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { // returned. // // NOTE: This function MUST be run as a goroutine. -// -//nolint:funlen -func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { +func (c *commitSweepResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. if c.resolved { return nil, nil diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index 85f5ba4e48..2195e33779 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -82,7 +82,7 @@ func (i *commitSweepResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index f5a88f24e6..3629c1bc3c 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -42,7 +42,7 @@ type ContractResolver interface { // resolution, then another resolve is returned. // // NOTE: This function MUST be run as a goroutine. - Resolve(immediate bool) (ContractResolver, error) + Resolve() (ContractResolver, error) // SupplementState allows the user of a ContractResolver to supplement // it with state required for the proper resolution of a contract. diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index e5be63cbf7..ebac495835 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -90,9 +90,7 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error { // as we have no remaining actions left at our disposal. // // NOTE: Part of the ContractResolver interface. -func (h *htlcIncomingContestResolver) Resolve( - _ bool) (ContractResolver, error) { - +func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. if h.resolved { diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index 22280f953e..e8b8eac0c9 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -395,7 +395,7 @@ func (i *incomingResolverTestContext) resolve() { i.resolveErr = make(chan error, 1) go func() { var err error - i.nextResolver, err = i.resolver.Resolve(false) + i.nextResolver, err = i.resolver.Resolve() i.resolveErr <- err }() diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 1303d0af60..ec32ff7f17 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -49,9 +49,7 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, // When either of these two things happens, we'll create a new resolver which // is able to handle the final resolution of the contract. We're only the pivot // point. -func (h *htlcOutgoingContestResolver) Resolve( - _ bool) (ContractResolver, error) { - +func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. if h.resolved { diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index 6608a6fb51..4fa3a6874f 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -209,7 +209,7 @@ func (i *outgoingResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 9e655cc48a..06ebf4edc4 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -115,9 +115,7 @@ func (h *htlcSuccessResolver) ResolverKey() []byte { // TODO(roasbeef): create multi to batch // // NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) Resolve( - immediate bool) (ContractResolver, error) { - +func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { return nil, nil @@ -126,12 +124,12 @@ func (h *htlcSuccessResolver) Resolve( // If we don't have a success transaction, then this means that this is // an output on the remote party's commitment transaction. if h.htlcResolution.SignedSuccessTx == nil { - return h.resolveRemoteCommitOutput(immediate) + return h.resolveRemoteCommitOutput() } // Otherwise this an output on our own commitment, and we must start by // broadcasting the second-level success transaction. - secondLevelOutpoint, err := h.broadcastSuccessTx(immediate) + secondLevelOutpoint, err := h.broadcastSuccessTx() if err != nil { return nil, err } @@ -165,8 +163,8 @@ func (h *htlcSuccessResolver) Resolve( // broadcasting the second-level success transaction. It returns the ultimate // outpoint of the second-level tx, that we must wait to be spent for the // resolver to be fully resolved. -func (h *htlcSuccessResolver) broadcastSuccessTx( - immediate bool) (*wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastSuccessTx() ( + *wire.OutPoint, error) { // If we have non-nil SignDetails, this means that have a 2nd level // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY @@ -175,7 +173,7 @@ func (h *htlcSuccessResolver) broadcastSuccessTx( // the checkpointed outputIncubating field to determine if we already // swept the HTLC output into the second level transaction. if h.htlcResolution.SignDetails != nil { - return h.broadcastReSignedSuccessTx(immediate) + return h.broadcastReSignedSuccessTx() } // Otherwise we'll publish the second-level transaction directly and @@ -225,10 +223,8 @@ func (h *htlcSuccessResolver) broadcastSuccessTx( // broadcastReSignedSuccessTx handles the case where we have non-nil // SignDetails, and offers the second level transaction to the Sweeper, that // will re-sign it and attach fees at will. -// -//nolint:funlen -func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( - *wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, + error) { // Keep track of the tx spending the HTLC output on the commitment, as // this will be the confirmed second-level tx we'll ultimately sweep. @@ -287,7 +283,6 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( sweep.Params{ Budget: budget, DeadlineHeight: deadline, - Immediate: immediate, }, ) if err != nil { @@ -419,7 +414,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( // resolveRemoteCommitOutput handles sweeping an HTLC output on the remote // commitment with the preimage. In this case we can sweep the output directly, // and don't have to broadcast a second-level transaction. -func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) ( +func (h *htlcSuccessResolver) resolveRemoteCommitOutput() ( ContractResolver, error) { isTaproot := txscript.IsPayToTaproot( @@ -471,7 +466,6 @@ func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) ( sweep.Params{ Budget: budget, DeadlineHeight: deadline, - Immediate: immediate, }, ) if err != nil { diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index ae7f1b390a..75c733638f 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -134,7 +134,7 @@ func (i *htlcResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 0f019a36a1..81d8e85d21 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -418,9 +418,7 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool { // see a direct sweep via the timeout clause. // // NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) Resolve( - immediate bool) (ContractResolver, error) { - +func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { return nil, nil @@ -429,7 +427,7 @@ func (h *htlcTimeoutResolver) Resolve( // Start by spending the HTLC output, either by broadcasting the // second-level timeout transaction, or directly if this is the remote // commitment. - commitSpend, err := h.spendHtlcOutput(immediate) + commitSpend, err := h.spendHtlcOutput() if err != nil { return nil, err } @@ -477,7 +475,7 @@ func (h *htlcTimeoutResolver) Resolve( // sweepSecondLevelTx sends a second level timeout transaction to the sweeper. // This transaction uses the SINLGE|ANYONECANPAY flag. -func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { +func (h *htlcTimeoutResolver) sweepSecondLevelTx() error { log.Infof("%T(%x): offering second-layer timeout tx to sweeper: %v", h, h.htlc.RHash[:], spew.Sdump(h.htlcResolution.SignedTimeoutTx)) @@ -538,7 +536,6 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { sweep.Params{ Budget: budget, DeadlineHeight: h.incomingHTLCExpiryHeight, - Immediate: immediate, }, ) if err != nil { @@ -572,7 +569,7 @@ func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error { // sweeper. This is used when the remote party goes on chain, and we're able to // sweep an HTLC we offered after a timeout. Only the CLTV encumbered outputs // are resolved via this path. -func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error { +func (h *htlcTimeoutResolver) sweepDirectHtlcOutput() error { var htlcWitnessType input.StandardWitnessType if h.isTaproot() { htlcWitnessType = input.TaprootHtlcOfferedRemoteTimeout @@ -612,7 +609,6 @@ func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error { // This is an outgoing HTLC, so we want to make sure // that we sweep it before the incoming HTLC expires. DeadlineHeight: h.incomingHTLCExpiryHeight, - Immediate: immediate, }, ) if err != nil { @@ -627,8 +623,8 @@ func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error { // used to spend the output into the next stage. If this is the remote // commitment, the output will be swept directly without the timeout // transaction. -func (h *htlcTimeoutResolver) spendHtlcOutput( - immediate bool) (*chainntnfs.SpendDetail, error) { +func (h *htlcTimeoutResolver) spendHtlcOutput() ( + *chainntnfs.SpendDetail, error) { switch { // If we have non-nil SignDetails, this means that have a 2nd level @@ -636,7 +632,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput( // (the case for anchor type channels). In this case we can re-sign it // and attach fees at will. We let the sweeper handle this job. case h.htlcResolution.SignDetails != nil && !h.outputIncubating: - if err := h.sweepSecondLevelTx(immediate); err != nil { + if err := h.sweepSecondLevelTx(); err != nil { log.Errorf("Sending timeout tx to sweeper: %v", err) return nil, err @@ -645,7 +641,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput( // If this is a remote commitment there's no second level timeout txn, // and we can just send this directly to the sweeper. case h.htlcResolution.SignedTimeoutTx == nil && !h.outputIncubating: - if err := h.sweepDirectHtlcOutput(immediate); err != nil { + if err := h.sweepDirectHtlcOutput(); err != nil { log.Errorf("Sending direct spend to sweeper: %v", err) return nil, err diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index 0b6ccd5b39..17341f0c20 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -390,7 +390,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { go func() { defer wg.Done() - _, err := resolver.Resolve(false) + _, err := resolver.Resolve() if err != nil { resolveErr <- err } From 8261d2a941b80ea3a44257dc3dd8eb667fccc19e Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 29 Oct 2024 22:01:16 +0800 Subject: [PATCH 33/59] contractcourt: start channel arbitrator with blockbeat To avoid calling GetBestBlock again. --- contractcourt/chain_arbitrator.go | 4 +- contractcourt/channel_arbitrator.go | 10 ++-- contractcourt/channel_arbitrator_test.go | 65 +++++++++++++++--------- 3 files changed, 48 insertions(+), 31 deletions(-) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 0a42dd57b7..34ed7d7ff5 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -794,7 +794,7 @@ func (c *ChainArbitrator) Start() error { arbitrator.cfg.ChanPoint) } - if err := arbitrator.Start(startState); err != nil { + if err := arbitrator.Start(startState, c.beat); err != nil { stopAndLog() return err } @@ -1211,7 +1211,7 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error // arbitrators, then launch it. c.activeChannels[chanPoint] = channelArb - if err := channelArb.Start(nil); err != nil { + if err := channelArb.Start(nil, c.beat); err != nil { return err } diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index dfbf6c97f1..8856b98d3c 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -462,7 +462,9 @@ func (c *ChannelArbitrator) getStartState(tx kvdb.RTx) (*chanArbStartState, // Start starts all the goroutines that the ChannelArbitrator needs to operate. // If takes a start state, which will be looked up on disk if it is not // provided. -func (c *ChannelArbitrator) Start(state *chanArbStartState) error { +func (c *ChannelArbitrator) Start(state *chanArbStartState, + beat chainio.Blockbeat) error { + if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { return nil } @@ -484,10 +486,8 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { // Set our state from our starting state. c.state = state.currentState - _, bestHeight, err := c.cfg.ChainIO.GetBestBlock() - if err != nil { - return err - } + // Get the starting height. + bestHeight := beat.Height() c.wg.Add(1) go c.channelAttendant(bestHeight, state.commitSet) diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 8715e1e63d..a3bdf6ba3b 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -296,7 +296,8 @@ func (c *chanArbTestCtx) Restart(restartClosure func(*chanArbTestCtx)) (*chanArb restartClosure(newCtx) } - if err := newCtx.chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := newCtx.chanArb.Start(nil, beat); err != nil { return nil, err } @@ -523,7 +524,8 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) { chanArbCtx, err := createTestChannelArbitrator(t, log) require.NoError(t, err, "unable to create ChannelArbitrator") - if err := chanArbCtx.chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArbCtx.chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -581,7 +583,8 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -634,7 +637,8 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -746,7 +750,8 @@ func TestChannelArbitratorBreachClose(t *testing.T) { chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.Registry = &mockRegistry{} - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -873,7 +878,8 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.Registry = &mockRegistry{} - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1163,7 +1169,8 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1270,7 +1277,8 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1376,7 +1384,8 @@ func TestChannelArbitratorPersistence(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1494,7 +1503,8 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1681,7 +1691,8 @@ func TestChannelArbitratorCommitFailure(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1765,7 +1776,8 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) { chanArb.cfg.ClosingHeight = 100 chanArb.cfg.CloseType = channeldb.RemoteForceClose - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(100) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1795,7 +1807,8 @@ func TestChannelArbitratorAlreadyForceClosed(t *testing.T) { chanArbCtx, err := createTestChannelArbitrator(t, log) require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1893,9 +1906,10 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { - t.Fatalf("unable to start ChannelArbitrator: %v", err) - } + beat := newBeatFromHeight(0) + err = chanArb.Start(nil, beat) + require.NoError(t, err) + defer chanArb.Stop() // Now that our channel arb has started, we'll set up @@ -2089,7 +2103,8 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { return false } - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -2123,7 +2138,7 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { // We will advance the uptime to 10 seconds which should be still within // the grace period and should not trigger going to chain. testClock.SetTime(startTime.Add(time.Second * 10)) - beat := newBeatFromHeight(5) + beat = newBeatFromHeight(5) chanArbCtx.chanArb.BlockbeatChan <- beat chanArbCtx.AssertState(StateDefault) @@ -2244,8 +2259,8 @@ func TestRemoteCloseInitiator(t *testing.T) { "ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start "+ "ChannelArbitrator: %v", err) } @@ -2796,7 +2811,9 @@ func TestChannelArbitratorAnchors(t *testing.T) { }, } - if err := chanArb.Start(nil); err != nil { + heightHint := uint32(1000) + beat := newBeatFromHeight(int32(heightHint)) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -2809,8 +2826,7 @@ func TestChannelArbitratorAnchors(t *testing.T) { chanArb.UpdateContractSignals(signals) // Set current block height. - heightHint := uint32(1000) - beat := newBeatFromHeight(int32(heightHint)) + beat = newBeatFromHeight(int32(heightHint)) chanArbCtx.chanArb.BlockbeatChan <- beat htlcAmt := lnwire.MilliSatoshi(1_000_000) @@ -3089,7 +3105,8 @@ func TestChannelArbitratorStartForceCloseFail(t *testing.T) { return test.broadcastErr } - err = chanArb.Start(nil) + beat := newBeatFromHeight(0) + err = chanArb.Start(nil, beat) if !test.expectedStartup { require.ErrorIs(t, err, test.broadcastErr) From c1ec4365786624b0f02e2fcb97f4eb4f8d554e3a Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 29 Oct 2024 22:02:13 +0800 Subject: [PATCH 34/59] multi: start consumers with a starting blockbeat This is needed so the consumers have an initial state about the current block. --- contractcourt/chain_arbitrator.go | 9 ++++-- contractcourt/chain_arbitrator_test.go | 6 ++-- server.go | 45 ++++++++++++++++++++++++-- sweep/fee_bumper.go | 5 ++- sweep/sweeper.go | 5 ++- 5 files changed, 60 insertions(+), 10 deletions(-) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 34ed7d7ff5..524958f48f 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -570,13 +570,16 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error { } // Start launches all goroutines that the ChainArbitrator needs to operate. -func (c *ChainArbitrator) Start() error { +func (c *ChainArbitrator) Start(beat chainio.Blockbeat) error { if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { return nil } - log.Infof("ChainArbitrator starting with config: budget=[%v]", - &c.cfg.Budget) + // Set the current beat. + c.beat = beat + + log.Infof("ChainArbitrator starting at height %d with budget=[%v]", + &c.cfg.Budget, c.beat.Height()) // First, we'll fetch all the channels that are still open, in order to // collect them within our set of active contracts. diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index de6d69900b..622686f76c 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -90,7 +90,8 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { chainArbCfg, db, ) - if err := chainArb.Start(); err != nil { + beat := newBeatFromHeight(0) + if err := chainArb.Start(beat); err != nil { t.Fatal(err) } t.Cleanup(func() { @@ -173,7 +174,8 @@ func TestResolveContract(t *testing.T) { chainArb := NewChainArbitrator( chainArbCfg, db, ) - if err := chainArb.Start(); err != nil { + beat := newBeatFromHeight(0) + if err := chainArb.Start(beat); err != nil { t.Fatal(err) } t.Cleanup(func() { diff --git a/server.go b/server.go index 799bfde6b8..f83333bc68 100644 --- a/server.go +++ b/server.go @@ -28,6 +28,7 @@ import ( "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/brontide" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/chanbackup" @@ -2073,6 +2074,12 @@ func (c cleaner) run() { // //nolint:funlen func (s *server) Start() error { + // Get the current blockbeat. + beat, err := s.getStartingBeat() + if err != nil { + return err + } + var startErr error // If one sub system fails to start, the following code ensures that the @@ -2167,13 +2174,13 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.txPublisher.Stop) - if err := s.txPublisher.Start(); err != nil { + if err := s.txPublisher.Start(beat); err != nil { startErr = err return } cleanup = cleanup.add(s.sweeper.Stop) - if err := s.sweeper.Start(); err != nil { + if err := s.sweeper.Start(beat); err != nil { startErr = err return } @@ -2218,7 +2225,7 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.chainArb.Stop) - if err := s.chainArb.Start(); err != nil { + if err := s.chainArb.Start(beat); err != nil { startErr = err return } @@ -5152,3 +5159,35 @@ func (s *server) fetchClosedChannelSCIDs() map[lnwire.ShortChannelID]struct{} { return closedSCIDs } + +// getStartingBeat returns the current beat. This is used during the startup to +// initialize blockbeat consumers. +func (s *server) getStartingBeat() (*chainio.Beat, error) { + // beat is the current blockbeat. + var beat *chainio.Beat + + // We should get a notification with the current best block immediately + // by passing a nil block. + blockEpochs, err := s.cc.ChainNotifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return beat, fmt.Errorf("register block epoch ntfn: %w", err) + } + defer blockEpochs.Cancel() + + // We registered for the block epochs with a nil request. The notifier + // should send us the current best block immediately. So we need to + // wait for it here because we need to know the current best height. + select { + case bestBlock := <-blockEpochs.Epochs: + srvrLog.Infof("Received initial block %v at height %d", + bestBlock.Hash, bestBlock.Height) + + // Update the current blockbeat. + beat = chainio.NewBeat(*bestBlock) + + case <-s.quit: + srvrLog.Debug("LND shutting down") + } + + return beat, nil +} diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index ad02ac6194..a43f875850 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -795,13 +795,16 @@ type monitorRecord struct { // Start starts the publisher by subscribing to block epoch updates and kicking // off the monitor loop. -func (t *TxPublisher) Start() error { +func (t *TxPublisher) Start(beat chainio.Blockbeat) error { log.Info("TxPublisher starting...") if t.started.Swap(true) { return fmt.Errorf("TxPublisher started more than once") } + // Set the current height. + t.currentHeight.Store(beat.Height()) + t.wg.Add(1) go t.monitor() diff --git a/sweep/sweeper.go b/sweep/sweeper.go index b91a6808fb..0f2675e8f8 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -441,7 +441,7 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { } // Start starts the process of constructing and publish sweep txes. -func (s *UtxoSweeper) Start() error { +func (s *UtxoSweeper) Start(beat chainio.Blockbeat) error { if !atomic.CompareAndSwapUint32(&s.started, 0, 1) { return nil } @@ -452,6 +452,9 @@ func (s *UtxoSweeper) Start() error { // not change from here on. s.relayFeeRate = s.cfg.FeeEstimator.RelayFeePerKW() + // Set the current height. + s.currentHeight = beat.Height() + // Start sweeper main loop. s.wg.Add(1) go s.collector() From 947286c35fb0d432a5f0be7b9debe4717c31ae04 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 17 Oct 2024 09:58:48 +0800 Subject: [PATCH 35/59] lnd: add new method `startLowLevelServices` In this commit we start to break up the starting process into smaller pieces, which is needed in the following commit to initialize blockbeat consumers. --- server.go | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/server.go b/server.go index f83333bc68..5b64ade348 100644 --- a/server.go +++ b/server.go @@ -666,6 +666,17 @@ func newServer(cfg *Config, listenAddrs []net.Addr, quit: make(chan struct{}), } + // Start the low-level services once they are initialized. + // + // TODO(yy): break the server startup into four steps, + // 1. init the low-level services. + // 2. start the low-level services. + // 3. init the high-level services. + // 4. start the high-level services. + if err := s.startLowLevelServices(); err != nil { + return nil, err + } + currentHash, currentHeight, err := s.cc.ChainIO.GetBestBlock() if err != nil { return nil, err @@ -2068,6 +2079,29 @@ func (c cleaner) run() { } } +// startLowLevelServices starts the low-level services of the server. These +// services must be started successfully before running the main server. The +// services are, +// 1. the chain notifier. +// +// TODO(yy): identify and add more low-level services here. +func (s *server) startLowLevelServices() error { + var startErr error + + cleanup := cleaner{} + + cleanup = cleanup.add(s.cc.ChainNotifier.Stop) + if err := s.cc.ChainNotifier.Start(); err != nil { + startErr = err + } + + if startErr != nil { + cleanup.run() + } + + return startErr +} + // Start starts the main daemon server, all requested listeners, and any helper // goroutines. // NOTE: This function is safe for concurrent access. @@ -2133,12 +2167,6 @@ func (s *server) Start() error { return } - cleanup = cleanup.add(s.cc.ChainNotifier.Stop) - if err := s.cc.ChainNotifier.Start(); err != nil { - startErr = err - return - } - cleanup = cleanup.add(s.cc.BestBlockTracker.Stop) if err := s.cc.BestBlockTracker.Start(); err != nil { startErr = err From da68b66f3c590032287c48587cca2f9d3716e2bc Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 17 Oct 2024 10:14:00 +0800 Subject: [PATCH 36/59] lnd: start `blockbeatDispatcher` and register consumers --- log.go | 2 ++ server.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/log.go b/log.go index a3efd03335..46047fb56c 100644 --- a/log.go +++ b/log.go @@ -9,6 +9,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" @@ -196,6 +197,7 @@ func SetupLoggers(root *build.SubLoggerManager, interceptor signal.Interceptor) root, blindedpath.Subsystem, interceptor, blindedpath.UseLogger, ) AddV1SubLogger(root, graphdb.Subsystem, interceptor, graphdb.UseLogger) + AddSubLogger(root, chainio.Subsystem, interceptor, chainio.UseLogger) } // AddSubLogger is a helper method to conveniently create and register the diff --git a/server.go b/server.go index 5b64ade348..9fd0b7a006 100644 --- a/server.go +++ b/server.go @@ -357,6 +357,10 @@ type server struct { // txPublisher is a publisher with fee-bumping capability. txPublisher *sweep.TxPublisher + // blockbeatDispatcher is a block dispatcher that notifies subscribers + // of new blocks. + blockbeatDispatcher *chainio.BlockbeatDispatcher + quit chan struct{} wg sync.WaitGroup @@ -624,6 +628,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, readPool: readPool, chansToRestore: chansToRestore, + blockbeatDispatcher: chainio.NewBlockbeatDispatcher( + cc.ChainNotifier, + ), channelNotifier: channelnotifier.New( dbs.ChanStateDB.ChannelStateDB(), ), @@ -1825,6 +1832,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } s.connMgr = cmgr + // Finally, register the subsystems in blockbeat. + s.registerBlockConsumers() + return s, nil } @@ -1857,6 +1867,25 @@ func (s *server) UpdateRoutingConfig(cfg *routing.MissionControlConfig) { routerCfg.MaxMcHistory = cfg.MaxMcHistory } +// registerBlockConsumers registers the subsystems that consume block events. +// By calling `RegisterQueue`, a list of subsystems are registered in the +// blockbeat for block notifications. When a new block arrives, the subsystems +// in the same queue are notified sequentially, and different queues are +// notified concurrently. +// +// NOTE: To put a subsystem in a different queue, create a slice and pass it to +// a new `RegisterQueue` call. +func (s *server) registerBlockConsumers() { + // In this queue, when a new block arrives, it will be received and + // processed in this order: chainArb -> sweeper -> txPublisher. + consumers := []chainio.Consumer{ + s.chainArb, + s.sweeper, + s.txPublisher, + } + s.blockbeatDispatcher.RegisterQueue(consumers) +} + // signAliasUpdate takes a ChannelUpdate and returns the signature. This is // used for option_scid_alias channels where the ChannelUpdate to be sent back // may differ from what is on disk. @@ -2494,6 +2523,17 @@ func (s *server) Start() error { srvrLog.Infof("Auto peer bootstrapping is disabled") } + // Start the blockbeat after all other subsystems have been + // started so they are ready to receive new blocks. + cleanup = cleanup.add(func() error { + s.blockbeatDispatcher.Stop() + return nil + }) + if err := s.blockbeatDispatcher.Start(); err != nil { + startErr = err + return + } + // Set the active flag now that we've completed the full // startup. atomic.StoreInt32(&s.active, 1) @@ -2518,6 +2558,9 @@ func (s *server) Stop() error { // Shutdown connMgr first to prevent conns during shutdown. s.connMgr.Stop() + // Stop dispatching blocks to other systems immediately. + s.blockbeatDispatcher.Stop() + // Shutdown the wallet, funding manager, and the rpc server. if err := s.chanStatusMgr.Stop(); err != nil { srvrLog.Warnf("failed to stop chanStatusMgr: %v", err) From 0579c435c83f6ccae464c69fd5d486a7cacfbf77 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 30 Oct 2024 04:38:27 +0800 Subject: [PATCH 37/59] contractcourt: fix linter `funlen` Refactor the `Start` method to fix the linter error: ``` contractcourt/chain_arbitrator.go:568: Function 'Start' is too long (242 > 200) (funlen) ``` --- contractcourt/chain_arbitrator.go | 270 ++++++++++++++----------- contractcourt/commit_sweep_resolver.go | 4 + 2 files changed, 151 insertions(+), 123 deletions(-) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 524958f48f..011b5225cd 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -583,137 +583,17 @@ func (c *ChainArbitrator) Start(beat chainio.Blockbeat) error { // First, we'll fetch all the channels that are still open, in order to // collect them within our set of active contracts. - openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels() - if err != nil { + if err := c.loadOpenChannels(); err != nil { return err } - if len(openChannels) > 0 { - log.Infof("Creating ChannelArbitrators for %v active channels", - len(openChannels)) - } - - // For each open channel, we'll configure then launch a corresponding - // ChannelArbitrator. - for _, channel := range openChannels { - chanPoint := channel.FundingOutpoint - channel := channel - - // First, we'll create an active chainWatcher for this channel - // to ensure that we detect any relevant on chain events. - breachClosure := func(ret *lnwallet.BreachRetribution) error { - return c.cfg.ContractBreach(chanPoint, ret) - } - - chainWatcher, err := newChainWatcher( - chainWatcherConfig{ - chanState: channel, - notifier: c.cfg.Notifier, - signer: c.cfg.Signer, - isOurAddr: c.cfg.IsOurAddress, - contractBreach: breachClosure, - extractStateNumHint: lnwallet.GetStateNumHint, - auxLeafStore: c.cfg.AuxLeafStore, - auxResolver: c.cfg.AuxResolver, - }, - ) - if err != nil { - return err - } - - c.activeWatchers[chanPoint] = chainWatcher - channelArb, err := newActiveChannelArbitrator( - channel, c, chainWatcher.SubscribeChannelEvents(), - ) - if err != nil { - return err - } - - c.activeChannels[chanPoint] = channelArb - - // Republish any closing transactions for this channel. - err = c.republishClosingTxs(channel) - if err != nil { - log.Errorf("Failed to republish closing txs for "+ - "channel %v", chanPoint) - } - } - // In addition to the channels that we know to be open, we'll also // launch arbitrators to finishing resolving any channels that are in // the pending close state. - closingChannels, err := c.chanSource.ChannelStateDB().FetchClosedChannels( - true, - ) - if err != nil { + if err := c.loadPendingCloseChannels(); err != nil { return err } - if len(closingChannels) > 0 { - log.Infof("Creating ChannelArbitrators for %v closing channels", - len(closingChannels)) - } - - // Next, for each channel is the closing state, we'll launch a - // corresponding more restricted resolver, as we don't have to watch - // the chain any longer, only resolve the contracts on the confirmed - // commitment. - //nolint:ll - for _, closeChanInfo := range closingChannels { - // We can leave off the CloseContract and ForceCloseChan - // methods as the channel is already closed at this point. - chanPoint := closeChanInfo.ChanPoint - arbCfg := ChannelArbitratorConfig{ - ChanPoint: chanPoint, - ShortChanID: closeChanInfo.ShortChanID, - ChainArbitratorConfig: c.cfg, - ChainEvents: &ChainEventSubscription{}, - IsPendingClose: true, - ClosingHeight: closeChanInfo.CloseHeight, - CloseType: closeChanInfo.CloseType, - PutResolverReport: func(tx kvdb.RwTx, - report *channeldb.ResolverReport) error { - - return c.chanSource.PutResolverReport( - tx, c.cfg.ChainHash, &chanPoint, report, - ) - }, - FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) { - chanStateDB := c.chanSource.ChannelStateDB() - return chanStateDB.FetchHistoricalChannel(&chanPoint) - }, - FindOutgoingHTLCDeadline: func( - htlc channeldb.HTLC) fn.Option[int32] { - - return c.FindOutgoingHTLCDeadline( - closeChanInfo.ShortChanID, htlc, - ) - }, - } - chanLog, err := newBoltArbitratorLog( - c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint, - ) - if err != nil { - return err - } - arbCfg.MarkChannelResolved = func() error { - if c.cfg.NotifyFullyResolvedChannel != nil { - c.cfg.NotifyFullyResolvedChannel(chanPoint) - } - - return c.ResolveContract(chanPoint) - } - - // We create an empty map of HTLC's here since it's possible - // that the channel is in StateDefault and updateActiveHTLCs is - // called. We want to avoid writing to an empty map. Since the - // channel is already in the process of being resolved, no new - // HTLCs will be added. - c.activeChannels[chanPoint] = NewChannelArbitrator( - arbCfg, make(map[HtlcSetKey]htlcSet), chanLog, - ) - } - // Now, we'll start all chain watchers in parallel to shorten start up // duration. In neutrino mode, this allows spend registrations to take // advantage of batch spend reporting, instead of doing a single rescan @@ -765,7 +645,7 @@ func (c *ChainArbitrator) Start(beat chainio.Blockbeat) error { // transaction. var startStates map[wire.OutPoint]*chanArbStartState - err = kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error { + err := kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error { for _, arbitrator := range c.activeChannels { startState, err := arbitrator.getStartState(tx) if err != nil { @@ -1332,3 +1212,147 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID, func (c *ChainArbitrator) Name() string { return "ChainArbitrator" } + +// loadOpenChannels loads all channels that are currently open in the database +// and registers them with the chainWatcher for future notification. +func (c *ChainArbitrator) loadOpenChannels() error { + openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels() + if err != nil { + return err + } + + if len(openChannels) == 0 { + return nil + } + + log.Infof("Creating ChannelArbitrators for %v active channels", + len(openChannels)) + + // For each open channel, we'll configure then launch a corresponding + // ChannelArbitrator. + for _, channel := range openChannels { + chanPoint := channel.FundingOutpoint + channel := channel + + // First, we'll create an active chainWatcher for this channel + // to ensure that we detect any relevant on chain events. + breachClosure := func(ret *lnwallet.BreachRetribution) error { + return c.cfg.ContractBreach(chanPoint, ret) + } + + chainWatcher, err := newChainWatcher( + chainWatcherConfig{ + chanState: channel, + notifier: c.cfg.Notifier, + signer: c.cfg.Signer, + isOurAddr: c.cfg.IsOurAddress, + contractBreach: breachClosure, + extractStateNumHint: lnwallet.GetStateNumHint, + auxLeafStore: c.cfg.AuxLeafStore, + auxResolver: c.cfg.AuxResolver, + }, + ) + if err != nil { + return err + } + + c.activeWatchers[chanPoint] = chainWatcher + channelArb, err := newActiveChannelArbitrator( + channel, c, chainWatcher.SubscribeChannelEvents(), + ) + if err != nil { + return err + } + + c.activeChannels[chanPoint] = channelArb + + // Republish any closing transactions for this channel. + err = c.republishClosingTxs(channel) + if err != nil { + log.Errorf("Failed to republish closing txs for "+ + "channel %v", chanPoint) + } + } + + return nil +} + +// loadPendingCloseChannels loads all channels that are currently pending +// closure in the database and registers them with the ChannelArbitrator to +// continue the resolution process. +func (c *ChainArbitrator) loadPendingCloseChannels() error { + chanStateDB := c.chanSource.ChannelStateDB() + + closingChannels, err := chanStateDB.FetchClosedChannels(true) + if err != nil { + return err + } + + if len(closingChannels) == 0 { + return nil + } + + log.Infof("Creating ChannelArbitrators for %v closing channels", + len(closingChannels)) + + // Next, for each channel is the closing state, we'll launch a + // corresponding more restricted resolver, as we don't have to watch + // the chain any longer, only resolve the contracts on the confirmed + // commitment. + //nolint:ll + for _, closeChanInfo := range closingChannels { + // We can leave off the CloseContract and ForceCloseChan + // methods as the channel is already closed at this point. + chanPoint := closeChanInfo.ChanPoint + arbCfg := ChannelArbitratorConfig{ + ChanPoint: chanPoint, + ShortChanID: closeChanInfo.ShortChanID, + ChainArbitratorConfig: c.cfg, + ChainEvents: &ChainEventSubscription{}, + IsPendingClose: true, + ClosingHeight: closeChanInfo.CloseHeight, + CloseType: closeChanInfo.CloseType, + PutResolverReport: func(tx kvdb.RwTx, + report *channeldb.ResolverReport) error { + + return c.chanSource.PutResolverReport( + tx, c.cfg.ChainHash, &chanPoint, report, + ) + }, + FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) { + return chanStateDB.FetchHistoricalChannel(&chanPoint) + }, + FindOutgoingHTLCDeadline: func( + htlc channeldb.HTLC) fn.Option[int32] { + + return c.FindOutgoingHTLCDeadline( + closeChanInfo.ShortChanID, htlc, + ) + }, + } + chanLog, err := newBoltArbitratorLog( + c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint, + ) + if err != nil { + return err + } + arbCfg.MarkChannelResolved = func() error { + if c.cfg.NotifyFullyResolvedChannel != nil { + c.cfg.NotifyFullyResolvedChannel(chanPoint) + } + + return c.ResolveContract(chanPoint) + } + + // We create an empty map of HTLC's here since it's possible + // that the channel is in StateDefault and updateActiveHTLCs is + // called. We want to avoid writing to an empty map. Since the + // channel is already in the process of being resolved, no new + // HTLCs will be added. + c.activeChannels[chanPoint] = NewChannelArbitrator( + arbCfg, make(map[HtlcSetKey]htlcSet), chanLog, + ) + } + + return nil +} diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 9a3bb058ae..423d235dbf 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -165,6 +165,10 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { // returned. // // NOTE: This function MUST be run as a goroutine. + +// TODO(yy): fix the funlen in the next PR. +// +//nolint:funlen func (c *commitSweepResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. if c.resolved { From 266cb6eebbde3c89c2723b21de1c809888787b57 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 22 May 2024 16:51:36 +0800 Subject: [PATCH 38/59] multi: improve loggings --- contractcourt/channel_arbitrator.go | 12 +++++++----- contractcourt/htlc_lease_resolver.go | 6 +++--- contractcourt/utxonursery.go | 2 +- htlcswitch/switch.go | 2 +- lnwallet/wallet.go | 2 +- sweep/sweeper.go | 16 ++++++++++------ 6 files changed, 23 insertions(+), 17 deletions(-) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 8856b98d3c..ca70f733ff 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -1610,8 +1610,8 @@ func (c *ChannelArbitrator) advanceState( for { priorState = c.state log.Debugf("ChannelArbitrator(%v): attempting state step with "+ - "trigger=%v from state=%v", c.cfg.ChanPoint, trigger, - priorState) + "trigger=%v from state=%v at height=%v", + c.cfg.ChanPoint, trigger, priorState, triggerHeight) nextState, closeTx, err := c.stateStep( triggerHeight, trigger, confCommitSet, @@ -2822,14 +2822,12 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32, // We have broadcasted our commitment, and it is now confirmed // on-chain. case closeInfo := <-c.cfg.ChainEvents.LocalUnilateralClosure: - log.Infof("ChannelArbitrator(%v): local on-chain "+ - "channel close", c.cfg.ChanPoint) - if c.state != StateCommitmentBroadcasted { log.Errorf("ChannelArbitrator(%v): unexpected "+ "local on-chain channel close", c.cfg.ChanPoint) } + closeTx := closeInfo.CloseTx resolutions, err := closeInfo.ContractResolutions. @@ -2857,6 +2855,10 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32, return } + log.Infof("ChannelArbitrator(%v): local force close "+ + "tx=%v confirmed", c.cfg.ChanPoint, + closeTx.TxHash()) + contractRes := &ContractResolutions{ CommitHash: closeTx.TxHash(), CommitResolution: resolutions.CommitResolution, diff --git a/contractcourt/htlc_lease_resolver.go b/contractcourt/htlc_lease_resolver.go index 9c5da6ee49..6230f96777 100644 --- a/contractcourt/htlc_lease_resolver.go +++ b/contractcourt/htlc_lease_resolver.go @@ -57,10 +57,10 @@ func (h *htlcLeaseResolver) makeSweepInput(op *wire.OutPoint, signDesc *input.SignDescriptor, csvDelay, broadcastHeight uint32, payHash [32]byte, resBlob fn.Option[tlv.Blob]) *input.BaseInput { - if h.hasCLTV() { - log.Infof("%T(%x): CSV and CLTV locks expired, offering "+ - "second-layer output to sweeper: %v", h, payHash, op) + log.Infof("%T(%x): offering second-layer output to sweeper: %v", h, + payHash, op) + if h.hasCLTV() { return input.NewCsvInputWithCltv( op, cltvWtype, signDesc, broadcastHeight, csvDelay, diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go index a870683746..f78be9fa49 100644 --- a/contractcourt/utxonursery.go +++ b/contractcourt/utxonursery.go @@ -794,7 +794,7 @@ func (u *UtxoNursery) graduateClass(classHeight uint32) error { return err } - utxnLog.Infof("Attempting to graduate height=%v: num_kids=%v, "+ + utxnLog.Debugf("Attempting to graduate height=%v: num_kids=%v, "+ "num_babies=%v", classHeight, len(kgtnOutputs), len(cribOutputs)) // Offer the outputs to the sweeper and set up notifications that will diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 4c54fab0a5..c94c677966 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1606,7 +1606,7 @@ out: } } - log.Infof("Received outside contract resolution, "+ + log.Debugf("Received outside contract resolution, "+ "mapping to: %v", spew.Sdump(pkt)) // We don't check the error, as the only failure we can diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index 96ea85cf9e..2646d7c8f0 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -733,7 +733,7 @@ func (l *LightningWallet) RegisterFundingIntent(expectedID [32]byte, } if _, ok := l.fundingIntents[expectedID]; ok { - return fmt.Errorf("%w: already has intent registered: %v", + return fmt.Errorf("%w: already has intent registered: %x", ErrDuplicatePendingChanID, expectedID[:]) } diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 0f2675e8f8..d49f104af0 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -242,8 +242,9 @@ func (p *SweeperInput) isMature(currentHeight uint32) (bool, uint32) { // currentHeight plus one. locktime = p.BlocksToMaturity() + p.HeightHint() if currentHeight+1 < locktime { - log.Debugf("Input %v has CSV expiry=%v, current height is %v", - p.OutPoint(), locktime, currentHeight) + log.Debugf("Input %v has CSV expiry=%v, current height is %v, "+ + "skipped sweeping", p.OutPoint(), locktime, + currentHeight) return false, locktime } @@ -1197,8 +1198,8 @@ func (s *UtxoSweeper) calculateDefaultDeadline(pi *SweeperInput) int32 { if !matured { defaultDeadline = int32(locktime + s.cfg.NoDeadlineConfTarget) log.Debugf("Input %v is immature, using locktime=%v instead "+ - "of current height=%d", pi.OutPoint(), locktime, - s.currentHeight) + "of current height=%d as starting height", + pi.OutPoint(), locktime, s.currentHeight) } return defaultDeadline @@ -1210,7 +1211,8 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { outpoint := input.input.OutPoint() pi, pending := s.inputs[outpoint] if pending { - log.Debugf("Already has pending input %v received", outpoint) + log.Infof("Already has pending input %v received, old params: "+ + "%v, new params %v", outpoint, pi.params, input.params) s.handleExistingInput(input, pi) @@ -1492,6 +1494,8 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // turn this inputs map into a SyncMap in case we wanna add concurrent // access to the map in the future. for op, input := range s.inputs { + log.Tracef("Checking input: %s, state=%v", input, input.state) + // If the input has reached a final state, that it's either // been swept, or failed, or excluded, we will remove it from // our sweeper. @@ -1521,7 +1525,7 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // skip this input and wait for the locktime to be reached. mature, locktime := input.isMature(uint32(s.currentHeight)) if !mature { - log.Infof("Skipping input %v due to locktime=%v not "+ + log.Debugf("Skipping input %v due to locktime=%v not "+ "reached, current height is %v", op, locktime, s.currentHeight) From 837d5b050eadcd90a550c6a4cde9df9692c6888d Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 19 Nov 2024 17:34:09 +0800 Subject: [PATCH 39/59] chainio: use `errgroup` to limit num of goroutines --- chainio/dispatcher.go | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/chainio/dispatcher.go b/chainio/dispatcher.go index 244a3ac8f7..87bc21fbaa 100644 --- a/chainio/dispatcher.go +++ b/chainio/dispatcher.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/chainntnfs" + "golang.org/x/sync/errgroup" ) // DefaultProcessBlockTimeout is the timeout value used when waiting for one @@ -229,34 +230,30 @@ func DispatchSequential(b Blockbeat, consumers []Consumer) error { // It requires the consumer to finish processing the block within the specified // time, otherwise a timeout error is returned. func DispatchConcurrent(b Blockbeat, consumers []Consumer) error { - // errChans is a map of channels that will be used to receive errors - // returned from notifying the consumers. - errChans := make(map[string]chan error, len(consumers)) + eg := &errgroup.Group{} // Notify each queue in goroutines. for _, c := range consumers { - // Create a signal chan. - errChan := make(chan error, 1) - errChans[c.Name()] = errChan - // Notify each consumer concurrently. - go func(c Consumer, beat Blockbeat) { + eg.Go(func() error { // Send the beat to the consumer. - errChan <- notifyAndWait( - b, c, DefaultProcessBlockTimeout, - ) - }(c, b) - } + err := notifyAndWait(b, c, DefaultProcessBlockTimeout) + + // Exit early if there's no error. + if err == nil { + return nil + } - // Wait for all consumers in each queue to finish. - for name, errChan := range errChans { - err := <-errChan - if err != nil { b.logger().Errorf("Consumer=%v failed to process "+ - "block: %v", name, err) + "block: %v", c.Name(), err) return err - } + }) + } + + // Wait for all consumers in each queue to finish. + if err := eg.Wait(); err != nil { + return err } return nil From 5f37d9394626f4aa3b318122dedce630921e2c38 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 10 Dec 2024 15:13:44 +0800 Subject: [PATCH 40/59] chainio: update `fn` to `v2` --- chainio/consumer_test.go | 2 +- chainio/dispatcher_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chainio/consumer_test.go b/chainio/consumer_test.go index 3ef79b61b4..d1cabf3168 100644 --- a/chainio/consumer_test.go +++ b/chainio/consumer_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/stretchr/testify/require" ) diff --git a/chainio/dispatcher_test.go b/chainio/dispatcher_test.go index 88044c0201..11abbeb65e 100644 --- a/chainio/dispatcher_test.go +++ b/chainio/dispatcher_test.go @@ -5,7 +5,7 @@ import ( "time" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) From d2c96afa95a0d72970cae973a7c681ef4762336b Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 20 Jun 2024 21:56:52 +0800 Subject: [PATCH 41/59] contractcourt: add verbose logging in resolvers We now put the outpoint in the resolvers's logging so it's easier to debug. --- contractcourt/anchor_resolver.go | 3 ++- contractcourt/breach_resolver.go | 5 +++-- contractcourt/commit_sweep_resolver.go | 4 ++-- contractcourt/contract_resolver.go | 6 ++++-- contractcourt/htlc_success_resolver.go | 25 +++++++++++++++---------- contractcourt/htlc_timeout_resolver.go | 24 ++++++++++++++---------- 6 files changed, 40 insertions(+), 27 deletions(-) diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index af7ac76462..b50e061f44 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -2,6 +2,7 @@ package contractcourt import ( "errors" + "fmt" "io" "sync" @@ -71,7 +72,7 @@ func newAnchorResolver(anchorSignDescriptor input.SignDescriptor, currentReport: report, } - r.initLogger(r) + r.initLogger(fmt.Sprintf("%T(%v)", r, r.anchor)) return r } diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index 63395651cc..9a5f4bbe08 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -2,6 +2,7 @@ package contractcourt import ( "encoding/binary" + "fmt" "io" "github.com/lightningnetwork/lnd/channeldb" @@ -32,7 +33,7 @@ func newBreachResolver(resCfg ResolverConfig) *breachResolver { replyChan: make(chan struct{}), } - r.initLogger(r) + r.initLogger(fmt.Sprintf("%T(%v)", r, r.ChanPoint)) return r } @@ -114,7 +115,7 @@ func newBreachResolverFromReader(r io.Reader, resCfg ResolverConfig) ( return nil, err } - b.initLogger(b) + b.initLogger(fmt.Sprintf("%T(%v)", b, b.ChanPoint)) return b, nil } diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 423d235dbf..7b101f80e3 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -88,7 +88,7 @@ func newCommitSweepResolver(res lnwallet.CommitOutputResolution, chanPoint: chanPoint, } - r.initLogger(r) + r.initLogger(fmt.Sprintf("%T(%v)", r, r.commitResolution.SelfOutPoint)) r.initReport() return r @@ -484,7 +484,7 @@ func newCommitSweepResolverFromReader(r io.Reader, resCfg ResolverConfig) ( // removed this, but keep in mind that this data may still be present in // the database. - c.initLogger(c) + c.initLogger(fmt.Sprintf("%T(%v)", c, c.commitResolution.SelfOutPoint)) c.initReport() return c, nil diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 3629c1bc3c..ff52ce9760 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -120,8 +120,10 @@ func newContractResolverKit(cfg ResolverConfig) *contractResolverKit { } // initLogger initializes the resolver-specific logger. -func (r *contractResolverKit) initLogger(resolver ContractResolver) { - logPrefix := fmt.Sprintf("%T(%v):", resolver, r.ChanPoint) +func (r *contractResolverKit) initLogger(prefix string) { + logPrefix := fmt.Sprintf("ChannelArbitrator(%v): %s:", r.ChanPoint, + prefix) + r.log = log.WithPrefix(logPrefix) } diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 06ebf4edc4..363a5cc044 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -2,6 +2,7 @@ package contractcourt import ( "encoding/binary" + "fmt" "io" "sync" @@ -81,27 +82,30 @@ func newSuccessResolver(res lnwallet.IncomingHtlcResolution, } h.initReport() + h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint())) return h } -// ResolverKey returns an identifier which should be globally unique for this -// particular resolver within the chain the original contract resides within. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) ResolverKey() []byte { +// outpoint returns the outpoint of the HTLC output we're attempting to sweep. +func (h *htlcSuccessResolver) outpoint() wire.OutPoint { // The primary key for this resolver will be the outpoint of the HTLC // on the commitment transaction itself. If this is our commitment, // then the output can be found within the signed success tx, // otherwise, it's just the ClaimOutpoint. - var op wire.OutPoint if h.htlcResolution.SignedSuccessTx != nil { - op = h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint - } else { - op = h.htlcResolution.ClaimOutpoint + return h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint } - key := newResolverID(op) + return h.htlcResolution.ClaimOutpoint +} + +// ResolverKey returns an identifier which should be globally unique for this +// particular resolver within the chain the original contract resides within. +// +// NOTE: Part of the ContractResolver interface. +func (h *htlcSuccessResolver) ResolverKey() []byte { + key := newResolverID(h.outpoint()) return key[:] } @@ -679,6 +683,7 @@ func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) ( } h.initReport() + h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint())) return h, nil } diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 81d8e85d21..ca456ec4c8 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -82,6 +82,7 @@ func newTimeoutResolver(res lnwallet.OutgoingHtlcResolution, } h.initReport() + h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint())) return h } @@ -93,23 +94,25 @@ func (h *htlcTimeoutResolver) isTaproot() bool { ) } -// ResolverKey returns an identifier which should be globally unique for this -// particular resolver within the chain the original contract resides within. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) ResolverKey() []byte { +// outpoint returns the outpoint of the HTLC output we're attempting to sweep. +func (h *htlcTimeoutResolver) outpoint() wire.OutPoint { // The primary key for this resolver will be the outpoint of the HTLC // on the commitment transaction itself. If this is our commitment, // then the output can be found within the signed timeout tx, // otherwise, it's just the ClaimOutpoint. - var op wire.OutPoint if h.htlcResolution.SignedTimeoutTx != nil { - op = h.htlcResolution.SignedTimeoutTx.TxIn[0].PreviousOutPoint - } else { - op = h.htlcResolution.ClaimOutpoint + return h.htlcResolution.SignedTimeoutTx.TxIn[0].PreviousOutPoint } - key := newResolverID(op) + return h.htlcResolution.ClaimOutpoint +} + +// ResolverKey returns an identifier which should be globally unique for this +// particular resolver within the chain the original contract resides within. +// +// NOTE: Part of the ContractResolver interface. +func (h *htlcTimeoutResolver) ResolverKey() []byte { + key := newResolverID(h.outpoint()) return key[:] } @@ -1038,6 +1041,7 @@ func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) ( } h.initReport() + h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint())) return h, nil } From 8e69c43675f5f4056e34b5b1cee06e9b16fec432 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 14 Nov 2024 02:28:56 +0800 Subject: [PATCH 42/59] contractcourt: add spend path helpers in timeout/success resolver This commit adds a few helper methods to decide how the htlc output should be spent. --- contractcourt/htlc_success_resolver.go | 43 +++++++++++++++++++------- contractcourt/htlc_timeout_resolver.go | 40 ++++++++++++++++++++---- 2 files changed, 65 insertions(+), 18 deletions(-) diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 363a5cc044..2dc61b18a0 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -127,7 +127,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { // If we don't have a success transaction, then this means that this is // an output on the remote party's commitment transaction. - if h.htlcResolution.SignedSuccessTx == nil { + if h.isRemoteCommitOutput() { return h.resolveRemoteCommitOutput() } @@ -176,7 +176,7 @@ func (h *htlcSuccessResolver) broadcastSuccessTx() ( // and attach fees at will. We let the sweeper handle this job. We use // the checkpointed outputIncubating field to determine if we already // swept the HTLC output into the second level transaction. - if h.htlcResolution.SignDetails != nil { + if h.isZeroFeeOutput() { return h.broadcastReSignedSuccessTx() } @@ -236,12 +236,9 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, // We will have to let the sweeper re-sign the success tx and wait for // it to confirm, if we haven't already. - isTaproot := txscript.IsPayToTaproot( - h.htlcResolution.SweepSignDesc.Output.PkScript, - ) if !h.outputIncubating { var secondLevelInput input.HtlcSecondLevelAnchorInput - if isTaproot { + if h.isTaproot() { //nolint:ll secondLevelInput = input.MakeHtlcSecondLevelSuccessTaprootInput( h.htlcResolution.SignedSuccessTx, @@ -371,7 +368,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, // Let the sweeper sweep the second-level output now that the // CSV/CLTV locks have expired. var witType input.StandardWitnessType - if isTaproot { + if h.isTaproot() { witType = input.TaprootHtlcAcceptedSuccessSecondLevel } else { witType = input.HtlcAcceptedSuccessSecondLevel @@ -421,16 +418,12 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, func (h *htlcSuccessResolver) resolveRemoteCommitOutput() ( ContractResolver, error) { - isTaproot := txscript.IsPayToTaproot( - h.htlcResolution.SweepSignDesc.Output.PkScript, - ) - // Before we can craft out sweeping transaction, we need to // create an input which contains all the items required to add // this input to a sweeping transaction, and generate a // witness. var inp input.Input - if isTaproot { + if h.isTaproot() { inp = lnutils.Ptr(input.MakeTaprootHtlcSucceedInput( &h.htlcResolution.ClaimOutpoint, &h.htlcResolution.SweepSignDesc, @@ -712,3 +705,29 @@ func (h *htlcSuccessResolver) SupplementDeadline(_ fn.Option[int32]) { // A compile time assertion to ensure htlcSuccessResolver meets the // ContractResolver interface. var _ htlcContractResolver = (*htlcSuccessResolver)(nil) + +// isRemoteCommitOutput returns a bool to indicate whether the htlc output is +// on the remote commitment. +func (h *htlcSuccessResolver) isRemoteCommitOutput() bool { + // If we don't have a success transaction, then this means that this is + // an output on the remote party's commitment transaction. + return h.htlcResolution.SignedSuccessTx == nil +} + +// isZeroFeeOutput returns a boolean indicating whether the htlc output is from +// a anchor-enabled channel, which uses the sighash SINGLE|ANYONECANPAY. +func (h *htlcSuccessResolver) isZeroFeeOutput() bool { + // If we have non-nil SignDetails, this means it has a 2nd level HTLC + // transaction that is signed using sighash SINGLE|ANYONECANPAY (the + // case for anchor type channels). In this case we can re-sign it and + // attach fees at will. + return h.htlcResolution.SignedSuccessTx != nil && + h.htlcResolution.SignDetails != nil +} + +// isTaproot returns true if the resolver is for a taproot output. +func (h *htlcSuccessResolver) isTaproot() bool { + return txscript.IsPayToTaproot( + h.htlcResolution.SweepSignDesc.Output.PkScript, + ) +} diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index ca456ec4c8..698b4ae21e 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -634,7 +634,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput() ( // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY // (the case for anchor type channels). In this case we can re-sign it // and attach fees at will. We let the sweeper handle this job. - case h.htlcResolution.SignDetails != nil && !h.outputIncubating: + case h.isZeroFeeOutput() && !h.outputIncubating: if err := h.sweepSecondLevelTx(); err != nil { log.Errorf("Sending timeout tx to sweeper: %v", err) @@ -643,7 +643,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput() ( // If this is a remote commitment there's no second level timeout txn, // and we can just send this directly to the sweeper. - case h.htlcResolution.SignedTimeoutTx == nil && !h.outputIncubating: + case h.isRemoteCommitOutput() && !h.outputIncubating: if err := h.sweepDirectHtlcOutput(); err != nil { log.Errorf("Sending direct spend to sweeper: %v", err) @@ -653,7 +653,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput() ( // If we have a SignedTimeoutTx but no SignDetails, this is a local // commitment for a non-anchor channel, so we'll send it to the utxo // nursery. - case h.htlcResolution.SignDetails == nil && !h.outputIncubating: + case h.isLegacyOutput() && !h.outputIncubating: if err := h.sendSecondLevelTxLegacy(); err != nil { log.Errorf("Sending timeout tx to nursery: %v", err) @@ -769,7 +769,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend( // If the sweeper is handling the second level transaction, wait for // the CSV and possible CLTV lock to expire, before sweeping the output // on the second-level. - case h.htlcResolution.SignDetails != nil: + case h.isZeroFeeOutput(): waitHeight := h.deriveWaitHeight( h.htlcResolution.CsvDelay, commitSpend, ) @@ -851,7 +851,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend( // Finally, if this was an output on our commitment transaction, we'll // wait for the second-level HTLC output to be spent, and for that // transaction itself to confirm. - case h.htlcResolution.SignedTimeoutTx != nil: + case !h.isRemoteCommitOutput(): log.Infof("%T(%v): waiting for nursery/sweeper to spend CSV "+ "delayed output", h, claimOutpoint) @@ -1232,7 +1232,7 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, // continue the loop. hasPreimage := isPreimageSpend( h.isTaproot(), spendDetail, - h.htlcResolution.SignedTimeoutTx != nil, + !h.isRemoteCommitOutput(), ) if !hasPreimage { log.Debugf("HTLC output %s spent doesn't "+ @@ -1260,3 +1260,31 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, } } } + +// isRemoteCommitOutput returns a bool to indicate whether the htlc output is +// on the remote commitment. +func (h *htlcTimeoutResolver) isRemoteCommitOutput() bool { + // If we don't have a timeout transaction, then this means that this is + // an output on the remote party's commitment transaction. + return h.htlcResolution.SignedTimeoutTx == nil +} + +// isZeroFeeOutput returns a boolean indicating whether the htlc output is from +// a anchor-enabled channel, which uses the sighash SINGLE|ANYONECANPAY. +func (h *htlcTimeoutResolver) isZeroFeeOutput() bool { + // If we have non-nil SignDetails, this means it has a 2nd level HTLC + // transaction that is signed using sighash SINGLE|ANYONECANPAY (the + // case for anchor type channels). In this case we can re-sign it and + // attach fees at will. + return h.htlcResolution.SignedTimeoutTx != nil && + h.htlcResolution.SignDetails != nil +} + +// isLegacyOutput returns a boolean indicating whether the htlc output is from +// a non-anchor-enabled channel. +func (h *htlcTimeoutResolver) isLegacyOutput() bool { + // If we have a SignedTimeoutTx but no SignDetails, this is a local + // commitment for a non-anchor channel. + return h.htlcResolution.SignedTimeoutTx != nil && + h.htlcResolution.SignDetails == nil +} From 5d75a0ef271c374fde76daf66c2d7b24fe52e02f Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 14 Nov 2024 21:59:16 +0800 Subject: [PATCH 43/59] contractcourt: add sweep senders in `htlcSuccessResolver` This commit is a pure refactor in which moves the sweep handling logic into the new methods. --- contractcourt/htlc_success_resolver.go | 381 +++++++++++++------------ 1 file changed, 199 insertions(+), 182 deletions(-) diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 2dc61b18a0..3531d1fc92 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -237,55 +237,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, // We will have to let the sweeper re-sign the success tx and wait for // it to confirm, if we haven't already. if !h.outputIncubating { - var secondLevelInput input.HtlcSecondLevelAnchorInput - if h.isTaproot() { - //nolint:ll - secondLevelInput = input.MakeHtlcSecondLevelSuccessTaprootInput( - h.htlcResolution.SignedSuccessTx, - h.htlcResolution.SignDetails, h.htlcResolution.Preimage, - h.broadcastHeight, - input.WithResolutionBlob( - h.htlcResolution.ResolutionBlob, - ), - ) - } else { - //nolint:ll - secondLevelInput = input.MakeHtlcSecondLevelSuccessAnchorInput( - h.htlcResolution.SignedSuccessTx, - h.htlcResolution.SignDetails, h.htlcResolution.Preimage, - h.broadcastHeight, - ) - } - - // Calculate the budget for this sweep. - value := btcutil.Amount( - secondLevelInput.SignDesc().Output.Value, - ) - budget := calculateBudget( - value, h.Budget.DeadlineHTLCRatio, - h.Budget.DeadlineHTLC, - ) - - // The deadline would be the CLTV in this HTLC output. If we - // are the initiator of this force close, with the default - // `IncomingBroadcastDelta`, it means we have 10 blocks left - // when going onchain. Given we need to mine one block to - // confirm the force close tx, and one more block to trigger - // the sweep, we have 8 blocks left to sweep the HTLC. - deadline := fn.Some(int32(h.htlc.RefundTimeout)) - - log.Infof("%T(%x): offering second-level HTLC success tx to "+ - "sweeper with deadline=%v, budget=%v", h, - h.htlc.RHash[:], h.htlc.RefundTimeout, budget) - - // We'll now offer the second-level transaction to the sweeper. - _, err := h.Sweeper.SweepInput( - &secondLevelInput, - sweep.Params{ - Budget: budget, - DeadlineHeight: deadline, - }, - ) + err := h.sweepSuccessTx() if err != nil { return nil, err } @@ -316,99 +268,18 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, "confirmed!", h, h.htlc.RHash[:]) } - // If we ended up here after a restart, we must again get the - // spend notification. - if commitSpend == nil { - var err error - commitSpend, err = waitForSpend( - &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint, - h.htlcResolution.SignDetails.SignDesc.Output.PkScript, - h.broadcastHeight, h.Notifier, h.quit, - ) - if err != nil { - return nil, err - } - } - - // The HTLC success tx has a CSV lock that we must wait for, and if - // this is a lease enforced channel and we're the imitator, we may need - // to wait for longer. - waitHeight := h.deriveWaitHeight( - h.htlcResolution.CsvDelay, commitSpend, - ) - - // Now that the sweeper has broadcasted the second-level transaction, - // it has confirmed, and we have checkpointed our state, we'll sweep - // the second level output. We report the resolver has moved the next - // stage. - h.reportLock.Lock() - h.currentReport.Stage = 2 - h.currentReport.MaturityHeight = waitHeight - h.reportLock.Unlock() - - if h.hasCLTV() { - log.Infof("%T(%x): waiting for CSV and CLTV lock to "+ - "expire at height %v", h, h.htlc.RHash[:], - waitHeight) - } else { - log.Infof("%T(%x): waiting for CSV lock to expire at "+ - "height %v", h, h.htlc.RHash[:], waitHeight) + err := h.sweepSuccessTxOutput() + if err != nil { + return nil, err } - // We'll use this input index to determine the second-level output - // index on the transaction, as the signatures requires the indexes to - // be the same. We don't look for the second-level output script - // directly, as there might be more than one HTLC output to the same - // pkScript. + // Will return this outpoint, when this is spent the resolver is fully + // resolved. op := &wire.OutPoint{ Hash: *commitSpend.SpenderTxHash, Index: commitSpend.SpenderInputIndex, } - // Let the sweeper sweep the second-level output now that the - // CSV/CLTV locks have expired. - var witType input.StandardWitnessType - if h.isTaproot() { - witType = input.TaprootHtlcAcceptedSuccessSecondLevel - } else { - witType = input.HtlcAcceptedSuccessSecondLevel - } - inp := h.makeSweepInput( - op, witType, - input.LeaseHtlcAcceptedSuccessSecondLevel, - &h.htlcResolution.SweepSignDesc, - h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight), - h.htlc.RHash, h.htlcResolution.ResolutionBlob, - ) - - // Calculate the budget for this sweep. - budget := calculateBudget( - btcutil.Amount(inp.SignDesc().Output.Value), - h.Budget.NoDeadlineHTLCRatio, - h.Budget.NoDeadlineHTLC, - ) - - log.Infof("%T(%x): offering second-level success tx output to sweeper "+ - "with no deadline and budget=%v at height=%v", h, - h.htlc.RHash[:], budget, waitHeight) - - // TODO(roasbeef): need to update above for leased types - _, err := h.Sweeper.SweepInput( - inp, - sweep.Params{ - Budget: budget, - - // For second level success tx, there's no rush to get - // it confirmed, so we use a nil deadline. - DeadlineHeight: fn.None[int32](), - }, - ) - if err != nil { - return nil, err - } - - // Will return this outpoint, when this is spent the resolver is fully - // resolved. return op, nil } @@ -418,53 +289,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, func (h *htlcSuccessResolver) resolveRemoteCommitOutput() ( ContractResolver, error) { - // Before we can craft out sweeping transaction, we need to - // create an input which contains all the items required to add - // this input to a sweeping transaction, and generate a - // witness. - var inp input.Input - if h.isTaproot() { - inp = lnutils.Ptr(input.MakeTaprootHtlcSucceedInput( - &h.htlcResolution.ClaimOutpoint, - &h.htlcResolution.SweepSignDesc, - h.htlcResolution.Preimage[:], - h.broadcastHeight, - h.htlcResolution.CsvDelay, - input.WithResolutionBlob( - h.htlcResolution.ResolutionBlob, - ), - )) - } else { - inp = lnutils.Ptr(input.MakeHtlcSucceedInput( - &h.htlcResolution.ClaimOutpoint, - &h.htlcResolution.SweepSignDesc, - h.htlcResolution.Preimage[:], - h.broadcastHeight, - h.htlcResolution.CsvDelay, - )) - } - - // Calculate the budget for this sweep. - budget := calculateBudget( - btcutil.Amount(inp.SignDesc().Output.Value), - h.Budget.DeadlineHTLCRatio, - h.Budget.DeadlineHTLC, - ) - - deadline := fn.Some(int32(h.htlc.RefundTimeout)) - - log.Infof("%T(%x): offering direct-preimage HTLC output to sweeper "+ - "with deadline=%v, budget=%v", h, h.htlc.RHash[:], - h.htlc.RefundTimeout, budget) - - // We'll now offer the direct preimage HTLC to the sweeper. - _, err := h.Sweeper.SweepInput( - inp, - sweep.Params{ - Budget: budget, - DeadlineHeight: deadline, - }, - ) + err := h.sweepRemoteCommitOutput() if err != nil { return nil, err } @@ -731,3 +556,195 @@ func (h *htlcSuccessResolver) isTaproot() bool { h.htlcResolution.SweepSignDesc.Output.PkScript, ) } + +// sweepRemoteCommitOutput creates a sweep request to sweep the HTLC output on +// the remote commitment via the direct preimage-spend. +func (h *htlcSuccessResolver) sweepRemoteCommitOutput() error { + // Before we can craft out sweeping transaction, we need to create an + // input which contains all the items required to add this input to a + // sweeping transaction, and generate a witness. + var inp input.Input + + if h.isTaproot() { + inp = lnutils.Ptr(input.MakeTaprootHtlcSucceedInput( + &h.htlcResolution.ClaimOutpoint, + &h.htlcResolution.SweepSignDesc, + h.htlcResolution.Preimage[:], + h.broadcastHeight, + h.htlcResolution.CsvDelay, + input.WithResolutionBlob( + h.htlcResolution.ResolutionBlob, + ), + )) + } else { + inp = lnutils.Ptr(input.MakeHtlcSucceedInput( + &h.htlcResolution.ClaimOutpoint, + &h.htlcResolution.SweepSignDesc, + h.htlcResolution.Preimage[:], + h.broadcastHeight, + h.htlcResolution.CsvDelay, + )) + } + + // Calculate the budget for this sweep. + budget := calculateBudget( + btcutil.Amount(inp.SignDesc().Output.Value), + h.Budget.DeadlineHTLCRatio, + h.Budget.DeadlineHTLC, + ) + + deadline := fn.Some(int32(h.htlc.RefundTimeout)) + + log.Infof("%T(%x): offering direct-preimage HTLC output to sweeper "+ + "with deadline=%v, budget=%v", h, h.htlc.RHash[:], + h.htlc.RefundTimeout, budget) + + // We'll now offer the direct preimage HTLC to the sweeper. + _, err := h.Sweeper.SweepInput( + inp, + sweep.Params{ + Budget: budget, + DeadlineHeight: deadline, + }, + ) + + return err +} + +// sweepSuccessTx attempts to sweep the second level success tx. +func (h *htlcSuccessResolver) sweepSuccessTx() error { + var secondLevelInput input.HtlcSecondLevelAnchorInput + if h.isTaproot() { + secondLevelInput = input.MakeHtlcSecondLevelSuccessTaprootInput( + h.htlcResolution.SignedSuccessTx, + h.htlcResolution.SignDetails, h.htlcResolution.Preimage, + h.broadcastHeight, input.WithResolutionBlob( + h.htlcResolution.ResolutionBlob, + ), + ) + } else { + secondLevelInput = input.MakeHtlcSecondLevelSuccessAnchorInput( + h.htlcResolution.SignedSuccessTx, + h.htlcResolution.SignDetails, h.htlcResolution.Preimage, + h.broadcastHeight, + ) + } + + // Calculate the budget for this sweep. + value := btcutil.Amount(secondLevelInput.SignDesc().Output.Value) + budget := calculateBudget( + value, h.Budget.DeadlineHTLCRatio, h.Budget.DeadlineHTLC, + ) + + // The deadline would be the CLTV in this HTLC output. If we are the + // initiator of this force close, with the default + // `IncomingBroadcastDelta`, it means we have 10 blocks left when going + // onchain. + deadline := fn.Some(int32(h.htlc.RefundTimeout)) + + h.log.Infof("offering second-level HTLC success tx to sweeper with "+ + "deadline=%v, budget=%v", h.htlc.RefundTimeout, budget) + + // We'll now offer the second-level transaction to the sweeper. + _, err := h.Sweeper.SweepInput( + &secondLevelInput, + sweep.Params{ + Budget: budget, + DeadlineHeight: deadline, + }, + ) + + return err +} + +// sweepSuccessTxOutput attempts to sweep the output of the second level +// success tx. +func (h *htlcSuccessResolver) sweepSuccessTxOutput() error { + h.log.Debugf("sweeping output %v from 2nd-level HTLC success tx", + h.htlcResolution.ClaimOutpoint) + + // This should be non-blocking as we will only attempt to sweep the + // output when the second level tx has already been confirmed. In other + // words, waitForSpend will return immediately. + commitSpend, err := waitForSpend( + &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint, + h.htlcResolution.SignDetails.SignDesc.Output.PkScript, + h.broadcastHeight, h.Notifier, h.quit, + ) + if err != nil { + return err + } + + // The HTLC success tx has a CSV lock that we must wait for, and if + // this is a lease enforced channel and we're the imitator, we may need + // to wait for longer. + waitHeight := h.deriveWaitHeight(h.htlcResolution.CsvDelay, commitSpend) + + // Now that the sweeper has broadcasted the second-level transaction, + // it has confirmed, and we have checkpointed our state, we'll sweep + // the second level output. We report the resolver has moved the next + // stage. + h.reportLock.Lock() + h.currentReport.Stage = 2 + h.currentReport.MaturityHeight = waitHeight + h.reportLock.Unlock() + + if h.hasCLTV() { + log.Infof("%T(%x): waiting for CSV and CLTV lock to expire at "+ + "height %v", h, h.htlc.RHash[:], waitHeight) + } else { + log.Infof("%T(%x): waiting for CSV lock to expire at height %v", + h, h.htlc.RHash[:], waitHeight) + } + + // We'll use this input index to determine the second-level output + // index on the transaction, as the signatures requires the indexes to + // be the same. We don't look for the second-level output script + // directly, as there might be more than one HTLC output to the same + // pkScript. + op := &wire.OutPoint{ + Hash: *commitSpend.SpenderTxHash, + Index: commitSpend.SpenderInputIndex, + } + + // Let the sweeper sweep the second-level output now that the + // CSV/CLTV locks have expired. + var witType input.StandardWitnessType + if h.isTaproot() { + witType = input.TaprootHtlcAcceptedSuccessSecondLevel + } else { + witType = input.HtlcAcceptedSuccessSecondLevel + } + inp := h.makeSweepInput( + op, witType, + input.LeaseHtlcAcceptedSuccessSecondLevel, + &h.htlcResolution.SweepSignDesc, + h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight), + h.htlc.RHash, h.htlcResolution.ResolutionBlob, + ) + + // Calculate the budget for this sweep. + budget := calculateBudget( + btcutil.Amount(inp.SignDesc().Output.Value), + h.Budget.NoDeadlineHTLCRatio, + h.Budget.NoDeadlineHTLC, + ) + + log.Infof("%T(%x): offering second-level success tx output to sweeper "+ + "with no deadline and budget=%v at height=%v", h, + h.htlc.RHash[:], budget, waitHeight) + + // TODO(yy): use the result chan returned from SweepInput. + _, err = h.Sweeper.SweepInput( + inp, + sweep.Params{ + Budget: budget, + + // For second level success tx, there's no rush to get + // it confirmed, so we use a nil deadline. + DeadlineHeight: fn.None[int32](), + }, + ) + + return err +} From 1d89884a3c6f6d949469fa48f4df09347809ea90 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 14 Nov 2024 21:59:55 +0800 Subject: [PATCH 44/59] contractcourt: add resolver handlers in `htlcSuccessResolver` This commit refactors the `Resolve` method by adding two resolver handlers to handle waiting for spending confirmations. --- contractcourt/htlc_success_resolver.go | 229 ++++++++++++++++--------- contractcourt/htlc_timeout_resolver.go | 8 +- 2 files changed, 149 insertions(+), 88 deletions(-) diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 3531d1fc92..b436ccda89 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -140,27 +140,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { // To wrap this up, we'll wait until the second-level transaction has // been spent, then fully resolve the contract. - log.Infof("%T(%x): waiting for second-level HTLC output to be spent "+ - "after csv_delay=%v", h, h.htlc.RHash[:], h.htlcResolution.CsvDelay) - - spend, err := waitForSpend( - secondLevelOutpoint, - h.htlcResolution.SweepSignDesc.Output.PkScript, - h.broadcastHeight, h.Notifier, h.quit, - ) - if err != nil { - return nil, err - } - - h.reportLock.Lock() - h.currentReport.RecoveredBalance = h.currentReport.LimboBalance - h.currentReport.LimboBalance = 0 - h.reportLock.Unlock() - - h.resolved = true - return nil, h.checkpointClaim( - spend.SpenderTxHash, channeldb.ResolverOutcomeClaimed, - ) + return nil, h.resolveSuccessTxOutput(*secondLevelOutpoint) } // broadcastSuccessTx handles an HTLC output on our local commitment by @@ -187,40 +167,11 @@ func (h *htlcSuccessResolver) broadcastSuccessTx() ( // We'll now broadcast the second layer transaction so we can kick off // the claiming process. - // - // TODO(roasbeef): after changing sighashes send to tx bundler - label := labels.MakeLabel( - labels.LabelTypeChannelClose, &h.ShortChanID, - ) - err := h.PublishTx(h.htlcResolution.SignedSuccessTx, label) + err := h.resolveLegacySuccessTx() if err != nil { return nil, err } - // Otherwise, this is an output on our commitment transaction. In this - // case, we'll send it to the incubator, but only if we haven't already - // done so. - if !h.outputIncubating { - log.Infof("%T(%x): incubating incoming htlc output", - h, h.htlc.RHash[:]) - - err := h.IncubateOutputs( - h.ChanPoint, fn.None[lnwallet.OutgoingHtlcResolution](), - fn.Some(h.htlcResolution), - h.broadcastHeight, fn.Some(int32(h.htlc.RefundTimeout)), - ) - if err != nil { - return nil, err - } - - h.outputIncubating = true - - if err := h.Checkpoint(h); err != nil { - log.Errorf("unable to Checkpoint: %v", err) - return nil, err - } - } - return &h.htlcResolution.ClaimOutpoint, nil } @@ -242,33 +193,25 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, return nil, err } - log.Infof("%T(%x): waiting for second-level HTLC success "+ - "transaction to confirm", h, h.htlc.RHash[:]) - - // Wait for the second level transaction to confirm. - commitSpend, err = waitForSpend( - &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint, - h.htlcResolution.SignDetails.SignDesc.Output.PkScript, - h.broadcastHeight, h.Notifier, h.quit, - ) + err = h.resolveSuccessTx() if err != nil { return nil, err } + } - // Now that the second-level transaction has confirmed, we - // checkpoint the state so we'll go to the next stage in case - // of restarts. - h.outputIncubating = true - if err := h.Checkpoint(h); err != nil { - log.Errorf("unable to Checkpoint: %v", err) - return nil, err - } - - log.Infof("%T(%x): second-level HTLC success transaction "+ - "confirmed!", h, h.htlc.RHash[:]) + // This should be non-blocking as we will only attempt to sweep the + // output when the second level tx has already been confirmed. In other + // words, waitForSpend will return immediately. + commitSpend, err := waitForSpend( + &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint, + h.htlcResolution.SignDetails.SignDesc.Output.PkScript, + h.broadcastHeight, h.Notifier, h.quit, + ) + if err != nil { + return nil, err } - err := h.sweepSuccessTxOutput() + err = h.sweepSuccessTxOutput() if err != nil { return nil, err } @@ -304,23 +247,14 @@ func (h *htlcSuccessResolver) resolveRemoteCommitOutput() ( return nil, err } - // Once the transaction has received a sufficient number of - // confirmations, we'll mark ourselves as fully resolved and exit. - h.resolved = true - // Checkpoint the resolver, and write the outcome to disk. - return nil, h.checkpointClaim( - sweepTxDetails.SpenderTxHash, - channeldb.ResolverOutcomeClaimed, - ) + return nil, h.checkpointClaim(sweepTxDetails.SpenderTxHash) } // checkpointClaim checkpoints the success resolver with the reports it needs. // If this htlc was claimed two stages, it will write reports for both stages, // otherwise it will just write for the single htlc claim. -func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash, - outcome channeldb.ResolverOutcome) error { - +func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash) error { // Mark the htlc as final settled. err := h.ChainArbitratorConfig.PutFinalHtlcOutcome( h.ChannelArbitratorConfig.ShortChanID, h.htlc.HtlcIndex, true, @@ -348,7 +282,7 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash, OutPoint: h.htlcResolution.ClaimOutpoint, Amount: amt, ResolverType: channeldb.ResolverTypeIncomingHtlc, - ResolverOutcome: outcome, + ResolverOutcome: channeldb.ResolverOutcomeClaimed, SpendTxID: spendTx, }, } @@ -373,6 +307,7 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash, } // Finally, we checkpoint the resolver with our report(s). + h.resolved = true return h.Checkpoint(h, reports...) } @@ -748,3 +683,129 @@ func (h *htlcSuccessResolver) sweepSuccessTxOutput() error { return err } + +// resolveLegacySuccessTx handles an HTLC output from a pre-anchor type channel +// by broadcasting the second-level success transaction. +func (h *htlcSuccessResolver) resolveLegacySuccessTx() error { + // Otherwise we'll publish the second-level transaction directly and + // offer the resolution to the nursery to handle. + h.log.Infof("broadcasting legacy second-level success tx: %v", + h.htlcResolution.SignedSuccessTx.TxHash()) + + // We'll now broadcast the second layer transaction so we can kick off + // the claiming process. + // + // TODO(yy): offer it to the sweeper instead. + label := labels.MakeLabel( + labels.LabelTypeChannelClose, &h.ShortChanID, + ) + err := h.PublishTx(h.htlcResolution.SignedSuccessTx, label) + if err != nil { + return err + } + + // Fast-forward to resolve the output from the success tx if the it has + // already been sent to the UtxoNursery. + if h.outputIncubating { + return h.resolveSuccessTxOutput(h.htlcResolution.ClaimOutpoint) + } + + h.log.Infof("incubating incoming htlc output") + + // Send the output to the incubator. + err = h.IncubateOutputs( + h.ChanPoint, fn.None[lnwallet.OutgoingHtlcResolution](), + fn.Some(h.htlcResolution), + h.broadcastHeight, fn.Some(int32(h.htlc.RefundTimeout)), + ) + if err != nil { + return err + } + + // Mark the output as incubating and checkpoint it. + h.outputIncubating = true + if err := h.Checkpoint(h); err != nil { + return err + } + + // Move to resolve the output. + return h.resolveSuccessTxOutput(h.htlcResolution.ClaimOutpoint) +} + +// resolveSuccessTx waits for the sweeping tx of the second-level success tx to +// confirm and offers the output from the success tx to the sweeper. +func (h *htlcSuccessResolver) resolveSuccessTx() error { + h.log.Infof("waiting for 2nd-level HTLC success transaction to confirm") + + // Create aliases to make the code more readable. + outpoint := h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint + pkScript := h.htlcResolution.SignDetails.SignDesc.Output.PkScript + + // Wait for the second level transaction to confirm. + commitSpend, err := waitForSpend( + &outpoint, pkScript, h.broadcastHeight, h.Notifier, h.quit, + ) + if err != nil { + return err + } + + // We'll use this input index to determine the second-level output + // index on the transaction, as the signatures requires the indexes to + // be the same. We don't look for the second-level output script + // directly, as there might be more than one HTLC output to the same + // pkScript. + op := wire.OutPoint{ + Hash: *commitSpend.SpenderTxHash, + Index: commitSpend.SpenderInputIndex, + } + + // If the 2nd-stage sweeping has already been started, we can + // fast-forward to start the resolving process for the stage two + // output. + if h.outputIncubating { + return h.resolveSuccessTxOutput(op) + } + + // Now that the second-level transaction has confirmed, we checkpoint + // the state so we'll go to the next stage in case of restarts. + h.outputIncubating = true + if err := h.Checkpoint(h); err != nil { + log.Errorf("unable to Checkpoint: %v", err) + return err + } + + h.log.Infof("2nd-level HTLC success tx=%v confirmed", + commitSpend.SpenderTxHash) + + // Send the sweep request for the output from the success tx. + if err := h.sweepSuccessTxOutput(); err != nil { + return err + } + + return h.resolveSuccessTxOutput(op) +} + +// resolveSuccessTxOutput waits for the spend of the output from the 2nd-level +// success tx. +func (h *htlcSuccessResolver) resolveSuccessTxOutput(op wire.OutPoint) error { + // To wrap this up, we'll wait until the second-level transaction has + // been spent, then fully resolve the contract. + log.Infof("%T(%x): waiting for second-level HTLC output to be spent "+ + "after csv_delay=%v", h, h.htlc.RHash[:], + h.htlcResolution.CsvDelay) + + spend, err := waitForSpend( + &op, h.htlcResolution.SweepSignDesc.Output.PkScript, + h.broadcastHeight, h.Notifier, h.quit, + ) + if err != nil { + return err + } + + h.reportLock.Lock() + h.currentReport.RecoveredBalance = h.currentReport.LimboBalance + h.currentReport.LimboBalance = 0 + h.reportLock.Unlock() + + return h.checkpointClaim(spend.SpenderTxHash) +} diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 698b4ae21e..24167e70a4 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -548,9 +548,9 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx() error { return err } -// sendSecondLevelTxLegacy sends a second level timeout transaction to the utxo -// nursery. This transaction uses the legacy SIGHASH_ALL flag. -func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error { +// resolveSecondLevelTxLegacy sends a second level timeout transaction to the +// utxo nursery. This transaction uses the legacy SIGHASH_ALL flag. +func (h *htlcTimeoutResolver) resolveSecondLevelTxLegacy() error { log.Debugf("%T(%v): incubating htlc output", h, h.htlcResolution.ClaimOutpoint) @@ -654,7 +654,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput() ( // commitment for a non-anchor channel, so we'll send it to the utxo // nursery. case h.isLegacyOutput() && !h.outputIncubating: - if err := h.sendSecondLevelTxLegacy(); err != nil { + if err := h.resolveSecondLevelTxLegacy(); err != nil { log.Errorf("Sending timeout tx to nursery: %v", err) return nil, err From ba016c25d7e281bb8cdc2b823e86480c4a6a22d2 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 14 Nov 2024 22:07:32 +0800 Subject: [PATCH 45/59] contractcourt: remove redundant return value in `claimCleanUp` --- contractcourt/htlc_outgoing_contest_resolver.go | 4 ++-- contractcourt/htlc_timeout_resolver.go | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index ec32ff7f17..7adce5a689 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -87,7 +87,7 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { } // TODO(roasbeef): Checkpoint? - return h.claimCleanUp(commitSpend) + return nil, h.claimCleanUp(commitSpend) // If it hasn't, then we'll watch for both the expiration, and the // sweeping out this output. @@ -144,7 +144,7 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { // party is by revealing the preimage. So we'll perform // our duties to clean up the contract once it has been // claimed. - return h.claimCleanUp(commitSpend) + return nil, h.claimCleanUp(commitSpend) case <-h.quit: return nil, fmt.Errorf("resolver canceled") diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 24167e70a4..b0a57352b8 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -160,7 +160,7 @@ const ( // by the remote party. It'll extract the preimage, add it to the global cache, // and finally send the appropriate clean up message. func (h *htlcTimeoutResolver) claimCleanUp( - commitSpend *chainntnfs.SpendDetail) (ContractResolver, error) { + commitSpend *chainntnfs.SpendDetail) error { // Depending on if this is our commitment or not, then we'll be looking // for a different witness pattern. @@ -195,7 +195,7 @@ func (h *htlcTimeoutResolver) claimCleanUp( // element, then we're actually on the losing side of a breach // attempt... case h.isTaproot() && len(spendingInput.Witness) == 1: - return nil, fmt.Errorf("breach attempt failed") + return fmt.Errorf("breach attempt failed") // Otherwise, they'll be spending directly from our commitment output. // In which case the witness stack looks like: @@ -212,8 +212,8 @@ func (h *htlcTimeoutResolver) claimCleanUp( preimage, err := lntypes.MakePreimage(preimageBytes) if err != nil { - return nil, fmt.Errorf("unable to create pre-image from "+ - "witness: %v", err) + return fmt.Errorf("unable to create pre-image from witness: %w", + err) } log.Infof("%T(%v): extracting preimage=%v from on-chain "+ @@ -235,7 +235,7 @@ func (h *htlcTimeoutResolver) claimCleanUp( HtlcIndex: h.htlc.HtlcIndex, PreImage: &pre, }); err != nil { - return nil, err + return err } h.resolved = true @@ -250,7 +250,7 @@ func (h *htlcTimeoutResolver) claimCleanUp( SpendTxID: commitSpend.SpenderTxHash, } - return nil, h.Checkpoint(h, report) + return h.Checkpoint(h, report) } // chainDetailsToWatch returns the output and script which we use to watch for @@ -448,7 +448,7 @@ func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { "witness cache", h, h.htlc.RHash[:], h.htlcResolution.ClaimOutpoint) - return h.claimCleanUp(commitSpend) + return nil, h.claimCleanUp(commitSpend) } // At this point, the second-level transaction is sufficiently From 2bec82c848a55ec9876db24fe7562fd2fc89c80f Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 14 Nov 2024 22:08:00 +0800 Subject: [PATCH 46/59] contractcourt: add sweep senders in `htlcTimeoutResolver` This commit adds new methods to handle making sweep requests based on the spending path used by the outgoing htlc output. --- contractcourt/htlc_timeout_resolver.go | 251 ++++++++++++++----------- 1 file changed, 140 insertions(+), 111 deletions(-) diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index b0a57352b8..d17059364c 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -476,13 +476,9 @@ func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { return h.handleCommitSpend(commitSpend) } -// sweepSecondLevelTx sends a second level timeout transaction to the sweeper. +// sweepTimeoutTx sends a second level timeout transaction to the sweeper. // This transaction uses the SINLGE|ANYONECANPAY flag. -func (h *htlcTimeoutResolver) sweepSecondLevelTx() error { - log.Infof("%T(%x): offering second-layer timeout tx to sweeper: %v", - h, h.htlc.RHash[:], - spew.Sdump(h.htlcResolution.SignedTimeoutTx)) - +func (h *htlcTimeoutResolver) sweepTimeoutTx() error { var inp input.Input if h.isTaproot() { inp = lnutils.Ptr(input.MakeHtlcSecondLevelTimeoutTaprootInput( @@ -513,27 +509,12 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx() error { btcutil.Amount(inp.SignDesc().Output.Value), 2, 0, ) + h.log.Infof("offering 2nd-level HTLC timeout tx to sweeper "+ + "with deadline=%v, budget=%v", h.incomingHTLCExpiryHeight, + budget) + // For an outgoing HTLC, it must be swept before the RefundTimeout of // its incoming HTLC is reached. - // - // TODO(yy): we may end up mixing inputs with different time locks. - // Suppose we have two outgoing HTLCs, - // - HTLC1: nLocktime is 800000, CLTV delta is 80. - // - HTLC2: nLocktime is 800001, CLTV delta is 79. - // This means they would both have an incoming HTLC that expires at - // 800080, hence they share the same deadline but different locktimes. - // However, with current design, when we are at block 800000, HTLC1 is - // offered to the sweeper. When block 800001 is reached, HTLC1's - // sweeping process is already started, while HTLC2 is being offered to - // the sweeper, so they won't be mixed. This can become an issue tho, - // if we decide to sweep per X blocks. Or the contractcourt sees the - // block first while the sweeper is only aware of the last block. To - // properly fix it, we need `blockbeat` to make sure subsystems are in - // sync. - log.Infof("%T(%x): offering second-level HTLC timeout tx to sweeper "+ - "with deadline=%v, budget=%v", h, h.htlc.RHash[:], - h.incomingHTLCExpiryHeight, budget) - _, err := h.Sweeper.SweepInput( inp, sweep.Params{ @@ -551,21 +532,15 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx() error { // resolveSecondLevelTxLegacy sends a second level timeout transaction to the // utxo nursery. This transaction uses the legacy SIGHASH_ALL flag. func (h *htlcTimeoutResolver) resolveSecondLevelTxLegacy() error { - log.Debugf("%T(%v): incubating htlc output", h, - h.htlcResolution.ClaimOutpoint) + h.log.Debug("incubating htlc output") - err := h.IncubateOutputs( + // The utxo nursery will take care of broadcasting the second-level + // timeout tx and sweeping its output once it confirms. + return h.IncubateOutputs( h.ChanPoint, fn.Some(h.htlcResolution), fn.None[lnwallet.IncomingHtlcResolution](), h.broadcastHeight, h.incomingHTLCExpiryHeight, ) - if err != nil { - return err - } - - h.outputIncubating = true - - return h.Checkpoint(h) } // sweepDirectHtlcOutput sends the direct spend of the HTLC output to the @@ -635,7 +610,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput() ( // (the case for anchor type channels). In this case we can re-sign it // and attach fees at will. We let the sweeper handle this job. case h.isZeroFeeOutput() && !h.outputIncubating: - if err := h.sweepSecondLevelTx(); err != nil { + if err := h.sweepTimeoutTx(); err != nil { log.Errorf("Sending timeout tx to sweeper: %v", err) return nil, err @@ -696,9 +671,6 @@ func (h *htlcTimeoutResolver) watchHtlcSpend() (*chainntnfs.SpendDetail, func (h *htlcTimeoutResolver) waitForConfirmedSpend(op *wire.OutPoint, pkScript []byte) (*chainntnfs.SpendDetail, error) { - log.Infof("%T(%v): waiting for spent of HTLC output %v to be "+ - "fully confirmed", h, h.htlcResolution.ClaimOutpoint, op) - // We'll block here until either we exit, or the HTLC output on the // commitment transaction has been spent. spend, err := waitForSpend( @@ -770,82 +742,11 @@ func (h *htlcTimeoutResolver) handleCommitSpend( // the CSV and possible CLTV lock to expire, before sweeping the output // on the second-level. case h.isZeroFeeOutput(): - waitHeight := h.deriveWaitHeight( - h.htlcResolution.CsvDelay, commitSpend, - ) - - h.reportLock.Lock() - h.currentReport.Stage = 2 - h.currentReport.MaturityHeight = waitHeight - h.reportLock.Unlock() - - if h.hasCLTV() { - log.Infof("%T(%x): waiting for CSV and CLTV lock to "+ - "expire at height %v", h, h.htlc.RHash[:], - waitHeight) - } else { - log.Infof("%T(%x): waiting for CSV lock to expire at "+ - "height %v", h, h.htlc.RHash[:], waitHeight) - } - - // We'll use this input index to determine the second-level - // output index on the transaction, as the signatures requires - // the indexes to be the same. We don't look for the - // second-level output script directly, as there might be more - // than one HTLC output to the same pkScript. - op := &wire.OutPoint{ - Hash: *commitSpend.SpenderTxHash, - Index: commitSpend.SpenderInputIndex, - } - - var csvWitnessType input.StandardWitnessType - if h.isTaproot() { - //nolint:ll - csvWitnessType = input.TaprootHtlcOfferedTimeoutSecondLevel - } else { - csvWitnessType = input.HtlcOfferedTimeoutSecondLevel - } - - // Let the sweeper sweep the second-level output now that the - // CSV/CLTV locks have expired. - inp := h.makeSweepInput( - op, csvWitnessType, - input.LeaseHtlcOfferedTimeoutSecondLevel, - &h.htlcResolution.SweepSignDesc, - h.htlcResolution.CsvDelay, - uint32(commitSpend.SpendingHeight), h.htlc.RHash, - h.htlcResolution.ResolutionBlob, - ) - - // Calculate the budget for this sweep. - budget := calculateBudget( - btcutil.Amount(inp.SignDesc().Output.Value), - h.Budget.NoDeadlineHTLCRatio, - h.Budget.NoDeadlineHTLC, - ) - - log.Infof("%T(%x): offering second-level timeout tx output to "+ - "sweeper with no deadline and budget=%v at height=%v", - h, h.htlc.RHash[:], budget, waitHeight) - - _, err := h.Sweeper.SweepInput( - inp, - sweep.Params{ - Budget: budget, - - // For second level success tx, there's no rush - // to get it confirmed, so we use a nil - // deadline. - DeadlineHeight: fn.None[int32](), - }, - ) + err := h.sweepTimeoutTxOutput() if err != nil { return nil, err } - // Update the claim outpoint to point to the second-level - // transaction created by the sweeper. - claimOutpoint = *op fallthrough // Finally, if this was an output on our commitment transaction, we'll @@ -1288,3 +1189,131 @@ func (h *htlcTimeoutResolver) isLegacyOutput() bool { return h.htlcResolution.SignedTimeoutTx != nil && h.htlcResolution.SignDetails == nil } + +// waitHtlcSpendAndCheckPreimage waits for the htlc output to be spent and +// checks whether the spending reveals the preimage. If the preimage is found, +// it will be added to the preimage beacon to settle the incoming link, and a +// nil spend details will be returned. Otherwise, the spend details will be +// returned, indicating this is a non-preimage spend. +func (h *htlcTimeoutResolver) waitHtlcSpendAndCheckPreimage() ( + *chainntnfs.SpendDetail, error) { + + // Wait for the htlc output to be spent, which can happen in one of the + // paths, + // 1. The remote party spends the htlc output using the preimage. + // 2. The local party spends the htlc timeout tx from the local + // commitment. + // 3. The local party spends the htlc output directlt from the remote + // commitment. + spend, err := h.watchHtlcSpend() + if err != nil { + return nil, err + } + + // If the spend reveals the pre-image, then we'll enter the clean up + // workflow to pass the preimage back to the incoming link, add it to + // the witness cache, and exit. + if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) { + return nil, h.claimCleanUp(spend) + } + + return spend, nil +} + +// sweepTimeoutTxOutput attempts to sweep the output of the second level +// timeout tx. +func (h *htlcTimeoutResolver) sweepTimeoutTxOutput() error { + h.log.Debugf("sweeping output %v from 2nd-level HTLC timeout tx", + h.htlcResolution.ClaimOutpoint) + + // This should be non-blocking as we will only attempt to sweep the + // output when the second level tx has already been confirmed. In other + // words, waitHtlcSpendAndCheckPreimage will return immediately. + commitSpend, err := h.waitHtlcSpendAndCheckPreimage() + if err != nil { + return err + } + + // Exit early if the spend is nil, as this means it's a remote spend + // using the preimage path, which is handled in claimCleanUp. + if commitSpend == nil { + h.log.Infof("preimage spend detected, skipping 2nd-level " + + "HTLC output sweep") + + return nil + } + + waitHeight := h.deriveWaitHeight(h.htlcResolution.CsvDelay, commitSpend) + + // Now that the sweeper has broadcasted the second-level transaction, + // it has confirmed, and we have checkpointed our state, we'll sweep + // the second level output. We report the resolver has moved the next + // stage. + h.reportLock.Lock() + h.currentReport.Stage = 2 + h.currentReport.MaturityHeight = waitHeight + h.reportLock.Unlock() + + if h.hasCLTV() { + h.log.Infof("waiting for CSV and CLTV lock to expire at "+ + "height %v", waitHeight) + } else { + h.log.Infof("waiting for CSV lock to expire at height %v", + waitHeight) + } + + // We'll use this input index to determine the second-level output + // index on the transaction, as the signatures requires the indexes to + // be the same. We don't look for the second-level output script + // directly, as there might be more than one HTLC output to the same + // pkScript. + op := &wire.OutPoint{ + Hash: *commitSpend.SpenderTxHash, + Index: commitSpend.SpenderInputIndex, + } + + var witType input.StandardWitnessType + if h.isTaproot() { + witType = input.TaprootHtlcOfferedTimeoutSecondLevel + } else { + witType = input.HtlcOfferedTimeoutSecondLevel + } + + // Let the sweeper sweep the second-level output now that the CSV/CLTV + // locks have expired. + inp := h.makeSweepInput( + op, witType, + input.LeaseHtlcOfferedTimeoutSecondLevel, + &h.htlcResolution.SweepSignDesc, + h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight), + h.htlc.RHash, h.htlcResolution.ResolutionBlob, + ) + + // Calculate the budget for this sweep. + budget := calculateBudget( + btcutil.Amount(inp.SignDesc().Output.Value), + h.Budget.NoDeadlineHTLCRatio, + h.Budget.NoDeadlineHTLC, + ) + + h.log.Infof("offering output from 2nd-level timeout tx to sweeper "+ + "with no deadline and budget=%v", budget) + + // TODO(yy): use the result chan returned from SweepInput to get the + // confirmation status of this sweeping tx so we don't need to make + // anothe subscription via `RegisterSpendNtfn` for this outpoint here + // in the resolver. + _, err = h.Sweeper.SweepInput( + inp, + sweep.Params{ + Budget: budget, + + // For second level success tx, there's no rush + // to get it confirmed, so we use a nil + // deadline. + DeadlineHeight: fn.None[int32](), + }, + ) + + return err +} From cd2d7098d987d80931d51ddf0aa5f3aab45e0378 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 16 Jul 2024 08:44:53 +0800 Subject: [PATCH 47/59] contractcourt: add methods to checkpoint states This commit adds checkpoint methods in `htlcTimeoutResolver`, which are similar to those used in `htlcSuccessResolver`. --- contractcourt/htlc_timeout_resolver.go | 168 +++++++++++-------------- 1 file changed, 71 insertions(+), 97 deletions(-) diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index d17059364c..f2f7e76998 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" @@ -451,26 +452,6 @@ func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { return nil, h.claimCleanUp(commitSpend) } - // At this point, the second-level transaction is sufficiently - // confirmed, or a transaction directly spending the output is. - // Therefore, we can now send back our clean up message, failing the - // HTLC on the incoming link. - // - // NOTE: This can be called twice if the outgoing resolver restarts - // before the second-stage timeout transaction is confirmed. - log.Infof("%T(%v): resolving htlc with incoming fail msg, "+ - "fully confirmed", h, h.htlcResolution.ClaimOutpoint) - - failureMsg := &lnwire.FailPermanentChannelFailure{} - err = h.DeliverResolutionMsg(ResolutionMsg{ - SourceChan: h.ShortChanID, - HtlcIndex: h.htlc.HtlcIndex, - Failure: failureMsg, - }) - if err != nil { - return nil, err - } - // Depending on whether this was a local or remote commit, we must // handle the spending transaction accordingly. return h.handleCommitSpend(commitSpend) @@ -680,30 +661,9 @@ func (h *htlcTimeoutResolver) waitForConfirmedSpend(op *wire.OutPoint, return nil, err } - // Once confirmed, persist the state on disk. - if err := h.checkPointSecondLevelTx(); err != nil { - return nil, err - } - return spend, err } -// checkPointSecondLevelTx persists the state of a second level HTLC tx to disk -// if it's published by the sweeper. -func (h *htlcTimeoutResolver) checkPointSecondLevelTx() error { - // If this was the second level transaction published by the sweeper, - // we can checkpoint the resolver now that it's confirmed. - if h.htlcResolution.SignDetails != nil && !h.outputIncubating { - h.outputIncubating = true - if err := h.Checkpoint(h); err != nil { - log.Errorf("unable to Checkpoint: %v", err) - return err - } - } - - return nil -} - // handleCommitSpend handles the spend of the HTLC output on the commitment // transaction. If this was our local commitment, the spend will be he // confirmed second-level timeout transaction, and we'll sweep that into our @@ -727,7 +687,8 @@ func (h *htlcTimeoutResolver) handleCommitSpend( // accordingly. spendTxID = commitSpend.SpenderTxHash - reports []*channeldb.ResolverReport + sweepTx *chainntnfs.SpendDetail + err error ) switch { @@ -756,7 +717,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend( log.Infof("%T(%v): waiting for nursery/sweeper to spend CSV "+ "delayed output", h, claimOutpoint) - sweepTx, err := waitForSpend( + sweepTx, err = waitForSpend( &claimOutpoint, h.htlcResolution.SweepSignDesc.Output.PkScript, h.broadcastHeight, h.Notifier, h.quit, @@ -770,38 +731,16 @@ func (h *htlcTimeoutResolver) handleCommitSpend( // Once our sweep of the timeout tx has confirmed, we add a // resolution for our timeoutTx tx first stage transaction. - timeoutTx := commitSpend.SpendingTx - index := commitSpend.SpenderInputIndex - spendHash := commitSpend.SpenderTxHash - - reports = append(reports, &channeldb.ResolverReport{ - OutPoint: timeoutTx.TxIn[index].PreviousOutPoint, - Amount: h.htlc.Amt.ToSatoshis(), - ResolverType: channeldb.ResolverTypeOutgoingHtlc, - ResolverOutcome: channeldb.ResolverOutcomeFirstStage, - SpendTxID: spendHash, - }) + err = h.checkpointStageOne(*spendTxID) + if err != nil { + return nil, err + } } // With the clean up message sent, we'll now mark the contract // resolved, update the recovered balance, record the timeout and the // sweep txid on disk, and wait. - h.resolved = true - h.reportLock.Lock() - h.currentReport.RecoveredBalance = h.currentReport.LimboBalance - h.currentReport.LimboBalance = 0 - h.reportLock.Unlock() - - amt := btcutil.Amount(h.htlcResolution.SweepSignDesc.Output.Value) - reports = append(reports, &channeldb.ResolverReport{ - OutPoint: claimOutpoint, - Amount: amt, - ResolverType: channeldb.ResolverTypeOutgoingHtlc, - ResolverOutcome: channeldb.ResolverOutcomeTimeout, - SpendTxID: spendTxID, - }) - - return nil, h.Checkpoint(h, reports...) + return nil, h.checkpointClaim(sweepTx) } // Stop signals the resolver to cancel any current resolution processes, and @@ -1050,12 +989,6 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, // Create a result chan to hold the results. result := &spendResult{} - // hasMempoolSpend is a flag that indicates whether we have found a - // preimage spend from the mempool. This is used to determine whether - // to checkpoint the resolver or not when later we found the - // corresponding block spend. - hasMempoolSpent := false - // Wait for a spend event to arrive. for { select { @@ -1083,23 +1016,6 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, // Once confirmed, persist the state on disk if // we haven't seen the output's spending tx in // mempool before. - // - // NOTE: we don't checkpoint the resolver if - // it's spending tx has already been found in - // mempool - the resolver will take care of the - // checkpoint in its `claimCleanUp`. If we do - // checkpoint here, however, we'd create a new - // record in db for the same htlc resolver - // which won't be cleaned up later, resulting - // the channel to stay in unresolved state. - // - // TODO(yy): when fee bumper is implemented, we - // need to further check whether this is a - // preimage spend. Also need to refactor here - // to save us some indentation. - if !hasMempoolSpent { - result.err = h.checkPointSecondLevelTx() - } } // Send the result and exit the loop. @@ -1146,10 +1062,6 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, result.spend = spendDetail resultChan <- result - // Set the hasMempoolSpent flag to true so we won't - // checkpoint the resolver again in db. - hasMempoolSpent = true - continue // If the resolver exits, we exit the goroutine. @@ -1317,3 +1229,65 @@ func (h *htlcTimeoutResolver) sweepTimeoutTxOutput() error { return err } + +// checkpointStageOne creates a checkpoint for the first stage of the htlc +// timeout transaction. This is used to ensure that the resolver can resume +// watching for the second stage spend in case of a restart. +func (h *htlcTimeoutResolver) checkpointStageOne( + spendTxid chainhash.Hash) error { + + h.log.Debugf("checkpoint stage one spend of HTLC output %v, spent "+ + "in tx %v", h.outpoint(), spendTxid) + + // Now that the second-level transaction has confirmed, we checkpoint + // the state so we'll go to the next stage in case of restarts. + h.outputIncubating = true + + // Create stage-one report. + report := &channeldb.ResolverReport{ + OutPoint: h.outpoint(), + Amount: h.htlc.Amt.ToSatoshis(), + ResolverType: channeldb.ResolverTypeOutgoingHtlc, + ResolverOutcome: channeldb.ResolverOutcomeFirstStage, + SpendTxID: &spendTxid, + } + + // At this point, the second-level transaction is sufficiently + // confirmed. We can now send back our clean up message, failing the + // HTLC on the incoming link. + failureMsg := &lnwire.FailPermanentChannelFailure{} + err := h.DeliverResolutionMsg(ResolutionMsg{ + SourceChan: h.ShortChanID, + HtlcIndex: h.htlc.HtlcIndex, + Failure: failureMsg, + }) + if err != nil { + return err + } + + return h.Checkpoint(h, report) +} + +// checkpointClaim checkpoints the timeout resolver with the reports it needs. +func (h *htlcTimeoutResolver) checkpointClaim( + spendDetail *chainntnfs.SpendDetail) error { + + h.log.Infof("resolving htlc with incoming fail msg, output=%v "+ + "confirmed in tx=%v", spendDetail.SpentOutPoint, + spendDetail.SpenderTxHash) + + // Create a resolver report for the claiming of the HTLC. + amt := btcutil.Amount(h.htlcResolution.SweepSignDesc.Output.Value) + report := &channeldb.ResolverReport{ + OutPoint: *spendDetail.SpentOutPoint, + Amount: amt, + ResolverType: channeldb.ResolverTypeOutgoingHtlc, + ResolverOutcome: channeldb.ResolverOutcomeTimeout, + SpendTxID: spendDetail.SpenderTxHash, + } + + // Finally, we checkpoint the resolver with our report(s). + h.resolved = true + + return h.Checkpoint(h, report) +} From 60bfafcd837bede5877c67a543084727109ce196 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 16 Jul 2024 08:53:00 +0800 Subject: [PATCH 48/59] contractcourt: add resolve handlers in `htlcTimeoutResolver` This commit adds more methods to handle resolving the spending of the output based on different spending paths. --- contractcourt/htlc_timeout_resolver.go | 201 +++++++++++++++---------- 1 file changed, 121 insertions(+), 80 deletions(-) diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index f2f7e76998..9904f37d3d 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -454,7 +454,11 @@ func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { // Depending on whether this was a local or remote commit, we must // handle the spending transaction accordingly. - return h.handleCommitSpend(commitSpend) + if h.isRemoteCommitOutput() { + return nil, h.resolveRemoteCommitOutput() + } + + return nil, h.resolveTimeoutTx() } // sweepTimeoutTx sends a second level timeout transaction to the sweeper. @@ -664,85 +668,6 @@ func (h *htlcTimeoutResolver) waitForConfirmedSpend(op *wire.OutPoint, return spend, err } -// handleCommitSpend handles the spend of the HTLC output on the commitment -// transaction. If this was our local commitment, the spend will be he -// confirmed second-level timeout transaction, and we'll sweep that into our -// wallet. If the was a remote commitment, the resolver will resolve -// immetiately. -func (h *htlcTimeoutResolver) handleCommitSpend( - commitSpend *chainntnfs.SpendDetail) (ContractResolver, error) { - - var ( - // claimOutpoint will be the outpoint of the second level - // transaction, or on the remote commitment directly. It will - // start out as set in the resolution, but we'll update it if - // the second-level goes through the sweeper and changes its - // txid. - claimOutpoint = h.htlcResolution.ClaimOutpoint - - // spendTxID will be the ultimate spend of the claimOutpoint. - // We set it to the commit spend for now, as this is the - // ultimate spend in case this is a remote commitment. If we go - // through the second-level transaction, we'll update this - // accordingly. - spendTxID = commitSpend.SpenderTxHash - - sweepTx *chainntnfs.SpendDetail - err error - ) - - switch { - - // If we swept an HTLC directly off the remote party's commitment - // transaction, then we can exit here as there's no second level sweep - // to do. - case h.htlcResolution.SignedTimeoutTx == nil: - break - - // If the sweeper is handling the second level transaction, wait for - // the CSV and possible CLTV lock to expire, before sweeping the output - // on the second-level. - case h.isZeroFeeOutput(): - err := h.sweepTimeoutTxOutput() - if err != nil { - return nil, err - } - - fallthrough - - // Finally, if this was an output on our commitment transaction, we'll - // wait for the second-level HTLC output to be spent, and for that - // transaction itself to confirm. - case !h.isRemoteCommitOutput(): - log.Infof("%T(%v): waiting for nursery/sweeper to spend CSV "+ - "delayed output", h, claimOutpoint) - - sweepTx, err = waitForSpend( - &claimOutpoint, - h.htlcResolution.SweepSignDesc.Output.PkScript, - h.broadcastHeight, h.Notifier, h.quit, - ) - if err != nil { - return nil, err - } - - // Update the spend txid to the hash of the sweep transaction. - spendTxID = sweepTx.SpenderTxHash - - // Once our sweep of the timeout tx has confirmed, we add a - // resolution for our timeoutTx tx first stage transaction. - err = h.checkpointStageOne(*spendTxID) - if err != nil { - return nil, err - } - } - - // With the clean up message sent, we'll now mark the contract - // resolved, update the recovered balance, record the timeout and the - // sweep txid on disk, and wait. - return nil, h.checkpointClaim(sweepTx) -} - // Stop signals the resolver to cancel any current resolution processes, and // suspend. // @@ -1291,3 +1216,119 @@ func (h *htlcTimeoutResolver) checkpointClaim( return h.Checkpoint(h, report) } + +// resolveRemoteCommitOutput handles sweeping an HTLC output on the remote +// commitment with via the timeout path. In this case we can sweep the output +// directly, and don't have to broadcast a second-level transaction. +func (h *htlcTimeoutResolver) resolveRemoteCommitOutput() error { + h.log.Debug("waiting for direct-timeout spend of the htlc to confirm") + + // Wait for the direct-timeout HTLC sweep tx to confirm. + spend, err := h.watchHtlcSpend() + if err != nil { + return err + } + + // If the spend reveals the preimage, then we'll enter the clean up + // workflow to pass the preimage back to the incoming link, add it to + // the witness cache, and exit. + if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) { + return h.claimCleanUp(spend) + } + + // Send the clean up msg to fail the incoming HTLC. + failureMsg := &lnwire.FailPermanentChannelFailure{} + err = h.DeliverResolutionMsg(ResolutionMsg{ + SourceChan: h.ShortChanID, + HtlcIndex: h.htlc.HtlcIndex, + Failure: failureMsg, + }) + if err != nil { + return err + } + + // TODO(yy): should also update the `RecoveredBalance` and + // `LimboBalance` like other paths? + + // Checkpoint the resolver, and write the outcome to disk. + return h.checkpointClaim(spend) +} + +// resolveTimeoutTx waits for the sweeping tx of the second-level +// timeout tx to confirm and offers the output from the timeout tx to the +// sweeper. +func (h *htlcTimeoutResolver) resolveTimeoutTx() error { + h.log.Debug("waiting for first-stage 2nd-level HTLC timeout tx to " + + "confirm") + + // Wait for the second level transaction to confirm. + spend, err := h.watchHtlcSpend() + if err != nil { + return err + } + + // If the spend reveals the preimage, then we'll enter the clean up + // workflow to pass the preimage back to the incoming link, add it to + // the witness cache, and exit. + if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) { + return h.claimCleanUp(spend) + } + + op := h.htlcResolution.ClaimOutpoint + spenderTxid := *spend.SpenderTxHash + + // If the timeout tx is a re-signed tx, we will need to find the actual + // spent outpoint from the spending tx. + if h.isZeroFeeOutput() { + op = wire.OutPoint{ + Hash: spenderTxid, + Index: spend.SpenderInputIndex, + } + } + + // If the 2nd-stage sweeping has already been started, we can + // fast-forward to start the resolving process for the stage two + // output. + if h.outputIncubating { + return h.resolveTimeoutTxOutput(op) + } + + h.log.Infof("2nd-level HTLC timeout tx=%v confirmed", spenderTxid) + + // Start the process to sweep the output from the timeout tx. + err = h.sweepTimeoutTxOutput() + if err != nil { + return err + } + + // Create a checkpoint since the timeout tx is confirmed and the sweep + // request has been made. + if err := h.checkpointStageOne(spenderTxid); err != nil { + return err + } + + // Start the resolving process for the stage two output. + return h.resolveTimeoutTxOutput(op) +} + +// resolveTimeoutTxOutput waits for the spend of the output from the 2nd-level +// timeout tx. +func (h *htlcTimeoutResolver) resolveTimeoutTxOutput(op wire.OutPoint) error { + h.log.Debugf("waiting for second-stage 2nd-level timeout tx output %v "+ + "to be spent after csv_delay=%v", op, h.htlcResolution.CsvDelay) + + spend, err := waitForSpend( + &op, h.htlcResolution.SweepSignDesc.Output.PkScript, + h.broadcastHeight, h.Notifier, h.quit, + ) + if err != nil { + return err + } + + h.reportLock.Lock() + h.currentReport.RecoveredBalance = h.currentReport.LimboBalance + h.currentReport.LimboBalance = 0 + h.reportLock.Unlock() + + return h.checkpointClaim(spend) +} From 03b6bc65bec1c7ee79518d041d57e5763e9db106 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 24 Jun 2024 21:49:21 +0800 Subject: [PATCH 49/59] contractcourt: add `Launch` method to anchor/breach resolver We will use this and its following commits to break the original `Resolve` methods into two parts - the first part is moved to a new method `Launch`, which handles sending a sweep request to the sweeper. The second part remains in `Resolve`, which is mainly waiting for a spending tx. Breach resolver currently doesn't do anything in its `Launch` since the sweeping of justice outputs are not handled by the sweeper yet. --- contractcourt/anchor_resolver.go | 119 ++++++++++++++++++----------- contractcourt/breach_resolver.go | 14 ++++ contractcourt/contract_resolver.go | 10 +++ 3 files changed, 100 insertions(+), 43 deletions(-) diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index b50e061f44..84f2a216e4 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -84,49 +84,12 @@ func (c *anchorResolver) ResolverKey() []byte { return nil } -// Resolve offers the anchor output to the sweeper and waits for it to be swept. +// Resolve waits for the output to be swept. func (c *anchorResolver) Resolve() (ContractResolver, error) { - // Attempt to update the sweep parameters to the post-confirmation - // situation. We don't want to force sweep anymore, because the anchor - // lost its special purpose to get the commitment confirmed. It is just - // an output that we want to sweep only if it is economical to do so. - // - // An exclusive group is not necessary anymore, because we know that - // this is the only anchor that can be swept. - // - // We also clear the parent tx information for cpfp, because the - // commitment tx is confirmed. - // - // After a restart or when the remote force closes, the sweeper is not - // yet aware of the anchor. In that case, it will be added as new input - // to the sweeper. - witnessType := input.CommitmentAnchor - - // For taproot channels, we need to use the proper witness type. - if c.chanType.IsTaproot() { - witnessType = input.TaprootAnchorSweepSpend - } - - anchorInput := input.MakeBaseInput( - &c.anchor, witnessType, &c.anchorSignDescriptor, - c.broadcastHeight, nil, - ) - - resultChan, err := c.Sweeper.SweepInput( - &anchorInput, - sweep.Params{ - // For normal anchor sweeping, the budget is 330 sats. - Budget: btcutil.Amount( - anchorInput.SignDesc().Output.Value, - ), - - // There's no rush to sweep the anchor, so we use a nil - // deadline here. - DeadlineHeight: fn.None[int32](), - }, - ) - if err != nil { - return nil, err + // If we're already resolved, then we can exit early. + if c.resolved { + c.log.Errorf("already resolved") + return nil, nil } var ( @@ -135,7 +98,7 @@ func (c *anchorResolver) Resolve() (ContractResolver, error) { ) select { - case sweepRes := <-resultChan: + case sweepRes := <-c.sweepResultChan: switch sweepRes.Err { // Anchor was swept successfully. case nil: @@ -161,6 +124,8 @@ func (c *anchorResolver) Resolve() (ContractResolver, error) { return nil, errResolverShuttingDown } + c.log.Infof("resolved in tx %v", spendTx) + // Update report to reflect that funds are no longer in limbo. c.reportLock.Lock() if outcome == channeldb.ResolverOutcomeClaimed { @@ -181,6 +146,9 @@ func (c *anchorResolver) Resolve() (ContractResolver, error) { // // NOTE: Part of the ContractResolver interface. func (c *anchorResolver) Stop() { + c.log.Debugf("stopping...") + defer c.log.Debugf("stopped") + close(c.quit) } @@ -216,3 +184,68 @@ func (c *anchorResolver) Encode(w io.Writer) error { // A compile time assertion to ensure anchorResolver meets the // ContractResolver interface. var _ ContractResolver = (*anchorResolver)(nil) + +// Launch offers the anchor output to the sweeper. +func (c *anchorResolver) Launch() error { + if c.launched { + c.log.Tracef("already launched") + return nil + } + + c.log.Debugf("launching resolver...") + c.launched = true + + // If we're already resolved, then we can exit early. + if c.resolved { + c.log.Errorf("already resolved") + return nil + } + + // Attempt to update the sweep parameters to the post-confirmation + // situation. We don't want to force sweep anymore, because the anchor + // lost its special purpose to get the commitment confirmed. It is just + // an output that we want to sweep only if it is economical to do so. + // + // An exclusive group is not necessary anymore, because we know that + // this is the only anchor that can be swept. + // + // We also clear the parent tx information for cpfp, because the + // commitment tx is confirmed. + // + // After a restart or when the remote force closes, the sweeper is not + // yet aware of the anchor. In that case, it will be added as new input + // to the sweeper. + witnessType := input.CommitmentAnchor + + // For taproot channels, we need to use the proper witness type. + if c.chanType.IsTaproot() { + witnessType = input.TaprootAnchorSweepSpend + } + + anchorInput := input.MakeBaseInput( + &c.anchor, witnessType, &c.anchorSignDescriptor, + c.broadcastHeight, nil, + ) + + resultChan, err := c.Sweeper.SweepInput( + &anchorInput, + sweep.Params{ + // For normal anchor sweeping, the budget is 330 sats. + Budget: btcutil.Amount( + anchorInput.SignDesc().Output.Value, + ), + + // There's no rush to sweep the anchor, so we use a nil + // deadline here. + DeadlineHeight: fn.None[int32](), + }, + ) + + if err != nil { + return err + } + + c.sweepResultChan = resultChan + + return nil +} diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index 9a5f4bbe08..75944fa6f7 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -83,6 +83,7 @@ func (b *breachResolver) Resolve() (ContractResolver, error) { // Stop signals the breachResolver to stop. func (b *breachResolver) Stop() { + b.log.Debugf("stopping...") close(b.quit) } @@ -123,3 +124,16 @@ func newBreachResolverFromReader(r io.Reader, resCfg ResolverConfig) ( // A compile time assertion to ensure breachResolver meets the ContractResolver // interface. var _ ContractResolver = (*breachResolver)(nil) + +// TODO(yy): implement it once the outputs are offered to the sweeper. +func (b *breachResolver) Launch() error { + if b.launched { + b.log.Tracef("already launched") + return nil + } + + b.log.Debugf("launching resolver...") + b.launched = true + + return nil +} diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index ff52ce9760..814c02ff56 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/sweep" ) var ( @@ -109,6 +110,15 @@ type contractResolverKit struct { log btclog.Logger quit chan struct{} + + // sweepResultChan is the result chan returned from calling + // `SweepInput`. It should be mounted to the specific resolver once the + // input has been offered to the sweeper. + sweepResultChan chan sweep.Result + + // launched specifies whether the resolver has been launched. Calling + // `Launch` will be a no-op if this is true. + launched bool } // newContractResolverKit instantiates the mix-in struct. From 107a0682cf71f9fab424f0ba84e0ad7b9017e7d0 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 20 Jun 2024 22:05:33 +0800 Subject: [PATCH 50/59] contractcourt: add `Launch` method to commit resolver --- contractcourt/commit_sweep_resolver.go | 339 +++++++++++--------- contractcourt/commit_sweep_resolver_test.go | 4 + 2 files changed, 186 insertions(+), 157 deletions(-) diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 7b101f80e3..b3323158db 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -172,165 +172,10 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { func (c *commitSweepResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. if c.resolved { + c.log.Errorf("already resolved") return nil, nil } - confHeight, err := c.getCommitTxConfHeight() - if err != nil { - return nil, err - } - - // Wait up until the CSV expires, unless we also have a CLTV that - // expires after. - unlockHeight := confHeight + c.commitResolution.MaturityDelay - if c.hasCLTV() { - unlockHeight = uint32(math.Max( - float64(unlockHeight), float64(c.leaseExpiry), - )) - } - - c.log.Debugf("commit conf_height=%v, unlock_height=%v", - confHeight, unlockHeight) - - // Update report now that we learned the confirmation height. - c.reportLock.Lock() - c.currentReport.MaturityHeight = unlockHeight - c.reportLock.Unlock() - - var ( - isLocalCommitTx bool - - signDesc = c.commitResolution.SelfOutputSignDesc - ) - - switch { - // For taproot channels, we'll know if this is the local commit based - // on the timelock value. For remote commitment transactions, the - // witness script has a timelock of 1. - case c.chanType.IsTaproot(): - delayKey := c.localChanCfg.DelayBasePoint.PubKey - nonDelayKey := c.localChanCfg.PaymentBasePoint.PubKey - - signKey := c.commitResolution.SelfOutputSignDesc.KeyDesc.PubKey - - // If the key in the script is neither of these, we shouldn't - // proceed. This should be impossible. - if !signKey.IsEqual(delayKey) && !signKey.IsEqual(nonDelayKey) { - return nil, fmt.Errorf("unknown sign key %v", signKey) - } - - // The commitment transaction is ours iff the signing key is - // the delay key. - isLocalCommitTx = signKey.IsEqual(delayKey) - - // The output is on our local commitment if the script starts with - // OP_IF for the revocation clause. On the remote commitment it will - // either be a regular P2WKH or a simple sig spend with a CSV delay. - default: - isLocalCommitTx = signDesc.WitnessScript[0] == txscript.OP_IF - } - isDelayedOutput := c.commitResolution.MaturityDelay != 0 - - c.log.Debugf("isDelayedOutput=%v, isLocalCommitTx=%v", isDelayedOutput, - isLocalCommitTx) - - // There're three types of commitments, those that have tweaks for the - // remote key (us in this case), those that don't, and a third where - // there is no tweak and the output is delayed. On the local commitment - // our output will always be delayed. We'll rely on the presence of the - // commitment tweak to discern which type of commitment this is. - var witnessType input.WitnessType - switch { - // The local delayed output for a taproot channel. - case isLocalCommitTx && c.chanType.IsTaproot(): - witnessType = input.TaprootLocalCommitSpend - - // The CSV 1 delayed output for a taproot channel. - case !isLocalCommitTx && c.chanType.IsTaproot(): - witnessType = input.TaprootRemoteCommitSpend - - // Delayed output to us on our local commitment for a channel lease in - // which we are the initiator. - case isLocalCommitTx && c.hasCLTV(): - witnessType = input.LeaseCommitmentTimeLock - - // Delayed output to us on our local commitment. - case isLocalCommitTx: - witnessType = input.CommitmentTimeLock - - // A confirmed output to us on the remote commitment for a channel lease - // in which we are the initiator. - case isDelayedOutput && c.hasCLTV(): - witnessType = input.LeaseCommitmentToRemoteConfirmed - - // A confirmed output to us on the remote commitment. - case isDelayedOutput: - witnessType = input.CommitmentToRemoteConfirmed - - // A non-delayed output on the remote commitment where the key is - // tweakless. - case c.commitResolution.SelfOutputSignDesc.SingleTweak == nil: - witnessType = input.CommitSpendNoDelayTweakless - - // A non-delayed output on the remote commitment where the key is - // tweaked. - default: - witnessType = input.CommitmentNoDelay - } - - c.log.Infof("Sweeping with witness type: %v", witnessType) - - // We'll craft an input with all the information required for the - // sweeper to create a fully valid sweeping transaction to recover - // these coins. - var inp *input.BaseInput - if c.hasCLTV() { - inp = input.NewCsvInputWithCltv( - &c.commitResolution.SelfOutPoint, witnessType, - &c.commitResolution.SelfOutputSignDesc, - c.broadcastHeight, c.commitResolution.MaturityDelay, - c.leaseExpiry, - input.WithResolutionBlob( - c.commitResolution.ResolutionBlob, - ), - ) - } else { - inp = input.NewCsvInput( - &c.commitResolution.SelfOutPoint, witnessType, - &c.commitResolution.SelfOutputSignDesc, - c.broadcastHeight, c.commitResolution.MaturityDelay, - input.WithResolutionBlob( - c.commitResolution.ResolutionBlob, - ), - ) - } - - // TODO(roasbeef): instead of ading ctrl block to the sign desc, make - // new input type, have sweeper set it? - - // Calculate the budget for the sweeping this input. - budget := calculateBudget( - btcutil.Amount(inp.SignDesc().Output.Value), - c.Budget.ToLocalRatio, c.Budget.ToLocal, - ) - c.log.Infof("Sweeping commit output using budget=%v", budget) - - // With our input constructed, we'll now offer it to the sweeper. - resultChan, err := c.Sweeper.SweepInput( - inp, sweep.Params{ - Budget: budget, - - // Specify a nil deadline here as there's no time - // pressure. - DeadlineHeight: fn.None[int32](), - }, - ) - if err != nil { - c.log.Errorf("unable to sweep input: %v", err) - - return nil, err - } - var sweepTxID chainhash.Hash // Sweeper is going to join this input with other inputs if possible @@ -339,7 +184,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) { // happen. outcome := channeldb.ResolverOutcomeClaimed select { - case sweepResult := <-resultChan: + case sweepResult := <-c.sweepResultChan: switch sweepResult.Err { // If the remote party was able to sweep this output it's // likely what we sent was actually a revoked commitment. @@ -391,6 +236,8 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) { // // NOTE: Part of the ContractResolver interface. func (c *commitSweepResolver) Stop() { + c.log.Debugf("stopping...") + defer c.log.Debugf("stopped") close(c.quit) } @@ -524,3 +371,181 @@ func (c *commitSweepResolver) initReport() { // A compile time assertion to ensure commitSweepResolver meets the // ContractResolver interface. var _ reportingContractResolver = (*commitSweepResolver)(nil) + +// Launch constructs a commit input and offers it to the sweeper. +func (c *commitSweepResolver) Launch() error { + if c.launched { + c.log.Tracef("already launched") + return nil + } + + c.log.Debugf("launching resolver...") + c.launched = true + + // If we're already resolved, then we can exit early. + if c.resolved { + c.log.Errorf("already resolved") + return nil + } + + confHeight, err := c.getCommitTxConfHeight() + if err != nil { + return err + } + + // Wait up until the CSV expires, unless we also have a CLTV that + // expires after. + unlockHeight := confHeight + c.commitResolution.MaturityDelay + if c.hasCLTV() { + unlockHeight = uint32(math.Max( + float64(unlockHeight), float64(c.leaseExpiry), + )) + } + + // Update report now that we learned the confirmation height. + c.reportLock.Lock() + c.currentReport.MaturityHeight = unlockHeight + c.reportLock.Unlock() + + // Derive the witness type for this input. + witnessType, err := c.decideWitnessType() + if err != nil { + return err + } + + // We'll craft an input with all the information required for the + // sweeper to create a fully valid sweeping transaction to recover + // these coins. + var inp *input.BaseInput + if c.hasCLTV() { + inp = input.NewCsvInputWithCltv( + &c.commitResolution.SelfOutPoint, witnessType, + &c.commitResolution.SelfOutputSignDesc, + c.broadcastHeight, c.commitResolution.MaturityDelay, + c.leaseExpiry, + ) + } else { + inp = input.NewCsvInput( + &c.commitResolution.SelfOutPoint, witnessType, + &c.commitResolution.SelfOutputSignDesc, + c.broadcastHeight, c.commitResolution.MaturityDelay, + ) + } + + // TODO(roasbeef): instead of ading ctrl block to the sign desc, make + // new input type, have sweeper set it? + + // Calculate the budget for the sweeping this input. + budget := calculateBudget( + btcutil.Amount(inp.SignDesc().Output.Value), + c.Budget.ToLocalRatio, c.Budget.ToLocal, + ) + c.log.Infof("sweeping commit output %v using budget=%v", witnessType, + budget) + + // With our input constructed, we'll now offer it to the sweeper. + resultChan, err := c.Sweeper.SweepInput( + inp, sweep.Params{ + Budget: budget, + + // Specify a nil deadline here as there's no time + // pressure. + DeadlineHeight: fn.None[int32](), + }, + ) + if err != nil { + c.log.Errorf("unable to sweep input: %v", err) + + return err + } + + c.sweepResultChan = resultChan + + return nil +} + +// decideWitnessType returns the witness type for the input. +func (c *commitSweepResolver) decideWitnessType() (input.WitnessType, error) { + var ( + isLocalCommitTx bool + signDesc = c.commitResolution.SelfOutputSignDesc + ) + + switch { + // For taproot channels, we'll know if this is the local commit based + // on the timelock value. For remote commitment transactions, the + // witness script has a timelock of 1. + case c.chanType.IsTaproot(): + delayKey := c.localChanCfg.DelayBasePoint.PubKey + nonDelayKey := c.localChanCfg.PaymentBasePoint.PubKey + + signKey := c.commitResolution.SelfOutputSignDesc.KeyDesc.PubKey + + // If the key in the script is neither of these, we shouldn't + // proceed. This should be impossible. + if !signKey.IsEqual(delayKey) && !signKey.IsEqual(nonDelayKey) { + return nil, fmt.Errorf("unknown sign key %v", signKey) + } + + // The commitment transaction is ours iff the signing key is + // the delay key. + isLocalCommitTx = signKey.IsEqual(delayKey) + + // The output is on our local commitment if the script starts with + // OP_IF for the revocation clause. On the remote commitment it will + // either be a regular P2WKH or a simple sig spend with a CSV delay. + default: + isLocalCommitTx = signDesc.WitnessScript[0] == txscript.OP_IF + } + + isDelayedOutput := c.commitResolution.MaturityDelay != 0 + + c.log.Debugf("isDelayedOutput=%v, isLocalCommitTx=%v", isDelayedOutput, + isLocalCommitTx) + + // There're three types of commitments, those that have tweaks for the + // remote key (us in this case), those that don't, and a third where + // there is no tweak and the output is delayed. On the local commitment + // our output will always be delayed. We'll rely on the presence of the + // commitment tweak to discern which type of commitment this is. + var witnessType input.WitnessType + switch { + // The local delayed output for a taproot channel. + case isLocalCommitTx && c.chanType.IsTaproot(): + witnessType = input.TaprootLocalCommitSpend + + // The CSV 1 delayed output for a taproot channel. + case !isLocalCommitTx && c.chanType.IsTaproot(): + witnessType = input.TaprootRemoteCommitSpend + + // Delayed output to us on our local commitment for a channel lease in + // which we are the initiator. + case isLocalCommitTx && c.hasCLTV(): + witnessType = input.LeaseCommitmentTimeLock + + // Delayed output to us on our local commitment. + case isLocalCommitTx: + witnessType = input.CommitmentTimeLock + + // A confirmed output to us on the remote commitment for a channel lease + // in which we are the initiator. + case isDelayedOutput && c.hasCLTV(): + witnessType = input.LeaseCommitmentToRemoteConfirmed + + // A confirmed output to us on the remote commitment. + case isDelayedOutput: + witnessType = input.CommitmentToRemoteConfirmed + + // A non-delayed output on the remote commitment where the key is + // tweakless. + case c.commitResolution.SelfOutputSignDesc.SingleTweak == nil: + witnessType = input.CommitSpendNoDelayTweakless + + // A non-delayed output on the remote commitment where the key is + // tweaked. + default: + witnessType = input.CommitmentNoDelay + } + + return witnessType, nil +} diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index 2195e33779..6855fddcd3 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/sweep" + "github.com/stretchr/testify/require" ) type commitSweepResolverTestContext struct { @@ -82,6 +83,9 @@ func (i *commitSweepResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { + err := i.resolver.Launch() + require.NoError(i.t, err) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, From 05f5b7dfa404eda72a606534c554b412aae78ccc Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 16 Jul 2024 07:24:45 +0800 Subject: [PATCH 51/59] contractcourt: add `Launch` method to htlc success resolver This commit breaks the `Resolve` into two parts - the first part is moved into a `Launch` method that handles sending sweep requests, and the second part remains in `Resolve` which handles waiting for the spend. Since we are using both utxo nursery and sweeper at the same time, to make sure this change doesn't break the existing behavior, we implement the `Launch` as following, - zero-fee htlc - handled by the sweeper - direct output from the remote commit - handled by the sweeper - legacy htlc - handled by the utxo nursery --- contractcourt/htlc_success_resolver.go | 186 ++++++++------------ contractcourt/htlc_success_resolver_test.go | 87 +++++++-- 2 files changed, 152 insertions(+), 121 deletions(-) diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index b436ccda89..20b4516634 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -10,8 +10,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" @@ -116,139 +114,60 @@ func (h *htlcSuccessResolver) ResolverKey() []byte { // anymore. Every HTLC has already passed through the incoming contest resolver // and in there the invoice was already marked as settled. // -// TODO(roasbeef): create multi to batch -// // NOTE: Part of the ContractResolver interface. +// +// TODO(yy): refactor the interface method to return an error only. func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { - // If we're already resolved, then we can exit early. - if h.resolved { - return nil, nil - } - - // If we don't have a success transaction, then this means that this is - // an output on the remote party's commitment transaction. - if h.isRemoteCommitOutput() { - return h.resolveRemoteCommitOutput() - } - - // Otherwise this an output on our own commitment, and we must start by - // broadcasting the second-level success transaction. - secondLevelOutpoint, err := h.broadcastSuccessTx() - if err != nil { - return nil, err - } - - // To wrap this up, we'll wait until the second-level transaction has - // been spent, then fully resolve the contract. - return nil, h.resolveSuccessTxOutput(*secondLevelOutpoint) -} - -// broadcastSuccessTx handles an HTLC output on our local commitment by -// broadcasting the second-level success transaction. It returns the ultimate -// outpoint of the second-level tx, that we must wait to be spent for the -// resolver to be fully resolved. -func (h *htlcSuccessResolver) broadcastSuccessTx() ( - *wire.OutPoint, error) { - - // If we have non-nil SignDetails, this means that have a 2nd level - // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY - // (the case for anchor type channels). In this case we can re-sign it - // and attach fees at will. We let the sweeper handle this job. We use - // the checkpointed outputIncubating field to determine if we already - // swept the HTLC output into the second level transaction. - if h.isZeroFeeOutput() { - return h.broadcastReSignedSuccessTx() - } + var err error - // Otherwise we'll publish the second-level transaction directly and - // offer the resolution to the nursery to handle. - log.Infof("%T(%x): broadcasting second-layer transition tx: %v", - h, h.htlc.RHash[:], spew.Sdump(h.htlcResolution.SignedSuccessTx)) - - // We'll now broadcast the second layer transaction so we can kick off - // the claiming process. - err := h.resolveLegacySuccessTx() - if err != nil { - return nil, err - } - - return &h.htlcResolution.ClaimOutpoint, nil -} + switch { + // If we're already resolved, then we can exit early. + case h.resolved: + h.log.Errorf("already resolved") -// broadcastReSignedSuccessTx handles the case where we have non-nil -// SignDetails, and offers the second level transaction to the Sweeper, that -// will re-sign it and attach fees at will. -func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, - error) { - - // Keep track of the tx spending the HTLC output on the commitment, as - // this will be the confirmed second-level tx we'll ultimately sweep. - var commitSpend *chainntnfs.SpendDetail - - // We will have to let the sweeper re-sign the success tx and wait for - // it to confirm, if we haven't already. - if !h.outputIncubating { - err := h.sweepSuccessTx() - if err != nil { - return nil, err - } + // If this is an output on the remote party's commitment transaction, + // use the direct-spend path to sweep the htlc. + case h.isRemoteCommitOutput(): + err = h.resolveRemoteCommitOutput() + // If this is an output on our commitment transaction using post-anchor + // channel type, it will be handled by the sweeper. + case h.isZeroFeeOutput(): err = h.resolveSuccessTx() - if err != nil { - return nil, err - } - } - - // This should be non-blocking as we will only attempt to sweep the - // output when the second level tx has already been confirmed. In other - // words, waitForSpend will return immediately. - commitSpend, err := waitForSpend( - &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint, - h.htlcResolution.SignDetails.SignDesc.Output.PkScript, - h.broadcastHeight, h.Notifier, h.quit, - ) - if err != nil { - return nil, err - } - err = h.sweepSuccessTxOutput() - if err != nil { - return nil, err + // If this is an output on our own commitment using pre-anchor channel + // type, we will publish the success tx and offer the output to the + // nursery. + default: + err = h.resolveLegacySuccessTx() } - // Will return this outpoint, when this is spent the resolver is fully - // resolved. - op := &wire.OutPoint{ - Hash: *commitSpend.SpenderTxHash, - Index: commitSpend.SpenderInputIndex, - } - - return op, nil + return nil, err } // resolveRemoteCommitOutput handles sweeping an HTLC output on the remote // commitment with the preimage. In this case we can sweep the output directly, // and don't have to broadcast a second-level transaction. -func (h *htlcSuccessResolver) resolveRemoteCommitOutput() ( - ContractResolver, error) { - - err := h.sweepRemoteCommitOutput() - if err != nil { - return nil, err - } +func (h *htlcSuccessResolver) resolveRemoteCommitOutput() error { + h.log.Info("waiting for direct-preimage spend of the htlc to confirm") // Wait for the direct-preimage HTLC sweep tx to confirm. + // + // TODO(yy): use the result chan returned from `SweepInput`. sweepTxDetails, err := waitForSpend( &h.htlcResolution.ClaimOutpoint, h.htlcResolution.SweepSignDesc.Output.PkScript, h.broadcastHeight, h.Notifier, h.quit, ) if err != nil { - return nil, err + return err } + // TODO(yy): should also update the `RecoveredBalance` and + // `LimboBalance` like other paths? + // Checkpoint the resolver, and write the outcome to disk. - return nil, h.checkpointClaim(sweepTxDetails.SpenderTxHash) + return h.checkpointClaim(sweepTxDetails.SpenderTxHash) } // checkpointClaim checkpoints the success resolver with the reports it needs. @@ -316,6 +235,9 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash) error { // // NOTE: Part of the ContractResolver interface. func (h *htlcSuccessResolver) Stop() { + h.log.Debugf("stopping...") + defer h.log.Debugf("stopped") + close(h.quit) } @@ -809,3 +731,47 @@ func (h *htlcSuccessResolver) resolveSuccessTxOutput(op wire.OutPoint) error { return h.checkpointClaim(spend.SpenderTxHash) } + +// Launch creates an input based on the details of the incoming htlc resolution +// and offers it to the sweeper. +func (h *htlcSuccessResolver) Launch() error { + if h.launched { + h.log.Tracef("already launched") + return nil + } + + h.log.Debugf("launching resolver...") + h.launched = true + + switch { + // If we're already resolved, then we can exit early. + case h.resolved: + h.log.Errorf("already resolved") + return nil + + // If this is an output on the remote party's commitment transaction, + // use the direct-spend path. + case h.isRemoteCommitOutput(): + return h.sweepRemoteCommitOutput() + + // If this is an anchor type channel, we now sweep either the + // second-level success tx or the output from the second-level success + // tx. + case h.isZeroFeeOutput(): + // If the second-level success tx has already been swept, we + // can go ahead and sweep its output. + if h.outputIncubating { + return h.sweepSuccessTxOutput() + } + + // Otherwise, sweep the second level tx. + return h.sweepSuccessTx() + + // If this is a legacy channel type, the output is handled by the + // nursery via the Resolve so we do nothing here. + // + // TODO(yy): handle the legacy output by offering it to the sweeper. + default: + return nil + } +} diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index 75c733638f..f395d67215 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "testing" + "time" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -20,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) var testHtlcAmt = lnwire.MilliSatoshi(200000) @@ -39,6 +41,15 @@ type htlcResolverTestContext struct { t *testing.T } +func newHtlcResolverTestContextFromReader(t *testing.T, + newResolver func(htlc channeldb.HTLC, + cfg ResolverConfig) ContractResolver) *htlcResolverTestContext { + + ctx := newHtlcResolverTestContext(t, newResolver) + + return ctx +} + func newHtlcResolverTestContext(t *testing.T, newResolver func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver) *htlcResolverTestContext { @@ -133,6 +144,7 @@ func newHtlcResolverTestContext(t *testing.T, func (i *htlcResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) + go func() { nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ @@ -192,6 +204,7 @@ func TestHtlcSuccessSingleStage(t *testing.T) { // sweeper. details := &chainntnfs.SpendDetail{ SpendingTx: sweepTx, + SpentOutPoint: &htlcOutpoint, SpenderTxHash: &sweepTxid, } ctx.notifier.SpendChan <- details @@ -215,8 +228,8 @@ func TestHtlcSuccessSingleStage(t *testing.T) { ) } -// TestSecondStageResolution tests successful sweep of a second stage htlc -// claim, going through the Nursery. +// TestHtlcSuccessSecondStageResolution tests successful sweep of a second +// stage htlc claim, going through the Nursery. func TestHtlcSuccessSecondStageResolution(t *testing.T) { commitOutpoint := wire.OutPoint{Index: 2} htlcOutpoint := wire.OutPoint{Index: 3} @@ -279,6 +292,7 @@ func TestHtlcSuccessSecondStageResolution(t *testing.T) { ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: sweepTx, + SpentOutPoint: &htlcOutpoint, SpenderTxHash: &sweepHash, } @@ -302,6 +316,8 @@ func TestHtlcSuccessSecondStageResolution(t *testing.T) { // TestHtlcSuccessSecondStageResolutionSweeper test that a resolver with // non-nil SignDetails will offer the second-level transaction to the sweeper // for re-signing. +// +//nolint:ll func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { commitOutpoint := wire.OutPoint{Index: 2} htlcOutpoint := wire.OutPoint{Index: 3} @@ -399,7 +415,20 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { _ bool) error { resolver := ctx.resolver.(*htlcSuccessResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs + + var ( + inp input.Input + ok bool + ) + + select { + case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs: + require.True(t, ok) + + case <-time.After(1 * time.Second): + t.Fatal("expected input to be swept") + } + op := inp.OutPoint() if op != commitOutpoint { return fmt.Errorf("outpoint %v swept, "+ @@ -412,6 +441,7 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { SpenderTxHash: &reSignedHash, SpenderInputIndex: 1, SpendingHeight: 10, + SpentOutPoint: &commitOutpoint, } return nil }, @@ -434,13 +464,37 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { SpenderTxHash: &reSignedHash, SpenderInputIndex: 1, SpendingHeight: 10, + SpentOutPoint: &commitOutpoint, } } // We expect it to sweep the second-level // transaction we notfied about above. resolver := ctx.resolver.(*htlcSuccessResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs + + // Mock `waitForSpend` to return the commit + // spend. + ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: reSignedSuccessTx, + SpenderTxHash: &reSignedHash, + SpenderInputIndex: 1, + SpendingHeight: 10, + SpentOutPoint: &commitOutpoint, + } + + var ( + inp input.Input + ok bool + ) + + select { + case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs: + require.True(t, ok) + + case <-time.After(1 * time.Second): + t.Fatal("expected input to be swept") + } + op := inp.OutPoint() exp := wire.OutPoint{ Hash: reSignedHash, @@ -457,6 +511,7 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { SpendingTx: sweepTx, SpenderTxHash: &sweepHash, SpendingHeight: 14, + SpentOutPoint: &op, } return nil @@ -504,11 +559,14 @@ func testHtlcSuccess(t *testing.T, resolution lnwallet.IncomingHtlcResolution, // for the next portion of the test. ctx := newHtlcResolverTestContext(t, func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver { - return &htlcSuccessResolver{ + r := &htlcSuccessResolver{ contractResolverKit: *newContractResolverKit(cfg), htlc: htlc, htlcResolution: resolution, } + r.initLogger("htlcSuccessResolver") + + return r }, ) @@ -606,7 +664,12 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext, checkpointedState = append(checkpointedState, b.Bytes()) nextCheckpoint++ - checkpointChan <- struct{}{} + select { + case checkpointChan <- struct{}{}: + case <-time.After(1 * time.Second): + t.Fatal("checkpoint timeout") + } + return nil } @@ -617,6 +680,8 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext, // preCheckpoint logic if needed. resumed := true for i, cp := range expectedCheckpoints { + t.Logf("Running checkpoint %d", i) + if cp.preCheckpoint != nil { if err := cp.preCheckpoint(ctx, resumed); err != nil { t.Fatalf("failure at stage %d: %v", i, err) @@ -625,15 +690,15 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext, resumed = false // Wait for the resolver to have checkpointed its state. - <-checkpointChan + select { + case <-checkpointChan: + case <-time.After(1 * time.Second): + t.Fatalf("resolver did not checkpoint at stage %d", i) + } } // Wait for the resolver to fully complete. ctx.waitForResult() - if nextCheckpoint < len(expectedCheckpoints) { - t.Fatalf("not all checkpoints hit") - } - return checkpointedState } From 7b99e49cf095c218964fc0436f2d7c32712fdfb5 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 16 Jul 2024 09:05:47 +0800 Subject: [PATCH 52/59] contractcourt: add `Launch` method to htlc timeout resolver This commit breaks the `Resolve` into two parts - the first part is moved into a `Launch` method that handles sending sweep requests, and the second part remains in `Resolve` which handles waiting for the spend. Since we are using both utxo nursery and sweeper at the same time, to make sure this change doesn't break the existing behavior, we implement the `Launch` as following, - zero-fee htlc - handled by the sweeper - direct output from the remote commit - handled by the sweeper - legacy htlc - handled by the utxo nursery --- contractcourt/channel_arbitrator_test.go | 7 +- contractcourt/htlc_timeout_resolver.go | 162 ++++++------ contractcourt/htlc_timeout_resolver_test.go | 272 +++++++++++++------- 3 files changed, 254 insertions(+), 187 deletions(-) diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index a3bdf6ba3b..d1ef2993d5 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -982,6 +982,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { }, }, } + closeTxid := closeTx.TxHash() htlcOp := wire.OutPoint{ Hash: closeTx.TxHash(), @@ -1117,7 +1118,11 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { // Notify resolver that the HTLC output of the commitment has been // spent. - oldNotifier.SpendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} + oldNotifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: closeTx, + SpentOutPoint: &wire.OutPoint{}, + SpenderTxHash: &closeTxid, + } // Finally, we should also receive a resolution message instructing the // switch to cancel back the HTLC. diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 9904f37d3d..eae52255f2 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -425,40 +425,25 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool { func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { + h.log.Errorf("already resolved") return nil, nil } - // Start by spending the HTLC output, either by broadcasting the - // second-level timeout transaction, or directly if this is the remote - // commitment. - commitSpend, err := h.spendHtlcOutput() - if err != nil { - return nil, err - } - - // If the spend reveals the pre-image, then we'll enter the clean up - // workflow to pass the pre-image back to the incoming link, add it to - // the witness cache, and exit. - if isPreimageSpend( - h.isTaproot(), commitSpend, - h.htlcResolution.SignedTimeoutTx != nil, - ) { - - log.Infof("%T(%v): HTLC has been swept with pre-image by "+ - "remote party during timeout flow! Adding pre-image to "+ - "witness cache", h, h.htlc.RHash[:], - h.htlcResolution.ClaimOutpoint) - - return nil, h.claimCleanUp(commitSpend) - } - - // Depending on whether this was a local or remote commit, we must - // handle the spending transaction accordingly. + // If this is an output on the remote party's commitment transaction, + // use the direct-spend path to sweep the htlc. if h.isRemoteCommitOutput() { return nil, h.resolveRemoteCommitOutput() } - return nil, h.resolveTimeoutTx() + // If this is a zero-fee HTLC, we now handle the spend from our + // commitment transaction. + if h.isZeroFeeOutput() { + return nil, h.resolveTimeoutTx() + } + + // If this is an output on our own commitment using pre-anchor channel + // type, we will let the utxo nursery handle it. + return nil, h.resolveSecondLevelTxLegacy() } // sweepTimeoutTx sends a second level timeout transaction to the sweeper. @@ -521,11 +506,16 @@ func (h *htlcTimeoutResolver) resolveSecondLevelTxLegacy() error { // The utxo nursery will take care of broadcasting the second-level // timeout tx and sweeping its output once it confirms. - return h.IncubateOutputs( + err := h.IncubateOutputs( h.ChanPoint, fn.Some(h.htlcResolution), fn.None[lnwallet.IncomingHtlcResolution](), h.broadcastHeight, h.incomingHTLCExpiryHeight, ) + if err != nil { + return err + } + + return h.resolveTimeoutTx() } // sweepDirectHtlcOutput sends the direct spend of the HTLC output to the @@ -581,53 +571,6 @@ func (h *htlcTimeoutResolver) sweepDirectHtlcOutput() error { return nil } -// spendHtlcOutput handles the initial spend of an HTLC output via the timeout -// clause. If this is our local commitment, the second-level timeout TX will be -// used to spend the output into the next stage. If this is the remote -// commitment, the output will be swept directly without the timeout -// transaction. -func (h *htlcTimeoutResolver) spendHtlcOutput() ( - *chainntnfs.SpendDetail, error) { - - switch { - // If we have non-nil SignDetails, this means that have a 2nd level - // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY - // (the case for anchor type channels). In this case we can re-sign it - // and attach fees at will. We let the sweeper handle this job. - case h.isZeroFeeOutput() && !h.outputIncubating: - if err := h.sweepTimeoutTx(); err != nil { - log.Errorf("Sending timeout tx to sweeper: %v", err) - - return nil, err - } - - // If this is a remote commitment there's no second level timeout txn, - // and we can just send this directly to the sweeper. - case h.isRemoteCommitOutput() && !h.outputIncubating: - if err := h.sweepDirectHtlcOutput(); err != nil { - log.Errorf("Sending direct spend to sweeper: %v", err) - - return nil, err - } - - // If we have a SignedTimeoutTx but no SignDetails, this is a local - // commitment for a non-anchor channel, so we'll send it to the utxo - // nursery. - case h.isLegacyOutput() && !h.outputIncubating: - if err := h.resolveSecondLevelTxLegacy(); err != nil { - log.Errorf("Sending timeout tx to nursery: %v", err) - - return nil, err - } - } - - // Now that we've handed off the HTLC to the nursery or sweeper, we'll - // watch for a spend of the output, and make our next move off of that. - // Depending on if this is our commitment, or the remote party's - // commitment, we'll be watching a different outpoint and script. - return h.watchHtlcSpend() -} - // watchHtlcSpend watches for a spend of the HTLC output. For neutrino backend, // it will check blocks for the confirmed spend. For btcd and bitcoind, it will // check both the mempool and the blocks. @@ -673,6 +616,9 @@ func (h *htlcTimeoutResolver) waitForConfirmedSpend(op *wire.OutPoint, // // NOTE: Part of the ContractResolver interface. func (h *htlcTimeoutResolver) Stop() { + h.log.Debugf("stopping...") + defer h.log.Debugf("stopped") + close(h.quit) } @@ -1018,15 +964,6 @@ func (h *htlcTimeoutResolver) isZeroFeeOutput() bool { h.htlcResolution.SignDetails != nil } -// isLegacyOutput returns a boolean indicating whether the htlc output is from -// a non-anchor-enabled channel. -func (h *htlcTimeoutResolver) isLegacyOutput() bool { - // If we have a SignedTimeoutTx but no SignDetails, this is a local - // commitment for a non-anchor channel. - return h.htlcResolution.SignedTimeoutTx != nil && - h.htlcResolution.SignDetails == nil -} - // waitHtlcSpendAndCheckPreimage waits for the htlc output to be spent and // checks whether the spending reveals the preimage. If the preimage is found, // it will be added to the preimage beacon to settle the incoming link, and a @@ -1296,9 +1233,11 @@ func (h *htlcTimeoutResolver) resolveTimeoutTx() error { h.log.Infof("2nd-level HTLC timeout tx=%v confirmed", spenderTxid) // Start the process to sweep the output from the timeout tx. - err = h.sweepTimeoutTxOutput() - if err != nil { - return err + if h.isZeroFeeOutput() { + err = h.sweepTimeoutTxOutput() + if err != nil { + return err + } } // Create a checkpoint since the timeout tx is confirmed and the sweep @@ -1332,3 +1271,52 @@ func (h *htlcTimeoutResolver) resolveTimeoutTxOutput(op wire.OutPoint) error { return h.checkpointClaim(spend) } + +// Launch creates an input based on the details of the outgoing htlc resolution +// and offers it to the sweeper. +func (h *htlcTimeoutResolver) Launch() error { + if h.launched { + h.log.Tracef("already launched") + return nil + } + + h.log.Debugf("launching resolver...") + h.launched = true + + switch { + // If we're already resolved, then we can exit early. + case h.resolved: + h.log.Errorf("already resolved") + return nil + + // If this is an output on the remote party's commitment transaction, + // use the direct timeout spend path. + // + // NOTE: When the outputIncubating is false, it means that the output + // has been offered to the utxo nursery as starting in 0.18.4, we + // stopped marking this flag for direct timeout spends (#9062). In that + // case, we will do nothing and let the utxo nursery handle it. + case h.isRemoteCommitOutput() && !h.outputIncubating: + return h.sweepDirectHtlcOutput() + + // If this is an anchor type channel, we now sweep either the + // second-level timeout tx or the output from the second-level timeout + // tx. + case h.isZeroFeeOutput(): + // If the second-level timeout tx has already been swept, we + // can go ahead and sweep its output. + if h.outputIncubating { + return h.sweepTimeoutTxOutput() + } + + // Otherwise, sweep the second level tx. + return h.sweepTimeoutTx() + + // If this is an output on our own commitment using pre-anchor channel + // type, we will let the utxo nursery handle it via Resolve. + // + // TODO(yy): handle the legacy output by offering it to the sweeper. + default: + return nil + } +} diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index 17341f0c20..89f334aa83 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -40,7 +40,7 @@ type mockWitnessBeacon struct { func newMockWitnessBeacon() *mockWitnessBeacon { return &mockWitnessBeacon{ preImageUpdates: make(chan lntypes.Preimage, 1), - newPreimages: make(chan []lntypes.Preimage), + newPreimages: make(chan []lntypes.Preimage, 1), lookupPreimage: make(map[lntypes.Hash]lntypes.Preimage), } } @@ -280,7 +280,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { notifier := &mock.ChainNotifier{ EpochChan: make(chan *chainntnfs.BlockEpoch), - SpendChan: make(chan *chainntnfs.SpendDetail), + SpendChan: make(chan *chainntnfs.SpendDetail, 1), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -321,6 +321,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { return nil }, + HtlcNotifier: &mockHTLCNotifier{}, }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { @@ -356,6 +357,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { Amt: testHtlcAmt, }, } + resolver.initLogger("timeoutResolver") var reports []*channeldb.ResolverReport @@ -390,7 +392,12 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { go func() { defer wg.Done() - _, err := resolver.Resolve() + err := resolver.Launch() + if err != nil { + resolveErr <- err + } + + _, err = resolver.Resolve() if err != nil { resolveErr <- err } @@ -406,8 +413,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { sweepChan = mockSweeper.sweptInputs } - // The output should be offered to either the sweeper or - // the nursery. + // The output should be offered to either the sweeper or the nursery. select { case <-incubateChan: case <-sweepChan: @@ -431,6 +437,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { case notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: spendingTx, SpenderTxHash: &spendTxHash, + SpentOutPoint: &testChanPoint2, }: case <-time.After(time.Second * 5): t.Fatalf("failed to request spend ntfn") @@ -487,6 +494,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { case notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: spendingTx, SpenderTxHash: &spendTxHash, + SpentOutPoint: &testChanPoint2, }: case <-time.After(time.Second * 5): t.Fatalf("failed to request spend ntfn") @@ -549,6 +557,8 @@ func TestHtlcTimeoutResolver(t *testing.T) { // TestHtlcTimeoutSingleStage tests a remote commitment confirming, and the // local node sweeping the HTLC output directly after timeout. +// +//nolint:ll func TestHtlcTimeoutSingleStage(t *testing.T) { commitOutpoint := wire.OutPoint{Index: 3} @@ -573,6 +583,12 @@ func TestHtlcTimeoutSingleStage(t *testing.T) { SpendTxID: &sweepTxid, } + sweepSpend := &chainntnfs.SpendDetail{ + SpendingTx: sweepTx, + SpentOutPoint: &commitOutpoint, + SpenderTxHash: &sweepTxid, + } + checkpoints := []checkpoint{ { // We send a confirmation the sweep tx from published @@ -582,9 +598,10 @@ func TestHtlcTimeoutSingleStage(t *testing.T) { // The nursery will create and publish a sweep // tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: sweepTx, - SpenderTxHash: &sweepTxid, + select { + case ctx.notifier.SpendChan <- sweepSpend: + case <-time.After(time.Second * 5): + t.Fatalf("failed to send spend ntfn") } // The resolver should deliver a failure @@ -620,7 +637,9 @@ func TestHtlcTimeoutSingleStage(t *testing.T) { // TestHtlcTimeoutSecondStage tests a local commitment being confirmed, and the // local node claiming the HTLC output using the second-level timeout tx. -func TestHtlcTimeoutSecondStage(t *testing.T) { +// +//nolint:ll +func TestHtlcTimeoutSecondStagex(t *testing.T) { commitOutpoint := wire.OutPoint{Index: 2} htlcOutpoint := wire.OutPoint{Index: 3} @@ -678,23 +697,57 @@ func TestHtlcTimeoutSecondStage(t *testing.T) { SpendTxID: &sweepHash, } + timeoutSpend := &chainntnfs.SpendDetail{ + SpendingTx: timeoutTx, + SpentOutPoint: &commitOutpoint, + SpenderTxHash: &timeoutTxid, + } + + sweepSpend := &chainntnfs.SpendDetail{ + SpendingTx: sweepTx, + SpentOutPoint: &htlcOutpoint, + SpenderTxHash: &sweepHash, + } + checkpoints := []checkpoint{ { + preCheckpoint: func(ctx *htlcResolverTestContext, + _ bool) error { + + // Deliver spend of timeout tx. + ctx.notifier.SpendChan <- timeoutSpend + + return nil + }, + // Output should be handed off to the nursery. incubating: true, + reports: []*channeldb.ResolverReport{ + firstStage, + }, }, { // We send a confirmation for our sweep tx to indicate // that our sweep succeeded. preCheckpoint: func(ctx *htlcResolverTestContext, - _ bool) error { + resumed bool) error { - // The nursery will publish the timeout tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: timeoutTx, - SpenderTxHash: &timeoutTxid, + // When it's reloaded from disk, we need to + // re-send the notification to mock the first + // `watchHtlcSpend`. + if resumed { + // Deliver spend of timeout tx. + ctx.notifier.SpendChan <- timeoutSpend + + // Deliver spend of timeout tx output. + ctx.notifier.SpendChan <- sweepSpend + + return nil } + // Deliver spend of timeout tx output. + ctx.notifier.SpendChan <- sweepSpend + // The resolver should deliver a failure // resolution message (indicating we // successfully timed out the HTLC). @@ -707,12 +760,6 @@ func TestHtlcTimeoutSecondStage(t *testing.T) { t.Fatalf("resolution not sent") } - // Deliver spend of timeout tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: sweepTx, - SpenderTxHash: &sweepHash, - } - return nil }, @@ -722,7 +769,7 @@ func TestHtlcTimeoutSecondStage(t *testing.T) { incubating: true, resolved: true, reports: []*channeldb.ResolverReport{ - firstStage, secondState, + secondState, }, }, } @@ -796,10 +843,6 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) { } checkpoints := []checkpoint{ - { - // Output should be handed off to the nursery. - incubating: true, - }, { // We send a spend notification for a remote spend with // the preimage. @@ -812,6 +855,7 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) { // the preimage. ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: spendTx, + SpentOutPoint: &commitOutpoint, SpenderTxHash: &spendTxHash, } @@ -847,7 +891,7 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) { // After the success tx has confirmed, we expect the // checkpoint to be resolved, and with the above // report. - incubating: true, + incubating: false, resolved: true, reports: []*channeldb.ResolverReport{ claim, @@ -914,6 +958,7 @@ func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) { ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: remoteSuccessTx, + SpentOutPoint: &commitOutpoint, SpenderTxHash: &successTxid, } @@ -967,20 +1012,15 @@ func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) { // TestHtlcTimeoutSecondStageSweeper tests that for anchor channels, when a // local commitment confirms, the timeout tx is handed to the sweeper to claim // the HTLC output. +// +//nolint:ll func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { - commitOutpoint := wire.OutPoint{Index: 2} htlcOutpoint := wire.OutPoint{Index: 3} - sweepTx := &wire.MsgTx{ - TxIn: []*wire.TxIn{{}}, - TxOut: []*wire.TxOut{{}}, - } - sweepHash := sweepTx.TxHash() - timeoutTx := &wire.MsgTx{ TxIn: []*wire.TxIn{ { - PreviousOutPoint: commitOutpoint, + PreviousOutPoint: htlcOutpoint, }, }, TxOut: []*wire.TxOut{ @@ -1027,11 +1067,16 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { }, } reSignedHash := reSignedTimeoutTx.TxHash() - reSignedOutPoint := wire.OutPoint{ + + timeoutTxOutpoint := wire.OutPoint{ Hash: reSignedHash, Index: 1, } + // Make a copy so `isPreimageSpend` can easily pass. + sweepTx := reSignedTimeoutTx.Copy() + sweepHash := sweepTx.TxHash() + // twoStageResolution is a resolution for a htlc on the local // party's commitment, where the timeout tx can be re-signed. twoStageResolution := lnwallet.OutgoingHtlcResolution{ @@ -1045,7 +1090,7 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { } firstStage := &channeldb.ResolverReport{ - OutPoint: commitOutpoint, + OutPoint: htlcOutpoint, Amount: testHtlcAmt.ToSatoshis(), ResolverType: channeldb.ResolverTypeOutgoingHtlc, ResolverOutcome: channeldb.ResolverOutcomeFirstStage, @@ -1053,12 +1098,45 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { } secondState := &channeldb.ResolverReport{ - OutPoint: reSignedOutPoint, + OutPoint: timeoutTxOutpoint, Amount: btcutil.Amount(testSignDesc.Output.Value), ResolverType: channeldb.ResolverTypeOutgoingHtlc, ResolverOutcome: channeldb.ResolverOutcomeTimeout, SpendTxID: &sweepHash, } + // mockTimeoutTxSpend is a helper closure to mock `waitForSpend` to + // return the commit spend in `sweepTimeoutTxOutput`. + mockTimeoutTxSpend := func(ctx *htlcResolverTestContext) { + select { + case ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: reSignedTimeoutTx, + SpenderInputIndex: 1, + SpenderTxHash: &reSignedHash, + SpendingHeight: 10, + SpentOutPoint: &htlcOutpoint, + }: + + case <-time.After(time.Second * 1): + t.Fatalf("spend not sent") + } + } + + // mockSweepTxSpend is a helper closure to mock `waitForSpend` to + // return the commit spend in `sweepTimeoutTxOutput`. + mockSweepTxSpend := func(ctx *htlcResolverTestContext) { + select { + case ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: sweepTx, + SpenderInputIndex: 1, + SpenderTxHash: &sweepHash, + SpendingHeight: 10, + SpentOutPoint: &timeoutTxOutpoint, + }: + + case <-time.After(time.Second * 1): + t.Fatalf("spend not sent") + } + } checkpoints := []checkpoint{ { @@ -1067,28 +1145,40 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { _ bool) error { resolver := ctx.resolver.(*htlcTimeoutResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs + + var ( + inp input.Input + ok bool + ) + + select { + case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs: + require.True(t, ok) + + case <-time.After(1 * time.Second): + t.Fatal("expected input to be swept") + } + op := inp.OutPoint() - if op != commitOutpoint { + if op != htlcOutpoint { return fmt.Errorf("outpoint %v swept, "+ - "expected %v", op, - commitOutpoint) + "expected %v", op, htlcOutpoint) } - // Emulat the sweeper spending using the - // re-signed timeout tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: reSignedTimeoutTx, - SpenderInputIndex: 1, - SpenderTxHash: &reSignedHash, - SpendingHeight: 10, - } + // Mock `waitForSpend` twice, called in, + // - `resolveReSignedTimeoutTx` + // - `sweepTimeoutTxOutput`. + mockTimeoutTxSpend(ctx) + mockTimeoutTxSpend(ctx) return nil }, // incubating=true is used to signal that the // second-level transaction was confirmed. incubating: true, + reports: []*channeldb.ResolverReport{ + firstStage, + }, }, { // We send a confirmation for our sweep tx to indicate @@ -1096,18 +1186,18 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { preCheckpoint: func(ctx *htlcResolverTestContext, resumed bool) error { - // If we are resuming from a checkpoint, we - // expect the resolver to re-subscribe to a - // spend, hence we must resend it. + // Mock `waitForSpend` to return the commit + // spend. if resumed { - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: reSignedTimeoutTx, - SpenderInputIndex: 1, - SpenderTxHash: &reSignedHash, - SpendingHeight: 10, - } + mockTimeoutTxSpend(ctx) + mockTimeoutTxSpend(ctx) + mockSweepTxSpend(ctx) + + return nil } + mockSweepTxSpend(ctx) + // The resolver should deliver a failure // resolution message (indicating we // successfully timed out the HTLC). @@ -1123,7 +1213,20 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { // The timeout tx output should now be given to // the sweeper. resolver := ctx.resolver.(*htlcTimeoutResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs + + var ( + inp input.Input + ok bool + ) + + select { + case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs: + require.True(t, ok) + + case <-time.After(1 * time.Second): + t.Fatal("expected input to be swept") + } + op := inp.OutPoint() exp := wire.OutPoint{ Hash: reSignedHash, @@ -1133,14 +1236,6 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { return fmt.Errorf("wrong outpoint swept") } - // Notify about the spend, which should resolve - // the resolver. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: sweepTx, - SpenderTxHash: &sweepHash, - SpendingHeight: 14, - } - return nil }, @@ -1150,7 +1245,6 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { incubating: true, resolved: true, reports: []*channeldb.ResolverReport{ - firstStage, secondState, }, }, @@ -1231,33 +1325,6 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) { } checkpoints := []checkpoint{ - { - // The output should be given to the sweeper. - preCheckpoint: func(ctx *htlcResolverTestContext, - _ bool) error { - - resolver := ctx.resolver.(*htlcTimeoutResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs - op := inp.OutPoint() - if op != commitOutpoint { - return fmt.Errorf("outpoint %v swept, "+ - "expected %v", op, - commitOutpoint) - } - - // Emulate the remote sweeping the output with the preimage. - // re-signed timeout tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: spendTx, - SpenderTxHash: &spendTxHash, - } - - return nil - }, - // incubating=true is used to signal that the - // second-level transaction was confirmed. - incubating: true, - }, { // We send a confirmation for our sweep tx to indicate // that our sweep succeeded. @@ -1272,6 +1339,7 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) { ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: spendTx, SpenderTxHash: &spendTxHash, + SpentOutPoint: &commitOutpoint, } } @@ -1309,7 +1377,7 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) { // After the sweep has confirmed, we expect the // checkpoint to be resolved, and with the above // reports. - incubating: true, + incubating: false, resolved: true, reports: []*channeldb.ResolverReport{ claim, @@ -1334,21 +1402,26 @@ func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution, // for the next portion of the test. ctx := newHtlcResolverTestContext(t, func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver { - return &htlcTimeoutResolver{ + r := &htlcTimeoutResolver{ contractResolverKit: *newContractResolverKit(cfg), htlc: htlc, htlcResolution: resolution, } + r.initLogger("htlcTimeoutResolver") + + return r }, ) checkpointedState := runFromCheckpoint(t, ctx, checkpoints) + t.Log("Running resolver to completion after restart") + // Now, from every checkpoint created, we re-create the resolver, and // run the test from that checkpoint. for i := range checkpointedState { cp := bytes.NewReader(checkpointedState[i]) - ctx := newHtlcResolverTestContext(t, + ctx := newHtlcResolverTestContextFromReader(t, func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver { resolver, err := newTimeoutResolverFromReader(cp, cfg) if err != nil { @@ -1356,7 +1429,8 @@ func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution, } resolver.Supplement(htlc) - resolver.htlcResolution = resolution + resolver.initLogger("htlcTimeoutResolver") + return resolver }, ) From 3b4e937390d62d0b3ef4b0c383ab1a37c7bec366 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Sun, 17 Nov 2024 10:48:16 +0800 Subject: [PATCH 53/59] invoices: exit early when the subscriber chan is nil When calling `NotifyExitHopHtlc` it is allowed to pass a chan to subscribe to the HTLC's resolution when it's settled. However, this method will also return immediately if there's already a resolution, which means it behaves like a notifier and a getter. If the caller decides to only use the getter to do a non-blocking lookup, it can pass a nil subscriber chan to bypass the notification. --- contractcourt/mock_registry_test.go | 5 +++++ invoices/invoiceregistry.go | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/contractcourt/mock_registry_test.go b/contractcourt/mock_registry_test.go index 5bba11afcb..0530ab51dd 100644 --- a/contractcourt/mock_registry_test.go +++ b/contractcourt/mock_registry_test.go @@ -29,6 +29,11 @@ func (r *mockRegistry) NotifyExitHopHtlc(payHash lntypes.Hash, wireCustomRecords lnwire.CustomRecords, payload invoices.Payload) (invoices.HtlcResolution, error) { + // Exit early if the notification channel is nil. + if hodlChan == nil { + return r.notifyResolution, r.notifyErr + } + r.notifyChan <- notifyExitHopData{ hodlChan: hodlChan, payHash: payHash, diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index f5a6c6a95f..cc76d5aefa 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -1275,7 +1275,11 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( invoiceToExpire = makeInvoiceExpiry(ctx.hash, invoice) } - i.hodlSubscribe(hodlChan, ctx.circuitKey) + // Subscribe to the resolution if the caller specified a + // notification channel. + if hodlChan != nil { + i.hodlSubscribe(hodlChan, ctx.circuitKey) + } default: panic("unknown action") From 018a1083e5188ce7aab07a9e881241da7088e6f7 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Sun, 17 Nov 2024 10:47:23 +0800 Subject: [PATCH 54/59] contractcourt: add `Launch` method to incoming contest resolver A minor refactor is done to support implementing `Launch`. --- .../htlc_incoming_contest_resolver.go | 259 +++++++++++++----- .../htlc_incoming_contest_resolver_test.go | 16 +- 2 files changed, 210 insertions(+), 65 deletions(-) diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index ebac495835..0e0c975248 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -78,6 +78,37 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error { return nil } +// Launch will call the inner resolver's launch method if the preimage can be +// found, otherwise it's a no-op. +func (h *htlcIncomingContestResolver) Launch() error { + // NOTE: we don't mark this resolver as launched as the inner resolver + // will set it when it's launched. + if h.launched { + h.log.Tracef("already launched") + return nil + } + + h.log.Debugf("launching contest resolver...") + + // Query the preimage and apply it if we already know it. + applied, err := h.findAndapplyPreimage() + if err != nil { + return err + } + + // No preimage found, leave it to be handled by the resolver. + if !applied { + return nil + } + + h.log.Debugf("found preimage for htlc=%x, transforming into success "+ + "resolver and launching it", h.htlc.RHash) + + // Once we've applied the preimage, we'll launch the inner resolver to + // attempt to claim the HTLC. + return h.htlcSuccessResolver.Launch() +} + // Resolve attempts to resolve this contract. As we don't yet know of the // preimage for the contract, we'll wait for one of two things to happen: // @@ -94,6 +125,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. if h.resolved { + h.log.Errorf("already resolved") return nil, nil } @@ -101,8 +133,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // now. payload, nextHopOnionBlob, err := h.decodePayload() if err != nil { - log.Debugf("ChannelArbitrator(%v): cannot decode payload of "+ - "htlc %v", h.ChanPoint, h.HtlcPoint()) + h.log.Debugf("cannot decode payload of htlc %v", h.HtlcPoint()) // If we've locked in an htlc with an invalid payload on our // commitment tx, we don't need to resolve it. The other party @@ -177,65 +208,6 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { return nil, h.Checkpoint(h, report) } - // applyPreimage is a helper function that will populate our internal - // resolver with the preimage we learn of. This should be called once - // the preimage is revealed so the inner resolver can properly complete - // its duties. The error return value indicates whether the preimage - // was properly applied. - applyPreimage := func(preimage lntypes.Preimage) error { - // Sanity check to see if this preimage matches our htlc. At - // this point it should never happen that it does not match. - if !preimage.Matches(h.htlc.RHash) { - return errors.New("preimage does not match hash") - } - - // Update htlcResolution with the matching preimage. - h.htlcResolution.Preimage = preimage - - log.Infof("%T(%v): applied preimage=%v", h, - h.htlcResolution.ClaimOutpoint, preimage) - - isSecondLevel := h.htlcResolution.SignedSuccessTx != nil - - // If we didn't have to go to the second level to claim (this - // is the remote commitment transaction), then we don't need to - // modify our canned witness. - if !isSecondLevel { - return nil - } - - isTaproot := txscript.IsPayToTaproot( - h.htlcResolution.SignedSuccessTx.TxOut[0].PkScript, - ) - - // If this is our commitment transaction, then we'll need to - // populate the witness for the second-level HTLC transaction. - switch { - // For taproot channels, the witness for sweeping with success - // looks like: - // - - // - // - // So we'll insert it at the 3rd index of the witness. - case isTaproot: - //nolint:ll - h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[2] = preimage[:] - - // Within the witness for the success transaction, the - // preimage is the 4th element as it looks like: - // - // * <0> - // - // We'll populate it within the witness, as since this - // was a "contest" resolver, we didn't yet know of the - // preimage. - case !isTaproot: - h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[3] = preimage[:] - } - - return nil - } - // Define a closure to process htlc resolutions either directly or // triggered by future notifications. processHtlcResolution := func(e invoices.HtlcResolution) ( @@ -247,7 +219,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // If the htlc resolution was a settle, apply the // preimage and return a success resolver. case *invoices.HtlcSettleResolution: - err := applyPreimage(resolution.Preimage) + err := h.applyPreimage(resolution.Preimage) if err != nil { return nil, err } @@ -312,6 +284,9 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { return nil, err } + h.log.Debugf("received resolution from registry: %v", + resolution) + defer func() { h.Registry.HodlUnsubscribeAll(hodlQueue.ChanIn()) @@ -369,7 +344,9 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // However, we don't know how to ourselves, so we'll // return our inner resolver which has the knowledge to // do so. - if err := applyPreimage(preimage); err != nil { + h.log.Debugf("Found preimage for htlc=%x", h.htlc.RHash) + + if err := h.applyPreimage(preimage); err != nil { return nil, err } @@ -388,7 +365,10 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { continue } - if err := applyPreimage(preimage); err != nil { + h.log.Debugf("Received preimage for htlc=%x", + h.htlc.RHash) + + if err := h.applyPreimage(preimage); err != nil { return nil, err } @@ -435,6 +415,76 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { } } +// applyPreimage is a helper function that will populate our internal resolver +// with the preimage we learn of. This should be called once the preimage is +// revealed so the inner resolver can properly complete its duties. The error +// return value indicates whether the preimage was properly applied. +func (h *htlcIncomingContestResolver) applyPreimage( + preimage lntypes.Preimage) error { + + // Sanity check to see if this preimage matches our htlc. At this point + // it should never happen that it does not match. + if !preimage.Matches(h.htlc.RHash) { + return errors.New("preimage does not match hash") + } + + // We may already have the preimage since both the `Launch` and + // `Resolve` methods will look for it. + if h.htlcResolution.Preimage != lntypes.ZeroHash { + h.log.Debugf("already applied preimage for htlc=%x", + h.htlc.RHash) + + return nil + } + + // Update htlcResolution with the matching preimage. + h.htlcResolution.Preimage = preimage + + log.Infof("%T(%v): applied preimage=%v", h, + h.htlcResolution.ClaimOutpoint, preimage) + + isSecondLevel := h.htlcResolution.SignedSuccessTx != nil + + // If we didn't have to go to the second level to claim (this + // is the remote commitment transaction), then we don't need to + // modify our canned witness. + if !isSecondLevel { + return nil + } + + isTaproot := txscript.IsPayToTaproot( + h.htlcResolution.SignedSuccessTx.TxOut[0].PkScript, + ) + + // If this is our commitment transaction, then we'll need to + // populate the witness for the second-level HTLC transaction. + switch { + // For taproot channels, the witness for sweeping with success + // looks like: + // - + // + // + // So we'll insert it at the 3rd index of the witness. + case isTaproot: + //nolint:ll + h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[2] = preimage[:] + + // Within the witness for the success transaction, the + // preimage is the 4th element as it looks like: + // + // * <0> + // + // We'll populate it within the witness, as since this + // was a "contest" resolver, we didn't yet know of the + // preimage. + case !isTaproot: + //nolint:ll + h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[3] = preimage[:] + } + + return nil +} + // report returns a report on the resolution state of the contract. func (h *htlcIncomingContestResolver) report() *ContractReport { // No locking needed as these values are read-only. @@ -461,6 +511,8 @@ func (h *htlcIncomingContestResolver) report() *ContractReport { // // NOTE: Part of the ContractResolver interface. func (h *htlcIncomingContestResolver) Stop() { + h.log.Debugf("stopping...") + defer h.log.Debugf("stopped") close(h.quit) } @@ -560,3 +612,82 @@ func (h *htlcIncomingContestResolver) decodePayload() (*hop.Payload, // A compile time assertion to ensure htlcIncomingContestResolver meets the // ContractResolver interface. var _ htlcContractResolver = (*htlcIncomingContestResolver)(nil) + +// findAndapplyPreimage performs a non-blocking read to find the preimage for +// the incoming HTLC. If found, it will be applied to the resolver. This method +// is used for the resolver to decide whether it wants to transform into a +// success resolver during launching. +// +// NOTE: Since we have two places to query the preimage, we need to check both +// the preimage db and the invoice db to look up the preimage. +func (h *htlcIncomingContestResolver) findAndapplyPreimage() (bool, error) { + // Query to see if we already know the preimage. + preimage, ok := h.PreimageDB.LookupPreimage(h.htlc.RHash) + + // If the preimage is known, we'll apply it. + if ok { + if err := h.applyPreimage(preimage); err != nil { + return false, err + } + + // Successfully applied the preimage, we can now return. + return true, nil + } + + // First try to parse the payload. + payload, _, err := h.decodePayload() + if err != nil { + h.log.Errorf("Cannot decode payload of htlc %v", h.HtlcPoint()) + + // If we cannot decode the payload, we will return a nil error + // and let it to be handled in `Resolve`. + return false, nil + } + + // Exit early if this is not the exit hop, which means we are not the + // payment receiver and don't have preimage. + if payload.FwdInfo.NextHop != hop.Exit { + return false, nil + } + + // Notify registry that we are potentially resolving as an exit hop + // on-chain. If this HTLC indeed pays to an existing invoice, the + // invoice registry will tell us what to do with the HTLC. This is + // identical to HTLC resolution in the link. + circuitKey := models.CircuitKey{ + ChanID: h.ShortChanID, + HtlcID: h.htlc.HtlcIndex, + } + + // Try get the resolution - if it doesn't give us a resolution + // immediately, we'll assume we don't know it yet and let the `Resolve` + // handle the waiting. + // + // NOTE: we use a nil subscriber here and a zero current height as we + // are only interested in the settle resolution. + // + // TODO(yy): move this logic to link and let the preimage be accessed + // via the preimage beacon. + resolution, err := h.Registry.NotifyExitHopHtlc( + h.htlc.RHash, h.htlc.Amt, h.htlcExpiry, 0, + circuitKey, nil, nil, payload, + ) + if err != nil { + return false, err + } + + res, ok := resolution.(*invoices.HtlcSettleResolution) + + // Exit early if it's not a settle resolution. + if !ok { + return false, nil + } + + // Otherwise we have a settle resolution, apply the preimage. + err = h.applyPreimage(res.Preimage) + if err != nil { + return false, err + } + + return true, nil +} diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index e8b8eac0c9..f17190e96e 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -5,11 +5,13 @@ import ( "io" "testing" + "github.com/btcsuite/btcd/wire" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" @@ -356,6 +358,7 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver return nil }, + Sweeper: newMockSweeper(), }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { @@ -374,10 +377,16 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver }, } + res := lnwallet.IncomingHtlcResolution{ + SweepSignDesc: input.SignDescriptor{ + Output: &wire.TxOut{}, + }, + } + c.resolver = &htlcIncomingContestResolver{ htlcSuccessResolver: &htlcSuccessResolver{ contractResolverKit: *newContractResolverKit(cfg), - htlcResolution: lnwallet.IncomingHtlcResolution{}, + htlcResolution: res, htlc: channeldb.HTLC{ Amt: lnwire.MilliSatoshi(testHtlcAmount), RHash: testResHash, @@ -386,6 +395,7 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver }, htlcExpiry: testHtlcExpiry, } + c.resolver.initLogger("htlcIncomingContestResolver") return c } @@ -395,6 +405,10 @@ func (i *incomingResolverTestContext) resolve() { i.resolveErr = make(chan error, 1) go func() { var err error + + err = i.resolver.Launch() + require.NoError(i.t, err) + i.nextResolver, err = i.resolver.Resolve() i.resolveErr <- err }() From 7b8cf8a3da9219f0ed7413aafca9898a28a2c359 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 20 Jun 2024 22:11:47 +0800 Subject: [PATCH 55/59] contractcourt: add `Launch` method to outgoing contest resolver --- .../htlc_outgoing_contest_resolver.go | 44 ++++++++++++++++--- .../htlc_outgoing_contest_resolver_test.go | 6 +++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 7adce5a689..b23259b5bb 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -1,7 +1,6 @@ package contractcourt import ( - "fmt" "io" "github.com/btcsuite/btcd/btcutil" @@ -36,6 +35,37 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, } } +// Launch will call the inner resolver's launch method if the expiry height has +// been reached, otherwise it's a no-op. +func (h *htlcOutgoingContestResolver) Launch() error { + // NOTE: we don't mark this resolver as launched as the inner resolver + // will set it when it's launched. + if h.launched { + h.log.Tracef("already launched") + return nil + } + + h.log.Debugf("launching contest resolver...") + + _, bestHeight, err := h.ChainIO.GetBestBlock() + if err != nil { + return err + } + + if uint32(bestHeight) < h.htlcResolution.Expiry { + return nil + } + + // If the current height is >= expiry, then a timeout path spend will + // be valid to be included in the next block, and we can immediately + // return the resolver. + h.log.Infof("expired (height=%v, expiry=%v), transforming into "+ + "timeout resolver and launching it", bestHeight, + h.htlcResolution.Expiry) + + return h.htlcTimeoutResolver.Launch() +} + // Resolve commences the resolution of this contract. As this contract hasn't // yet timed out, we'll wait for one of two things to happen // @@ -53,6 +83,7 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. if h.resolved { + h.log.Errorf("already resolved") return nil, nil } @@ -86,7 +117,6 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { return nil, errResolverShuttingDown } - // TODO(roasbeef): Checkpoint? return nil, h.claimCleanUp(commitSpend) // If it hasn't, then we'll watch for both the expiration, and the @@ -124,12 +154,14 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { // finalized` will be returned and the broadcast will // fail. newHeight := uint32(newBlock.Height) - if newHeight >= h.htlcResolution.Expiry { - log.Infof("%T(%v): HTLC has expired "+ + expiry := h.htlcResolution.Expiry + if newHeight >= expiry { + h.log.Infof("HTLC about to expire "+ "(height=%v, expiry=%v), transforming "+ "into timeout resolver", h, h.htlcResolution.ClaimOutpoint, newHeight, h.htlcResolution.Expiry) + return h.htlcTimeoutResolver, nil } @@ -147,7 +179,7 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { return nil, h.claimCleanUp(commitSpend) case <-h.quit: - return nil, fmt.Errorf("resolver canceled") + return nil, errResolverShuttingDown } } } @@ -178,6 +210,8 @@ func (h *htlcOutgoingContestResolver) report() *ContractReport { // // NOTE: Part of the ContractResolver interface. func (h *htlcOutgoingContestResolver) Stop() { + h.log.Debugf("stopping...") + defer h.log.Debugf("stopped") close(h.quit) } diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index 4fa3a6874f..625df60bf1 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) const ( @@ -159,6 +160,7 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { return nil }, + ChainIO: &mock.ChainIO{}, }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { @@ -195,6 +197,7 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { }, }, } + resolver.initLogger("htlcOutgoingContestResolver") return &outgoingResolverTestContext{ resolver: resolver, @@ -209,6 +212,9 @@ func (i *outgoingResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { + err := i.resolver.Launch() + require.NoError(i.t, err) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, From 5d9f5deee706e9ac20fe36160ed2f3a3e91834d9 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 10 Jul 2024 18:08:23 +0800 Subject: [PATCH 56/59] contractcourt: fix concurrent access to `resolved` This commit makes `resolved` an atomic bool to avoid data race. This field is now defined in `contractResolverKit` to avoid code duplication. --- contractcourt/anchor_resolver.go | 17 +---- contractcourt/breach_resolver.go | 22 +++--- contractcourt/briefcase_test.go | 67 +++++++++++-------- contractcourt/commit_sweep_resolver.go | 26 +++---- contractcourt/contract_resolver.go | 17 +++++ .../htlc_incoming_contest_resolver.go | 19 ++---- .../htlc_outgoing_contest_resolver.go | 10 +-- contractcourt/htlc_success_resolver.go | 27 +++----- contractcourt/htlc_success_resolver_test.go | 4 +- contractcourt/htlc_timeout_resolver.go | 29 ++++---- contractcourt/htlc_timeout_resolver_test.go | 2 +- 11 files changed, 111 insertions(+), 129 deletions(-) diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index 84f2a216e4..5a51113881 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -24,9 +24,6 @@ type anchorResolver struct { // anchor is the outpoint on the commitment transaction. anchor wire.OutPoint - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // broadcastHeight is the height that the original contract was // broadcast to the main-chain at. We'll use this value to bound any // historical queries to the chain for spends/confirmations. @@ -87,7 +84,7 @@ func (c *anchorResolver) ResolverKey() []byte { // Resolve waits for the output to be swept. func (c *anchorResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. - if c.resolved { + if c.IsResolved() { c.log.Errorf("already resolved") return nil, nil } @@ -137,7 +134,7 @@ func (c *anchorResolver) Resolve() (ContractResolver, error) { ) c.reportLock.Unlock() - c.resolved = true + c.markResolved() return nil, c.PutResolverReport(nil, report) } @@ -152,14 +149,6 @@ func (c *anchorResolver) Stop() { close(c.quit) } -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (c *anchorResolver) IsResolved() bool { - return c.resolved -} - // SupplementState allows the user of a ContractResolver to supplement it with // state required for the proper resolution of a contract. // @@ -196,7 +185,7 @@ func (c *anchorResolver) Launch() error { c.launched = true // If we're already resolved, then we can exit early. - if c.resolved { + if c.IsResolved() { c.log.Errorf("already resolved") return nil } diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index 75944fa6f7..a89ccf7e23 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -12,9 +12,6 @@ import ( // future, this will likely take over the duties the current BreachArbitrator // has. type breachResolver struct { - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // subscribed denotes whether or not the breach resolver has subscribed // to the BreachArbitrator for breach resolution. subscribed bool @@ -60,7 +57,7 @@ func (b *breachResolver) Resolve() (ContractResolver, error) { // If the breach resolution process is already complete, then // we can cleanup and checkpoint the resolved state. if complete { - b.resolved = true + b.markResolved() return nil, b.Checkpoint(b) } @@ -73,8 +70,9 @@ func (b *breachResolver) Resolve() (ContractResolver, error) { // The replyChan has been closed, signalling that the breach // has been fully resolved. Checkpoint the resolved state and // exit. - b.resolved = true + b.markResolved() return nil, b.Checkpoint(b) + case <-b.quit: } @@ -87,19 +85,13 @@ func (b *breachResolver) Stop() { close(b.quit) } -// IsResolved returns true if the breachResolver is fully resolved and cleanup -// can occur. -func (b *breachResolver) IsResolved() bool { - return b.resolved -} - // SupplementState adds additional state to the breachResolver. func (b *breachResolver) SupplementState(_ *channeldb.OpenChannel) { } // Encode encodes the breachResolver to the passed writer. func (b *breachResolver) Encode(w io.Writer) error { - return binary.Write(w, endian, b.resolved) + return binary.Write(w, endian, b.IsResolved()) } // newBreachResolverFromReader attempts to decode an encoded breachResolver @@ -112,9 +104,13 @@ func newBreachResolverFromReader(r io.Reader, resCfg ResolverConfig) ( replyChan: make(chan struct{}), } - if err := binary.Read(r, endian, &b.resolved); err != nil { + var resolved bool + if err := binary.Read(r, endian, &resolved); err != nil { return nil, err } + if resolved { + b.markResolved() + } b.initLogger(fmt.Sprintf("%T(%v)", b, b.ChanPoint)) diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go index 533d0eff78..aa2e711efc 100644 --- a/contractcourt/briefcase_test.go +++ b/contractcourt/briefcase_test.go @@ -206,8 +206,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver, ogRes.outputIncubating, diskRes.outputIncubating) } if ogRes.resolved != diskRes.resolved { - t.Fatalf("expected %v, got %v", ogRes.resolved, - diskRes.resolved) + t.Fatalf("expected %v, got %v", ogRes.resolved.Load(), + diskRes.resolved.Load()) } if ogRes.broadcastHeight != diskRes.broadcastHeight { t.Fatalf("expected %v, got %v", @@ -229,8 +229,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver, ogRes.outputIncubating, diskRes.outputIncubating) } if ogRes.resolved != diskRes.resolved { - t.Fatalf("expected %v, got %v", ogRes.resolved, - diskRes.resolved) + t.Fatalf("expected %v, got %v", ogRes.resolved.Load(), + diskRes.resolved.Load()) } if ogRes.broadcastHeight != diskRes.broadcastHeight { t.Fatalf("expected %v, got %v", @@ -275,8 +275,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver, ogRes.commitResolution, diskRes.commitResolution) } if ogRes.resolved != diskRes.resolved { - t.Fatalf("expected %v, got %v", ogRes.resolved, - diskRes.resolved) + t.Fatalf("expected %v, got %v", ogRes.resolved.Load(), + diskRes.resolved.Load()) } if ogRes.broadcastHeight != diskRes.broadcastHeight { t.Fatalf("expected %v, got %v", @@ -312,13 +312,14 @@ func TestContractInsertionRetrieval(t *testing.T) { SweepSignDesc: testSignDesc, }, outputIncubating: true, - resolved: true, broadcastHeight: 102, htlc: channeldb.HTLC{ HtlcIndex: 12, }, } - successResolver := htlcSuccessResolver{ + timeoutResolver.resolved.Store(true) + + successResolver := &htlcSuccessResolver{ htlcResolution: lnwallet.IncomingHtlcResolution{ Preimage: testPreimage, SignedSuccessTx: nil, @@ -327,40 +328,49 @@ func TestContractInsertionRetrieval(t *testing.T) { SweepSignDesc: testSignDesc, }, outputIncubating: true, - resolved: true, broadcastHeight: 109, htlc: channeldb.HTLC{ RHash: testPreimage, }, } - resolvers := []ContractResolver{ - &timeoutResolver, - &successResolver, - &commitSweepResolver{ - commitResolution: lnwallet.CommitOutputResolution{ - SelfOutPoint: testChanPoint2, - SelfOutputSignDesc: testSignDesc, - MaturityDelay: 99, - }, - resolved: false, - broadcastHeight: 109, - chanPoint: testChanPoint1, + successResolver.resolved.Store(true) + + commitResolver := &commitSweepResolver{ + commitResolution: lnwallet.CommitOutputResolution{ + SelfOutPoint: testChanPoint2, + SelfOutputSignDesc: testSignDesc, + MaturityDelay: 99, }, + broadcastHeight: 109, + chanPoint: testChanPoint1, + } + commitResolver.resolved.Store(false) + + resolvers := []ContractResolver{ + &timeoutResolver, successResolver, commitResolver, } // All resolvers require a unique ResolverKey() output. To achieve this // for the composite resolvers, we'll mutate the underlying resolver // with a new outpoint. - contestTimeout := timeoutResolver - contestTimeout.htlcResolution.ClaimOutpoint = randOutPoint() + contestTimeout := htlcTimeoutResolver{ + htlcResolution: lnwallet.OutgoingHtlcResolution{ + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + } resolvers = append(resolvers, &htlcOutgoingContestResolver{ htlcTimeoutResolver: &contestTimeout, }) - contestSuccess := successResolver - contestSuccess.htlcResolution.ClaimOutpoint = randOutPoint() + contestSuccess := &htlcSuccessResolver{ + htlcResolution: lnwallet.IncomingHtlcResolution{ + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + } resolvers = append(resolvers, &htlcIncomingContestResolver{ htlcExpiry: 100, - htlcSuccessResolver: &contestSuccess, + htlcSuccessResolver: contestSuccess, }) // For quick lookup during the test, we'll create this map which allow @@ -438,12 +448,12 @@ func TestContractResolution(t *testing.T) { SweepSignDesc: testSignDesc, }, outputIncubating: true, - resolved: true, broadcastHeight: 192, htlc: channeldb.HTLC{ HtlcIndex: 9912, }, } + timeoutResolver.resolved.Store(true) // First, we'll insert the resolver into the database and ensure that // we get the same resolver out the other side. We do not need to apply @@ -491,12 +501,13 @@ func TestContractSwapping(t *testing.T) { SweepSignDesc: testSignDesc, }, outputIncubating: true, - resolved: true, broadcastHeight: 102, htlc: channeldb.HTLC{ HtlcIndex: 12, }, } + timeoutResolver.resolved.Store(true) + contestResolver := &htlcOutgoingContestResolver{ htlcTimeoutResolver: timeoutResolver, } diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index b3323158db..84efdeb067 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -39,9 +39,6 @@ type commitSweepResolver struct { // this HTLC on-chain. commitResolution lnwallet.CommitOutputResolution - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // broadcastHeight is the height that the original contract was // broadcast to the main-chain at. We'll use this value to bound any // historical queries to the chain for spends/confirmations. @@ -171,7 +168,7 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { //nolint:funlen func (c *commitSweepResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. - if c.resolved { + if c.IsResolved() { c.log.Errorf("already resolved") return nil, nil } @@ -224,7 +221,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) { report := c.currentReport.resolverReport( &sweepTxID, channeldb.ResolverTypeCommit, outcome, ) - c.resolved = true + c.markResolved() // Checkpoint the resolver with a closure that will write the outcome // of the resolver and its sweep transaction to disk. @@ -241,14 +238,6 @@ func (c *commitSweepResolver) Stop() { close(c.quit) } -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (c *commitSweepResolver) IsResolved() bool { - return c.resolved -} - // SupplementState allows the user of a ContractResolver to supplement it with // state required for the proper resolution of a contract. // @@ -277,7 +266,7 @@ func (c *commitSweepResolver) Encode(w io.Writer) error { return err } - if err := binary.Write(w, endian, c.resolved); err != nil { + if err := binary.Write(w, endian, c.IsResolved()); err != nil { return err } if err := binary.Write(w, endian, c.broadcastHeight); err != nil { @@ -312,9 +301,14 @@ func newCommitSweepResolverFromReader(r io.Reader, resCfg ResolverConfig) ( return nil, err } - if err := binary.Read(r, endian, &c.resolved); err != nil { + var resolved bool + if err := binary.Read(r, endian, &resolved); err != nil { return nil, err } + if resolved { + c.markResolved() + } + if err := binary.Read(r, endian, &c.broadcastHeight); err != nil { return nil, err } @@ -383,7 +377,7 @@ func (c *commitSweepResolver) Launch() error { c.launched = true // If we're already resolved, then we can exit early. - if c.resolved { + if c.IsResolved() { c.log.Errorf("already resolved") return nil } diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 814c02ff56..5311d466ae 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "sync/atomic" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" @@ -119,6 +120,9 @@ type contractResolverKit struct { // launched specifies whether the resolver has been launched. Calling // `Launch` will be a no-op if this is true. launched bool + + // resolved reflects if the contract has been fully resolved or not. + resolved atomic.Bool } // newContractResolverKit instantiates the mix-in struct. @@ -137,6 +141,19 @@ func (r *contractResolverKit) initLogger(prefix string) { r.log = log.WithPrefix(logPrefix) } +// IsResolved returns true if the stored state in the resolve is fully +// resolved. In this case the target output can be forgotten. +// +// NOTE: Part of the ContractResolver interface. +func (r *contractResolverKit) IsResolved() bool { + return r.resolved.Load() +} + +// markResolved marks the resolver as resolved. +func (r *contractResolverKit) markResolved() { + r.resolved.Store(true) +} + var ( // errResolverShuttingDown is returned when the resolver stops // progressing because it received the quit signal. diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 0e0c975248..3179010225 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -124,7 +124,7 @@ func (h *htlcIncomingContestResolver) Launch() error { func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. - if h.resolved { + if h.IsResolved() { h.log.Errorf("already resolved") return nil, nil } @@ -140,7 +140,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // will time it out and get their funds back. This situation // can present itself when we crash before processRemoteAdds in // the link has ran. - h.resolved = true + h.markResolved() if err := h.processFinalHtlcFail(); err != nil { return nil, err @@ -193,7 +193,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { log.Infof("%T(%v): HTLC has timed out (expiry=%v, height=%v), "+ "abandoning", h, h.htlcResolution.ClaimOutpoint, h.htlcExpiry, currentHeight) - h.resolved = true + h.markResolved() if err := h.processFinalHtlcFail(); err != nil { return nil, err @@ -234,7 +234,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { h.htlcResolution.ClaimOutpoint, h.htlcExpiry, currentHeight) - h.resolved = true + h.markResolved() if err := h.processFinalHtlcFail(); err != nil { return nil, err @@ -395,7 +395,8 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { "(expiry=%v, height=%v), abandoning", h, h.htlcResolution.ClaimOutpoint, h.htlcExpiry, currentHeight) - h.resolved = true + + h.markResolved() if err := h.processFinalHtlcFail(); err != nil { return nil, err @@ -516,14 +517,6 @@ func (h *htlcIncomingContestResolver) Stop() { close(h.quit) } -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcIncomingContestResolver) IsResolved() bool { - return h.resolved -} - // Encode writes an encoded version of the ContractResolver into the passed // Writer. // diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index b23259b5bb..8763dfb27e 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -82,7 +82,7 @@ func (h *htlcOutgoingContestResolver) Launch() error { func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. - if h.resolved { + if h.IsResolved() { h.log.Errorf("already resolved") return nil, nil } @@ -215,14 +215,6 @@ func (h *htlcOutgoingContestResolver) Stop() { close(h.quit) } -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcOutgoingContestResolver) IsResolved() bool { - return h.resolved -} - // Encode writes an encoded version of the ContractResolver into the passed // Writer. // diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 20b4516634..b4c6db2b91 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -42,9 +42,6 @@ type htlcSuccessResolver struct { // second-level output (true). outputIncubating bool - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // broadcastHeight is the height that the original contract was // broadcast to the main-chain at. We'll use this value to bound any // historical queries to the chain for spends/confirmations. @@ -122,7 +119,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { switch { // If we're already resolved, then we can exit early. - case h.resolved: + case h.IsResolved(): h.log.Errorf("already resolved") // If this is an output on the remote party's commitment transaction, @@ -226,7 +223,7 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash) error { } // Finally, we checkpoint the resolver with our report(s). - h.resolved = true + h.markResolved() return h.Checkpoint(h, reports...) } @@ -241,14 +238,6 @@ func (h *htlcSuccessResolver) Stop() { close(h.quit) } -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) IsResolved() bool { - return h.resolved -} - // report returns a report on the resolution state of the contract. func (h *htlcSuccessResolver) report() *ContractReport { // If the sign details are nil, the report will be created by handled @@ -298,7 +287,7 @@ func (h *htlcSuccessResolver) Encode(w io.Writer) error { if err := binary.Write(w, endian, h.outputIncubating); err != nil { return err } - if err := binary.Write(w, endian, h.resolved); err != nil { + if err := binary.Write(w, endian, h.IsResolved()); err != nil { return err } if err := binary.Write(w, endian, h.broadcastHeight); err != nil { @@ -337,9 +326,15 @@ func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) ( if err := binary.Read(r, endian, &h.outputIncubating); err != nil { return nil, err } - if err := binary.Read(r, endian, &h.resolved); err != nil { + + var resolved bool + if err := binary.Read(r, endian, &resolved); err != nil { return nil, err } + if resolved { + h.markResolved() + } + if err := binary.Read(r, endian, &h.broadcastHeight); err != nil { return nil, err } @@ -745,7 +740,7 @@ func (h *htlcSuccessResolver) Launch() error { switch { // If we're already resolved, then we can exit early. - case h.resolved: + case h.IsResolved(): h.log.Errorf("already resolved") return nil diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index f395d67215..417ef907ed 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -616,11 +616,11 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext, var resolved, incubating bool if h, ok := resolver.(*htlcSuccessResolver); ok { - resolved = h.resolved + resolved = h.resolved.Load() incubating = h.outputIncubating } if h, ok := resolver.(*htlcTimeoutResolver); ok { - resolved = h.resolved + resolved = h.resolved.Load() incubating = h.outputIncubating } diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index eae52255f2..125d205477 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -38,9 +38,6 @@ type htlcTimeoutResolver struct { // incubator (utxo nursery). outputIncubating bool - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // broadcastHeight is the height that the original contract was // broadcast to the main-chain at. We'll use this value to bound any // historical queries to the chain for spends/confirmations. @@ -238,7 +235,7 @@ func (h *htlcTimeoutResolver) claimCleanUp( }); err != nil { return err } - h.resolved = true + h.markResolved() // Checkpoint our resolver with a report which reflects the preimage // claim by the remote party. @@ -424,7 +421,7 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool { // NOTE: Part of the ContractResolver interface. func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. - if h.resolved { + if h.IsResolved() { h.log.Errorf("already resolved") return nil, nil } @@ -622,14 +619,6 @@ func (h *htlcTimeoutResolver) Stop() { close(h.quit) } -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) IsResolved() bool { - return h.resolved -} - // report returns a report on the resolution state of the contract. func (h *htlcTimeoutResolver) report() *ContractReport { // If we have a SignedTimeoutTx but no SignDetails, this is a local @@ -689,7 +678,7 @@ func (h *htlcTimeoutResolver) Encode(w io.Writer) error { if err := binary.Write(w, endian, h.outputIncubating); err != nil { return err } - if err := binary.Write(w, endian, h.resolved); err != nil { + if err := binary.Write(w, endian, h.IsResolved()); err != nil { return err } if err := binary.Write(w, endian, h.broadcastHeight); err != nil { @@ -730,9 +719,15 @@ func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) ( if err := binary.Read(r, endian, &h.outputIncubating); err != nil { return nil, err } - if err := binary.Read(r, endian, &h.resolved); err != nil { + + var resolved bool + if err := binary.Read(r, endian, &resolved); err != nil { return nil, err } + if resolved { + h.markResolved() + } + if err := binary.Read(r, endian, &h.broadcastHeight); err != nil { return nil, err } @@ -1149,7 +1144,7 @@ func (h *htlcTimeoutResolver) checkpointClaim( } // Finally, we checkpoint the resolver with our report(s). - h.resolved = true + h.markResolved() return h.Checkpoint(h, report) } @@ -1285,7 +1280,7 @@ func (h *htlcTimeoutResolver) Launch() error { switch { // If we're already resolved, then we can exit early. - case h.resolved: + case h.IsResolved(): h.log.Errorf("already resolved") return nil diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index 89f334aa83..017d3d3886 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -532,7 +532,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { wg.Wait() // Finally, the resolver should be marked as resolved. - if !resolver.resolved { + if !resolver.resolved.Load() { t.Fatalf("resolver should be marked as resolved") } } From 8ee2b2767d16356e52299c562574c1d0eddf3a6f Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 11 Jul 2024 16:19:01 +0800 Subject: [PATCH 57/59] contractcourt: fix concurrent access to `launched` --- contractcourt/anchor_resolver.go | 4 ++-- contractcourt/breach_resolver.go | 4 ++-- contractcourt/commit_sweep_resolver.go | 4 ++-- contractcourt/contract_resolver.go | 17 +++++++++++++++-- contractcourt/htlc_incoming_contest_resolver.go | 2 +- contractcourt/htlc_outgoing_contest_resolver.go | 2 +- contractcourt/htlc_success_resolver.go | 4 ++-- contractcourt/htlc_timeout_resolver.go | 4 ++-- 8 files changed, 27 insertions(+), 14 deletions(-) diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index 5a51113881..eb42ad3cb4 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -176,13 +176,13 @@ var _ ContractResolver = (*anchorResolver)(nil) // Launch offers the anchor output to the sweeper. func (c *anchorResolver) Launch() error { - if c.launched { + if c.isLaunched() { c.log.Tracef("already launched") return nil } c.log.Debugf("launching resolver...") - c.launched = true + c.markLaunched() // If we're already resolved, then we can exit early. if c.IsResolved() { diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index a89ccf7e23..5644e60fad 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -123,13 +123,13 @@ var _ ContractResolver = (*breachResolver)(nil) // TODO(yy): implement it once the outputs are offered to the sweeper. func (b *breachResolver) Launch() error { - if b.launched { + if b.isLaunched() { b.log.Tracef("already launched") return nil } b.log.Debugf("launching resolver...") - b.launched = true + b.markLaunched() return nil } diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 84efdeb067..55ee08e5d8 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -368,13 +368,13 @@ var _ reportingContractResolver = (*commitSweepResolver)(nil) // Launch constructs a commit input and offers it to the sweeper. func (c *commitSweepResolver) Launch() error { - if c.launched { + if c.isLaunched() { c.log.Tracef("already launched") return nil } c.log.Debugf("launching resolver...") - c.launched = true + c.markLaunched() // If we're already resolved, then we can exit early. if c.IsResolved() { diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 5311d466ae..293c849662 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -118,8 +118,11 @@ type contractResolverKit struct { sweepResultChan chan sweep.Result // launched specifies whether the resolver has been launched. Calling - // `Launch` will be a no-op if this is true. - launched bool + // `Launch` will be a no-op if this is true. This value is not saved to + // db, as it's fine to relaunch a resolver after a restart. It's only + // used to avoid resending requests to the sweeper when a new blockbeat + // is received. + launched atomic.Bool // resolved reflects if the contract has been fully resolved or not. resolved atomic.Bool @@ -154,6 +157,16 @@ func (r *contractResolverKit) markResolved() { r.resolved.Store(true) } +// isLaunched returns true if the resolver has been launched. +func (r *contractResolverKit) isLaunched() bool { + return r.launched.Load() +} + +// markLaunched marks the resolver as launched. +func (r *contractResolverKit) markLaunched() { + r.launched.Store(true) +} + var ( // errResolverShuttingDown is returned when the resolver stops // progressing because it received the quit signal. diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 3179010225..bc5948487d 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -83,7 +83,7 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error { func (h *htlcIncomingContestResolver) Launch() error { // NOTE: we don't mark this resolver as launched as the inner resolver // will set it when it's launched. - if h.launched { + if h.isLaunched() { h.log.Tracef("already launched") return nil } diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 8763dfb27e..19574dee27 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -40,7 +40,7 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, func (h *htlcOutgoingContestResolver) Launch() error { // NOTE: we don't mark this resolver as launched as the inner resolver // will set it when it's launched. - if h.launched { + if h.isLaunched() { h.log.Tracef("already launched") return nil } diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index b4c6db2b91..a4d27ba4e8 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -730,13 +730,13 @@ func (h *htlcSuccessResolver) resolveSuccessTxOutput(op wire.OutPoint) error { // Launch creates an input based on the details of the incoming htlc resolution // and offers it to the sweeper. func (h *htlcSuccessResolver) Launch() error { - if h.launched { + if h.isLaunched() { h.log.Tracef("already launched") return nil } h.log.Debugf("launching resolver...") - h.launched = true + h.markLaunched() switch { // If we're already resolved, then we can exit early. diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 125d205477..1782cfb3ba 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -1270,13 +1270,13 @@ func (h *htlcTimeoutResolver) resolveTimeoutTxOutput(op wire.OutPoint) error { // Launch creates an input based on the details of the outgoing htlc resolution // and offers it to the sweeper. func (h *htlcTimeoutResolver) Launch() error { - if h.launched { + if h.isLaunched() { h.log.Tracef("already launched") return nil } h.log.Debugf("launching resolver...") - h.launched = true + h.launched.Store(true) switch { // If we're already resolved, then we can exit early. From 1fee03d10006559587d4548ec92ce444f2bdc3c0 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 26 Jun 2024 07:41:51 +0800 Subject: [PATCH 58/59] contractcourt: break `launchResolvers` into two steps In this commit, we break the old `launchResolvers` into two steps - step one is to launch the resolvers synchronously, and step two is to actually waiting for the resolvers to be resolved. This is critical as in the following commit we will require the resolvers to be launched at the same blockbeat when a force close event is sent by the chain watcher. --- contractcourt/channel_arbitrator.go | 75 +++++++++++++++++++-- contractcourt/channel_arbitrator_test.go | 54 ++++++++------- contractcourt/contract_resolver.go | 11 +++ contractcourt/htlc_success_resolver_test.go | 3 + 4 files changed, 113 insertions(+), 30 deletions(-) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index ca70f733ff..9b554cb648 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -816,7 +816,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // TODO(roasbeef): this isn't re-launched? } - c.launchResolvers(unresolvedContracts) + c.resolveContracts(unresolvedContracts) return nil } @@ -1355,7 +1355,7 @@ func (c *ChannelArbitrator) stateStep( // Finally, we'll launch all the required contract resolvers. // Once they're all resolved, we're no longer needed. - c.launchResolvers(resolvers) + c.resolveContracts(resolvers) nextState = StateWaitingFullResolution @@ -1578,18 +1578,75 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32, return fn.Some(int32(deadline)), valueLeft, nil } -// launchResolvers updates the activeResolvers list and starts the resolvers. -func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver) { +// resolveContracts updates the activeResolvers list and starts to resolve each +// contract concurrently, and launches them. +func (c *ChannelArbitrator) resolveContracts(resolvers []ContractResolver) { c.activeResolversLock.Lock() c.activeResolvers = resolvers c.activeResolversLock.Unlock() + // Launch all resolvers. + c.launchResolvers() + for _, contract := range resolvers { c.wg.Add(1) go c.resolveContract(contract) } } +// launchResolvers launches all the active resolvers concurrently. +func (c *ChannelArbitrator) launchResolvers() { + c.activeResolversLock.Lock() + resolvers := c.activeResolvers + c.activeResolversLock.Unlock() + + // errChans is a map of channels that will be used to receive errors + // returned from launching the resolvers. + errChans := make(map[ContractResolver]chan error, len(resolvers)) + + // Launch each resolver in goroutines. + for _, r := range resolvers { + // If the contract is already resolved, there's no need to + // launch it again. + if r.IsResolved() { + log.Debugf("ChannelArbitrator(%v): skipping resolver "+ + "%T as it's already resolved", c.cfg.ChanPoint, + r) + + continue + } + + // Create a signal chan. + errChan := make(chan error, 1) + errChans[r] = errChan + + go func() { + err := r.Launch() + errChan <- err + }() + } + + // Wait for all resolvers to finish launching. + for r, errChan := range errChans { + select { + case err := <-errChan: + if err == nil { + continue + } + + log.Errorf("ChannelArbitrator(%v): unable to launch "+ + "contract resolver(%T): %v", c.cfg.ChanPoint, r, + err) + + case <-c.quit: + log.Debugf("ChannelArbitrator quit signal received, " + + "exit launchResolvers") + + return + } + } +} + // advanceState is the main driver of our state machine. This method is an // iterative function which repeatedly attempts to advance the internal state // of the channel arbitrator. The state will be advanced until we reach a @@ -2628,6 +2685,13 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver) { // loop. currentContract = nextContract + // Launch the new contract. + err = currentContract.Launch() + if err != nil { + log.Errorf("Failed to launch %T: %v", + currentContract, err) + } + // If this contract is actually fully resolved, then // we'll mark it as such within the database. case currentContract.IsResolved(): @@ -3144,6 +3208,9 @@ func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Blockbeat) error { } } + // Launch all active resolvers when a new blockbeat is received. + c.launchResolvers() + return nil } diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index d1ef2993d5..827e10b5c7 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" @@ -1101,12 +1102,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { } // Send a notification that the expiry height has been reached. - // - // TODO(yy): remove the EpochChan and use the blockbeat below once - // resolvers are hooked with the blockbeat. oldNotifier.EpochChan <- &chainntnfs.BlockEpoch{Height: 10} - // beat := chainio.NewBlockbeatFromHeight(10) - // chanArb.BlockbeatChan <- beat // htlcOutgoingContestResolver is now transforming into a // htlcTimeoutResolver and should send the contract off for incubation. @@ -1149,8 +1145,12 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { default: } - // Notify resolver that the second level transaction is spent. - oldNotifier.SpendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} + // Notify resolver that the output of the timeout tx has been spent. + oldNotifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: closeTx, + SpentOutPoint: &wire.OutPoint{}, + SpenderTxHash: &closeTxid, + } // At this point channel should be marked as resolved. chanArbCtxNew.AssertStateTransitions(StateFullyResolved) @@ -2830,27 +2830,28 @@ func TestChannelArbitratorAnchors(t *testing.T) { } chanArb.UpdateContractSignals(signals) - // Set current block height. - beat = newBeatFromHeight(int32(heightHint)) - chanArbCtx.chanArb.BlockbeatChan <- beat - htlcAmt := lnwire.MilliSatoshi(1_000_000) // Create testing HTLCs. - deadlineDelta := uint32(10) - deadlinePreimageDelta := deadlineDelta + 2 + spendingHeight := uint32(beat.Height()) + deadlineDelta := uint32(100) + + deadlinePreimageDelta := deadlineDelta htlcWithPreimage := channeldb.HTLC{ - HtlcIndex: 99, - RefundTimeout: heightHint + deadlinePreimageDelta, + HtlcIndex: 99, + // RefundTimeout is 101. + RefundTimeout: spendingHeight + deadlinePreimageDelta, RHash: rHash, Incoming: true, Amt: htlcAmt, } + expectedDeadline := deadlineDelta/2 + spendingHeight - deadlineHTLCdelta := deadlineDelta + 3 + deadlineHTLCdelta := deadlineDelta + 40 htlc := channeldb.HTLC{ - HtlcIndex: 100, - RefundTimeout: heightHint + deadlineHTLCdelta, + HtlcIndex: 100, + // RefundTimeout is 141. + RefundTimeout: spendingHeight + deadlineHTLCdelta, Amt: htlcAmt, } @@ -2935,7 +2936,9 @@ func TestChannelArbitratorAnchors(t *testing.T) { //nolint:ll chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ - SpendDetail: &chainntnfs.SpendDetail{}, + SpendDetail: &chainntnfs.SpendDetail{ + SpendingHeight: int32(spendingHeight), + }, LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{ CloseTx: closeTx, ContractResolutions: fn.Some(lnwallet.ContractResolutions{ @@ -2999,16 +3002,14 @@ func TestChannelArbitratorAnchors(t *testing.T) { // to htlcWithPreimage's CLTV. require.Equal(t, 2, len(chanArbCtx.sweeper.deadlines)) require.EqualValues(t, - heightHint+deadlinePreimageDelta/2, + expectedDeadline, chanArbCtx.sweeper.deadlines[0], "want %d, got %d", - heightHint+deadlinePreimageDelta/2, - chanArbCtx.sweeper.deadlines[0], + expectedDeadline, chanArbCtx.sweeper.deadlines[0], ) require.EqualValues(t, - heightHint+deadlinePreimageDelta/2, + expectedDeadline, chanArbCtx.sweeper.deadlines[1], "want %d, got %d", - heightHint+deadlinePreimageDelta/2, - chanArbCtx.sweeper.deadlines[1], + expectedDeadline, chanArbCtx.sweeper.deadlines[1], ) } @@ -3159,7 +3160,8 @@ func assertResolverReport(t *testing.T, reports chan *channeldb.ResolverReport, select { case report := <-reports: if !reflect.DeepEqual(report, expected) { - t.Fatalf("expected: %v, got: %v", expected, report) + t.Fatalf("expected: %v, got: %v", spew.Sdump(expected), + spew.Sdump(report)) } case <-time.After(defaultTimeout): diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 293c849662..d11bd2f597 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -37,6 +37,17 @@ type ContractResolver interface { // resides within. ResolverKey() []byte + // Launch starts the resolver by constructing an input and offering it + // to the sweeper. Once offered, it's expected to monitor the sweeping + // result in a goroutine invoked by calling Resolve. + // + // NOTE: We can call `Resolve` inside a goroutine at the end of this + // method to avoid calling it in the ChannelArbitrator. However, there + // are some DB-related operations such as SwapContract/ResolveContract + // which need to be done inside the resolvers instead, which needs a + // deeper refactoring. + Launch() error + // Resolve instructs the contract resolver to resolve the output // on-chain. Once the output has been *fully* resolved, the function // should return immediately with a nil ContractResolver value for the diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index 417ef907ed..fe6ee1ad0e 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -146,6 +146,9 @@ func (i *htlcResolverTestContext) resolve() { i.resolverResultChan = make(chan resolveResult, 1) go func() { + err := i.resolver.Launch() + require.NoError(i.t, err) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, From cebad6d0051721bc15fac4b2b224951b6bc8914f Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 25 Nov 2024 13:55:11 +0800 Subject: [PATCH 59/59] contractcourt: offer outgoing htlc one block earlier before its expiry We need to offer the outgoing htlc one block earlier to make sure when the expiry height hits, the sweeper will not miss sweeping it in the same block. This also means the outgoing contest resolver now only does one thing - watch for preimage spend till height expiry-1, which can easily be moved into the timeout resolver instead in the future. --- contractcourt/htlc_outgoing_contest_resolver.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 19574dee27..b66a3fdf0b 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -155,7 +155,14 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { // fail. newHeight := uint32(newBlock.Height) expiry := h.htlcResolution.Expiry - if newHeight >= expiry { + + // Check if the expiry height is about to be reached. + // We offer this HTLC one block earlier to make sure + // when the next block arrives, the sweeper will pick + // up this input and sweep it immediately. The sweeper + // will handle the waiting for the one last block till + // expiry. + if newHeight >= expiry-1 { h.log.Infof("HTLC about to expire "+ "(height=%v, expiry=%v), transforming "+ "into timeout resolver", h,