diff --git a/aliasmgr/aliasmgr.go b/aliasmgr/aliasmgr.go index f06cb53d79..a3227b18b8 100644 --- a/aliasmgr/aliasmgr.go +++ b/aliasmgr/aliasmgr.go @@ -5,7 +5,7 @@ import ( "fmt" "sync" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -432,9 +432,9 @@ func (m *Manager) DeleteLocalAlias(alias, } // We'll filter the alias set and remove the alias from it. - aliasSet = fn.Filter(func(a lnwire.ShortChannelID) bool { + aliasSet = fn.Filter(aliasSet, func(a lnwire.ShortChannelID) bool { return a.ToUint64() != alias.ToUint64() - }, aliasSet) + }) // If the alias set is empty, we'll delete the base SCID from the // baseToSet map. @@ -514,11 +514,17 @@ func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) { // haveAlias returns true if the passed alias is already assigned to a // channel in the baseToSet map. haveAlias := func(maybeNextAlias lnwire.ShortChannelID) bool { - return fn.Any(func(aliasList []lnwire.ShortChannelID) bool { - return fn.Any(func(alias lnwire.ShortChannelID) bool { - return alias == maybeNextAlias - }, aliasList) - }, maps.Values(m.baseToSet)) + return fn.Any( + maps.Values(m.baseToSet), + func(aliasList []lnwire.ShortChannelID) bool { + return fn.Any( + aliasList, + func(alias lnwire.ShortChannelID) bool { + return alias == maybeNextAlias + }, + ) + }, + ) } err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { diff --git a/chainio/README.md b/chainio/README.md new file mode 100644 index 0000000000..b11e38157c --- /dev/null +++ b/chainio/README.md @@ -0,0 +1,152 @@ +# Chainio + +`chainio` is a package designed to provide blockchain data access to various +subsystems within `lnd`. When a new block is received, it is encapsulated in a +`Blockbeat` object and disseminated to all registered consumers. Consumers may +receive these updates either concurrently or sequentially, based on their +registration configuration, ensuring that each subsystem maintains a +synchronized view of the current block state. + +The main components include: + +- `Blockbeat`: An interface that provides information about the block. + +- `Consumer`: An interface that specifies how subsystems handle the blockbeat. + +- `BlockbeatDispatcher`: The core service responsible for receiving each block + and distributing it to all consumers. + +Additionally, the `BeatConsumer` struct provides a partial implementation of +the `Consumer` interface. This struct helps reduce code duplication, allowing +subsystems to avoid re-implementing the `ProcessBlock` method and provides a +commonly used `NotifyBlockProcessed` method. + + +### Register a Consumer + +Consumers within the same queue are notified **sequentially**, while all queues +are notified **concurrently**. A queue consists of a slice of consumers, which +are notified in left-to-right order. Developers are responsible for determining +dependencies in block consumption across subsystems: independent subsystems +should be notified concurrently, whereas dependent subsystems should be +notified sequentially. + +To notify the consumers concurrently, put them in different queues, +```go +// consumer1 and consumer2 will be notified concurrently. +queue1 := []chainio.Consumer{consumer1} +blockbeatDispatcher.RegisterQueue(consumer1) + +queue2 := []chainio.Consumer{consumer2} +blockbeatDispatcher.RegisterQueue(consumer2) +``` + +To notify the consumers sequentially, put them in the same queue, +```go +// consumers will be notified sequentially via, +// consumer1 -> consumer2 -> consumer3 +queue := []chainio.Consumer{ + consumer1, + consumer2, + consumer3, +} +blockbeatDispatcher.RegisterQueue(queue) +``` + +### Implement the `Consumer` Interface + +Implementing the `Consumer` interface is straightforward. Below is an example +of how +[`sweep.TxPublisher`](https://github.com/lightningnetwork/lnd/blob/5cec466fad44c582a64cfaeb91f6d5fd302fcf85/sweep/fee_bumper.go#L310) +implements this interface. + +To start, embed the partial implementation `chainio.BeatConsumer`, which +already provides the `ProcessBlock` implementation and commonly used +`NotifyBlockProcessed` method, and exposes `BlockbeatChan` for the consumer to +receive blockbeats. + +```go +type TxPublisher struct { + started atomic.Bool + stopped atomic.Bool + + chainio.BeatConsumer + + ... +``` + +We should also remember to initialize this `BeatConsumer`, + +```go +... +// Mount the block consumer. +tp.BeatConsumer = chainio.NewBeatConsumer(tp.quit, tp.Name()) +``` + +Finally, in the main event loop, read from `BlockbeatChan`, process the +received blockbeat, and, crucially, call `tp.NotifyBlockProcessed` to inform +the blockbeat dispatcher that processing is complete. + +```go +for { + select { + case beat := <-tp.BlockbeatChan: + // Consume this blockbeat, usually it means updating the subsystem + // using the new block data. + + // Notify we've processed the block. + tp.NotifyBlockProcessed(beat, nil) + + ... +``` + +### Existing Queues + +Currently, we have a single queue of consumers dedicated to handling force +closures. This queue includes `ChainArbitrator`, `UtxoSweeper`, and +`TxPublisher`, with `ChainArbitrator` managing two internal consumers: +`chainWatcher` and `ChannelArbitrator`. The blockbeat flows sequentially +through the chain as follows: `ChainArbitrator => chainWatcher => +ChannelArbitrator => UtxoSweeper => TxPublisher`. The following diagram +illustrates the flow within the public subsystems. + +```mermaid +sequenceDiagram + autonumber + participant bb as BlockBeat + participant cc as ChainArb + participant us as UtxoSweeper + participant tp as TxPublisher + + note left of bb: 0. received block x,
dispatching... + + note over bb,cc: 1. send block x to ChainArb,
wait for its done signal + bb->>cc: block x + rect rgba(165, 0, 85, 0.8) + critical signal processed + cc->>bb: processed block + option Process error or timeout + bb->>bb: error and exit + end + end + + note over bb,us: 2. send block x to UtxoSweeper, wait for its done signal + bb->>us: block x + rect rgba(165, 0, 85, 0.8) + critical signal processed + us->>bb: processed block + option Process error or timeout + bb->>bb: error and exit + end + end + + note over bb,tp: 3. send block x to TxPublisher, wait for its done signal + bb->>tp: block x + rect rgba(165, 0, 85, 0.8) + critical signal processed + tp->>bb: processed block + option Process error or timeout + bb->>bb: error and exit + end + end +``` diff --git a/chainio/blockbeat.go b/chainio/blockbeat.go new file mode 100644 index 0000000000..79188657fe --- /dev/null +++ b/chainio/blockbeat.go @@ -0,0 +1,54 @@ +package chainio + +import ( + "fmt" + + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/chainntnfs" +) + +// Beat implements the Blockbeat interface. It contains the block epoch and a +// customized logger. +// +// TODO(yy): extend this to check for confirmation status - which serves as the +// single source of truth, to avoid the potential race between receiving blocks +// and `GetTransactionDetails/RegisterSpendNtfn/RegisterConfirmationsNtfn`. +type Beat struct { + // epoch is the current block epoch the blockbeat is aware of. + epoch chainntnfs.BlockEpoch + + // log is the customized logger for the blockbeat which prints the + // block height. + log btclog.Logger +} + +// Compile-time check to ensure Beat satisfies the Blockbeat interface. +var _ Blockbeat = (*Beat)(nil) + +// NewBeat creates a new beat with the specified block epoch and a customized +// logger. +func NewBeat(epoch chainntnfs.BlockEpoch) *Beat { + b := &Beat{ + epoch: epoch, + } + + // Create a customized logger for the blockbeat. + logPrefix := fmt.Sprintf("Height[%6d]:", b.Height()) + b.log = clog.WithPrefix(logPrefix) + + return b +} + +// Height returns the height of the block epoch. +// +// NOTE: Part of the Blockbeat interface. +func (b *Beat) Height() int32 { + return b.epoch.Height +} + +// logger returns the logger for the blockbeat. +// +// NOTE: Part of the private blockbeat interface. +func (b *Beat) logger() btclog.Logger { + return b.log +} diff --git a/chainio/blockbeat_test.go b/chainio/blockbeat_test.go new file mode 100644 index 0000000000..9326651b38 --- /dev/null +++ b/chainio/blockbeat_test.go @@ -0,0 +1,28 @@ +package chainio + +import ( + "errors" + "testing" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/stretchr/testify/require" +) + +var errDummy = errors.New("dummy error") + +// TestNewBeat tests the NewBeat and Height functions. +func TestNewBeat(t *testing.T) { + t.Parallel() + + // Create a testing epoch. + epoch := chainntnfs.BlockEpoch{ + Height: 1, + } + + // Create the beat and check the internal state. + beat := NewBeat(epoch) + require.Equal(t, epoch, beat.epoch) + + // Check the height function. + require.Equal(t, epoch.Height, beat.Height()) +} diff --git a/chainio/consumer.go b/chainio/consumer.go new file mode 100644 index 0000000000..a9ec25745b --- /dev/null +++ b/chainio/consumer.go @@ -0,0 +1,113 @@ +package chainio + +// BeatConsumer defines a supplementary component that should be used by +// subsystems which implement the `Consumer` interface. It partially implements +// the `Consumer` interface by providing the method `ProcessBlock` such that +// subsystems don't need to re-implement it. +// +// While inheritance is not commonly used in Go, subsystems embedding this +// struct cannot pass the interface check for `Consumer` because the `Name` +// method is not implemented, which gives us a "mortise and tenon" structure. +// In addition to reducing code duplication, this design allows `ProcessBlock` +// to work on the concrete type `Beat` to access its internal states. +type BeatConsumer struct { + // BlockbeatChan is a channel to receive blocks from Blockbeat. The + // received block contains the best known height and the txns confirmed + // in this block. + BlockbeatChan chan Blockbeat + + // name is the name of the consumer which embeds the BlockConsumer. + name string + + // quit is a channel that closes when the BlockConsumer is shutting + // down. + // + // NOTE: this quit channel should be mounted to the same quit channel + // used by the subsystem. + quit chan struct{} + + // errChan is a buffered chan that receives an error returned from + // processing this block. + errChan chan error +} + +// NewBeatConsumer creates a new BlockConsumer. +func NewBeatConsumer(quit chan struct{}, name string) BeatConsumer { + // Refuse to start `lnd` if the quit channel is not initialized. We + // treat this case as if we are facing a nil pointer dereference, as + // there's no point to return an error here, which will cause the node + // to fail to be started anyway. + if quit == nil { + panic("quit channel is nil") + } + + b := BeatConsumer{ + BlockbeatChan: make(chan Blockbeat), + name: name, + errChan: make(chan error, 1), + quit: quit, + } + + return b +} + +// ProcessBlock takes a blockbeat and sends it to the consumer's blockbeat +// channel. It will send it to the subsystem's BlockbeatChan, and block until +// the processed result is received from the subsystem. The subsystem must call +// `NotifyBlockProcessed` after it has finished processing the block. +// +// NOTE: part of the `chainio.Consumer` interface. +func (b *BeatConsumer) ProcessBlock(beat Blockbeat) error { + // Update the current height. + beat.logger().Tracef("set current height for [%s]", b.name) + + select { + // Send the beat to the blockbeat channel. It's expected that the + // consumer will read from this channel and process the block. Once + // processed, it should return the error or nil to the beat.Err chan. + case b.BlockbeatChan <- beat: + beat.logger().Tracef("Sent blockbeat to [%s]", b.name) + + case <-b.quit: + beat.logger().Debugf("[%s] received shutdown before sending "+ + "beat", b.name) + + return nil + } + + // Check the consumer's err chan. We expect the consumer to call + // `beat.NotifyBlockProcessed` to send the error back here. + select { + case err := <-b.errChan: + beat.logger().Debugf("[%s] processed beat: err=%v", b.name, err) + + return err + + case <-b.quit: + beat.logger().Debugf("[%s] received shutdown", b.name) + } + + return nil +} + +// NotifyBlockProcessed signals that the block has been processed. It takes the +// blockbeat being processed and an error resulted from processing it. This +// error is then sent back to the consumer's err chan to unblock +// `ProcessBlock`. +// +// NOTE: This method must be called by the subsystem after it has finished +// processing the block. +func (b *BeatConsumer) NotifyBlockProcessed(beat Blockbeat, err error) { + // Update the current height. + beat.logger().Debugf("[%s]: notifying beat processed", b.name) + + select { + case b.errChan <- err: + beat.logger().Debugf("[%s]: notified beat processed, err=%v", + b.name, err) + + case <-b.quit: + beat.logger().Debugf("[%s] received shutdown before notifying "+ + "beat processed", b.name) + } +} diff --git a/chainio/consumer_test.go b/chainio/consumer_test.go new file mode 100644 index 0000000000..d1cabf3168 --- /dev/null +++ b/chainio/consumer_test.go @@ -0,0 +1,202 @@ +package chainio + +import ( + "testing" + "time" + + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/stretchr/testify/require" +) + +// TestNewBeatConsumer tests the NewBeatConsumer function. +func TestNewBeatConsumer(t *testing.T) { + t.Parallel() + + quitChan := make(chan struct{}) + name := "test" + + // Test the NewBeatConsumer function. + b := NewBeatConsumer(quitChan, name) + + // Assert the state. + require.Equal(t, quitChan, b.quit) + require.Equal(t, name, b.name) + require.NotNil(t, b.BlockbeatChan) +} + +// TestProcessBlockSuccess tests when the block is processed successfully, no +// error is returned. +func TestProcessBlockSuccess(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan. + consumerErrChan := make(chan error, 1) + b.errChan = consumerErrChan + + // Call the method under test. + resultChan := make(chan error, 1) + go func() { + resultChan <- b.ProcessBlock(mockBeat) + }() + + // Assert the beat is sent to the blockbeat channel. + beat, err := fn.RecvOrTimeout(b.BlockbeatChan, time.Second) + require.NoError(t, err) + require.Equal(t, mockBeat, beat) + + // Send nil to the consumer's error channel. + consumerErrChan <- nil + + // Assert the result of ProcessBlock is nil. + result, err := fn.RecvOrTimeout(resultChan, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestProcessBlockConsumerQuitBeforeSend tests when the consumer is quit +// before sending the beat, the method returns immediately. +func TestProcessBlockConsumerQuitBeforeSend(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Call the method under test. + resultChan := make(chan error, 1) + go func() { + resultChan <- b.ProcessBlock(mockBeat) + }() + + // Instead of reading the BlockbeatChan, close the quit channel. + close(quitChan) + + // Assert ProcessBlock returned nil. + result, err := fn.RecvOrTimeout(resultChan, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestProcessBlockConsumerQuitAfterSend tests when the consumer is quit after +// sending the beat, the method returns immediately. +func TestProcessBlockConsumerQuitAfterSend(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan. + consumerErrChan := make(chan error, 1) + b.errChan = consumerErrChan + + // Call the method under test. + resultChan := make(chan error, 1) + go func() { + resultChan <- b.ProcessBlock(mockBeat) + }() + + // Assert the beat is sent to the blockbeat channel. + beat, err := fn.RecvOrTimeout(b.BlockbeatChan, time.Second) + require.NoError(t, err) + require.Equal(t, mockBeat, beat) + + // Instead of sending nil to the consumer's error channel, close the + // quit chanel. + close(quitChan) + + // Assert ProcessBlock returned nil. + result, err := fn.RecvOrTimeout(resultChan, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestNotifyBlockProcessedSendErr asserts the error can be sent and read by +// the beat via NotifyBlockProcessed. +func TestNotifyBlockProcessedSendErr(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan. + consumerErrChan := make(chan error, 1) + b.errChan = consumerErrChan + + // Call the method under test. + done := make(chan error) + go func() { + defer close(done) + b.NotifyBlockProcessed(mockBeat, errDummy) + }() + + // Assert the error is sent to the beat's err chan. + result, err := fn.RecvOrTimeout(consumerErrChan, time.Second) + require.NoError(t, err) + require.ErrorIs(t, result, errDummy) + + // Assert the done channel is closed. + result, err = fn.RecvOrTimeout(done, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestNotifyBlockProcessedOnQuit asserts NotifyBlockProcessed exits +// immediately when the quit channel is closed. +func TestNotifyBlockProcessedOnQuit(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan - we don't buffer it so it will block + // on sending the error. + consumerErrChan := make(chan error) + b.errChan = consumerErrChan + + // Call the method under test. + done := make(chan error) + go func() { + defer close(done) + b.NotifyBlockProcessed(mockBeat, errDummy) + }() + + // Close the quit channel so the method will return. + close(b.quit) + + // Assert the done channel is closed. + result, err := fn.RecvOrTimeout(done, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} diff --git a/chainio/dispatcher.go b/chainio/dispatcher.go new file mode 100644 index 0000000000..87bc21fbaa --- /dev/null +++ b/chainio/dispatcher.go @@ -0,0 +1,296 @@ +package chainio + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/chainntnfs" + "golang.org/x/sync/errgroup" +) + +// DefaultProcessBlockTimeout is the timeout value used when waiting for one +// consumer to finish processing the new block epoch. +var DefaultProcessBlockTimeout = 60 * time.Second + +// ErrProcessBlockTimeout is the error returned when a consumer takes too long +// to process the block. +var ErrProcessBlockTimeout = errors.New("process block timeout") + +// BlockbeatDispatcher is a service that handles dispatching new blocks to +// `lnd`'s subsystems. During startup, subsystems that are block-driven should +// implement the `Consumer` interface and register themselves via +// `RegisterQueue`. When two subsystems are independent of each other, they +// should be registered in different queues so blocks are notified concurrently. +// Otherwise, when living in the same queue, the subsystems are notified of the +// new blocks sequentially, which means it's critical to understand the +// relationship of these systems to properly handle the order. +type BlockbeatDispatcher struct { + wg sync.WaitGroup + + // notifier is used to receive new block epochs. + notifier chainntnfs.ChainNotifier + + // beat is the latest blockbeat received. + beat Blockbeat + + // consumerQueues is a map of consumers that will receive blocks. Its + // key is a unique counter and its value is a queue of consumers. Each + // queue is notified concurrently, and consumers in the same queue is + // notified sequentially. + consumerQueues map[uint32][]Consumer + + // counter is used to assign a unique id to each queue. + counter atomic.Uint32 + + // quit is used to signal the BlockbeatDispatcher to stop. + quit chan struct{} +} + +// NewBlockbeatDispatcher returns a new blockbeat dispatcher instance. +func NewBlockbeatDispatcher(n chainntnfs.ChainNotifier) *BlockbeatDispatcher { + return &BlockbeatDispatcher{ + notifier: n, + quit: make(chan struct{}), + consumerQueues: make(map[uint32][]Consumer), + } +} + +// RegisterQueue takes a list of consumers and registers them in the same +// queue. +// +// NOTE: these consumers are notified sequentially. +func (b *BlockbeatDispatcher) RegisterQueue(consumers []Consumer) { + qid := b.counter.Add(1) + + b.consumerQueues[qid] = append(b.consumerQueues[qid], consumers...) + clog.Infof("Registered queue=%d with %d blockbeat consumers", qid, + len(consumers)) + + for _, c := range consumers { + clog.Debugf("Consumer [%s] registered in queue %d", c.Name(), + qid) + } +} + +// Start starts the blockbeat dispatcher - it registers a block notification +// and monitors and dispatches new blocks in a goroutine. It will refuse to +// start if there are no registered consumers. +func (b *BlockbeatDispatcher) Start() error { + // Make sure consumers are registered. + if len(b.consumerQueues) == 0 { + return fmt.Errorf("no consumers registered") + } + + // Start listening to new block epochs. We should get a notification + // with the current best block immediately. + blockEpochs, err := b.notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return fmt.Errorf("register block epoch ntfn: %w", err) + } + + clog.Infof("BlockbeatDispatcher is starting with %d consumer queues", + len(b.consumerQueues)) + defer clog.Debug("BlockbeatDispatcher started") + + b.wg.Add(1) + go b.dispatchBlocks(blockEpochs) + + return nil +} + +// Stop shuts down the blockbeat dispatcher. +func (b *BlockbeatDispatcher) Stop() { + clog.Info("BlockbeatDispatcher is stopping") + defer clog.Debug("BlockbeatDispatcher stopped") + + // Signal the dispatchBlocks goroutine to stop. + close(b.quit) + b.wg.Wait() +} + +func (b *BlockbeatDispatcher) log() btclog.Logger { + return b.beat.logger() +} + +// dispatchBlocks listens to new block epoch and dispatches it to all the +// consumers. Each queue is notified concurrently, and the consumers in the +// same queue are notified sequentially. +// +// NOTE: Must be run as a goroutine. +func (b *BlockbeatDispatcher) dispatchBlocks( + blockEpochs *chainntnfs.BlockEpochEvent) { + + defer b.wg.Done() + defer blockEpochs.Cancel() + + for { + select { + case blockEpoch, ok := <-blockEpochs.Epochs: + if !ok { + clog.Debugf("Block epoch channel closed") + + return + } + + clog.Infof("Received new block %v at height %d, "+ + "notifying consumers...", blockEpoch.Hash, + blockEpoch.Height) + + // Record the time it takes the consumer to process + // this block. + start := time.Now() + + // Update the current block epoch. + b.beat = NewBeat(*blockEpoch) + + // Notify all consumers. + err := b.notifyQueues() + if err != nil { + b.log().Errorf("Notify block failed: %v", err) + } + + b.log().Infof("Notified all consumers on new block "+ + "in %v", time.Since(start)) + + case <-b.quit: + b.log().Debugf("BlockbeatDispatcher quit signal " + + "received") + + return + } + } +} + +// notifyQueues notifies each queue concurrently about the latest block epoch. +func (b *BlockbeatDispatcher) notifyQueues() error { + // errChans is a map of channels that will be used to receive errors + // returned from notifying the consumers. + errChans := make(map[uint32]chan error, len(b.consumerQueues)) + + // Notify each queue in goroutines. + for qid, consumers := range b.consumerQueues { + b.log().Debugf("Notifying queue=%d with %d consumers", qid, + len(consumers)) + + // Create a signal chan. + errChan := make(chan error, 1) + errChans[qid] = errChan + + // Notify each queue concurrently. + go func(qid uint32, c []Consumer, beat Blockbeat) { + // Notify each consumer in this queue sequentially. + errChan <- DispatchSequential(beat, c) + }(qid, consumers, b.beat) + } + + // Wait for all consumers in each queue to finish. + for qid, errChan := range errChans { + select { + case err := <-errChan: + if err != nil { + return fmt.Errorf("queue=%d got err: %w", qid, + err) + } + + b.log().Debugf("Notified queue=%d", qid) + + case <-b.quit: + b.log().Debugf("BlockbeatDispatcher quit signal " + + "received, exit notifyQueues") + + return nil + } + } + + return nil +} + +// DispatchSequential takes a list of consumers and notify them about the new +// epoch sequentially. It requires the consumer to finish processing the block +// within the specified time, otherwise a timeout error is returned. +func DispatchSequential(b Blockbeat, consumers []Consumer) error { + for _, c := range consumers { + // Send the beat to the consumer. + err := notifyAndWait(b, c, DefaultProcessBlockTimeout) + if err != nil { + b.logger().Errorf("Failed to process block: %v", err) + + return err + } + } + + return nil +} + +// DispatchConcurrent notifies each consumer concurrently about the blockbeat. +// It requires the consumer to finish processing the block within the specified +// time, otherwise a timeout error is returned. +func DispatchConcurrent(b Blockbeat, consumers []Consumer) error { + eg := &errgroup.Group{} + + // Notify each queue in goroutines. + for _, c := range consumers { + // Notify each consumer concurrently. + eg.Go(func() error { + // Send the beat to the consumer. + err := notifyAndWait(b, c, DefaultProcessBlockTimeout) + + // Exit early if there's no error. + if err == nil { + return nil + } + + b.logger().Errorf("Consumer=%v failed to process "+ + "block: %v", c.Name(), err) + + return err + }) + } + + // Wait for all consumers in each queue to finish. + if err := eg.Wait(); err != nil { + return err + } + + return nil +} + +// notifyAndWait sends the blockbeat to the specified consumer. It requires the +// consumer to finish processing the block within the specified time, otherwise +// a timeout error is returned. +func notifyAndWait(b Blockbeat, c Consumer, timeout time.Duration) error { + b.logger().Debugf("Waiting for consumer[%s] to process it", c.Name()) + + // Record the time it takes the consumer to process this block. + start := time.Now() + + errChan := make(chan error, 1) + go func() { + errChan <- c.ProcessBlock(b) + }() + + // We expect the consumer to finish processing this block under 30s, + // otherwise a timeout error is returned. + select { + case err := <-errChan: + if err == nil { + break + } + + return fmt.Errorf("%s got err in ProcessBlock: %w", c.Name(), + err) + + case <-time.After(timeout): + return fmt.Errorf("consumer %s: %w", c.Name(), + ErrProcessBlockTimeout) + } + + b.logger().Debugf("Consumer[%s] processed block in %v", c.Name(), + time.Since(start)) + + return nil +} diff --git a/chainio/dispatcher_test.go b/chainio/dispatcher_test.go new file mode 100644 index 0000000000..11abbeb65e --- /dev/null +++ b/chainio/dispatcher_test.go @@ -0,0 +1,383 @@ +package chainio + +import ( + "testing" + "time" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// TestNotifyAndWaitOnConsumerErr asserts when the consumer returns an error, +// it's returned by notifyAndWait. +func TestNotifyAndWaitOnConsumerErr(t *testing.T) { + t.Parallel() + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock ProcessBlock to return an error. + consumer.On("ProcessBlock", mockBeat).Return(errDummy).Once() + + // Call the method under test. + err := notifyAndWait(mockBeat, consumer, DefaultProcessBlockTimeout) + + // We expect the error to be returned. + require.ErrorIs(t, err, errDummy) +} + +// TestNotifyAndWaitOnConsumerErr asserts when the consumer successfully +// processed the beat, no error is returned. +func TestNotifyAndWaitOnConsumerSuccess(t *testing.T) { + t.Parallel() + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock ProcessBlock to return nil. + consumer.On("ProcessBlock", mockBeat).Return(nil).Once() + + // Call the method under test. + err := notifyAndWait(mockBeat, consumer, DefaultProcessBlockTimeout) + + // We expect a nil error to be returned. + require.NoError(t, err) +} + +// TestNotifyAndWaitOnConsumerTimeout asserts when the consumer times out +// processing the block, the timeout error is returned. +func TestNotifyAndWaitOnConsumerTimeout(t *testing.T) { + t.Parallel() + + // Set timeout to be 10ms. + processBlockTimeout := 10 * time.Millisecond + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock ProcessBlock to return nil but blocks on returning. + consumer.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Sleep one second to block on the method. + time.Sleep(processBlockTimeout * 100) + }).Once() + + // Call the method under test. + err := notifyAndWait(mockBeat, consumer, processBlockTimeout) + + // We expect a timeout error to be returned. + require.ErrorIs(t, err, ErrProcessBlockTimeout) +} + +// TestDispatchSequential checks that the beat is sent to the consumers +// sequentially. +func TestDispatchSequential(t *testing.T) { + t.Parallel() + + // Create three mock consumers. + consumer1 := &MockConsumer{} + defer consumer1.AssertExpectations(t) + consumer1.On("Name").Return("mocker1") + + consumer2 := &MockConsumer{} + defer consumer2.AssertExpectations(t) + consumer2.On("Name").Return("mocker2") + + consumer3 := &MockConsumer{} + defer consumer3.AssertExpectations(t) + consumer3.On("Name").Return("mocker3") + + consumers := []Consumer{consumer1, consumer2, consumer3} + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // prevConsumer specifies the previous consumer that was called. + var prevConsumer string + + // Mock the ProcessBlock on consumers to reutrn immediately. + consumer1.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Check the order of the consumers. + // + // The first consumer should have no previous consumer. + require.Empty(t, prevConsumer) + + // Set the consumer as the previous consumer. + prevConsumer = consumer1.Name() + }).Once() + + consumer2.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Check the order of the consumers. + // + // The second consumer should see consumer1. + require.Equal(t, consumer1.Name(), prevConsumer) + + // Set the consumer as the previous consumer. + prevConsumer = consumer2.Name() + }).Once() + + consumer3.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Check the order of the consumers. + // + // The third consumer should see consumer2. + require.Equal(t, consumer2.Name(), prevConsumer) + + // Set the consumer as the previous consumer. + prevConsumer = consumer3.Name() + }).Once() + + // Call the method under test. + err := DispatchSequential(mockBeat, consumers) + require.NoError(t, err) + + // Check the previous consumer is the last consumer. + require.Equal(t, consumer3.Name(), prevConsumer) +} + +// TestRegisterQueue tests the RegisterQueue function. +func TestRegisterQueue(t *testing.T) { + t.Parallel() + + // Create two mock consumers. + consumer1 := &MockConsumer{} + defer consumer1.AssertExpectations(t) + consumer1.On("Name").Return("mocker1") + + consumer2 := &MockConsumer{} + defer consumer2.AssertExpectations(t) + consumer2.On("Name").Return("mocker2") + + consumers := []Consumer{consumer1, consumer2} + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Register the consumers. + b.RegisterQueue(consumers) + + // Assert that the consumers have been registered. + // + // We should have one queue. + require.Len(t, b.consumerQueues, 1) + + // The queue should have two consumers. + queue, ok := b.consumerQueues[1] + require.True(t, ok) + require.Len(t, queue, 2) +} + +// TestStartDispatcher tests the Start method. +func TestStartDispatcher(t *testing.T) { + t.Parallel() + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Start the dispatcher without consumers should return an error. + err := b.Start() + require.Error(t, err) + + // Create a consumer and register it. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker1") + b.RegisterQueue([]Consumer{consumer}) + + // Mock the chain notifier to return an error. + mockNotifier.On("RegisterBlockEpochNtfn", + mock.Anything).Return(nil, errDummy).Once() + + // Start the dispatcher now should return the error. + err = b.Start() + require.ErrorIs(t, err, errDummy) + + // Mock the chain notifier to return a valid notifier. + blockEpochs := &chainntnfs.BlockEpochEvent{} + mockNotifier.On("RegisterBlockEpochNtfn", + mock.Anything).Return(blockEpochs, nil).Once() + + // Start the dispatcher now should not return an error. + err = b.Start() + require.NoError(t, err) +} + +// TestDispatchBlocks asserts the blocks are properly dispatched to the queues. +func TestDispatchBlocks(t *testing.T) { + t.Parallel() + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Create the beat and attach it to the dispatcher. + epoch := chainntnfs.BlockEpoch{Height: 1} + beat := NewBeat(epoch) + b.beat = beat + + // Create a consumer and register it. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker1") + b.RegisterQueue([]Consumer{consumer}) + + // Mock the consumer to return nil error on ProcessBlock. This + // implictly asserts that the step `notifyQueues` is successfully + // reached in the `dispatchBlocks` method. + consumer.On("ProcessBlock", mock.Anything).Return(nil).Once() + + // Create a test epoch chan. + epochChan := make(chan *chainntnfs.BlockEpoch, 1) + blockEpochs := &chainntnfs.BlockEpochEvent{ + Epochs: epochChan, + Cancel: func() {}, + } + + // Call the method in a goroutine. + done := make(chan struct{}) + b.wg.Add(1) + go func() { + defer close(done) + b.dispatchBlocks(blockEpochs) + }() + + // Send an epoch. + epoch = chainntnfs.BlockEpoch{Height: 2} + epochChan <- &epoch + + // Wait for the dispatcher to process the epoch. + time.Sleep(100 * time.Millisecond) + + // Stop the dispatcher. + b.Stop() + + // We expect the dispatcher to stop immediately. + _, err := fn.RecvOrTimeout(done, time.Second) + require.NoError(t, err) +} + +// TestNotifyQueuesSuccess checks when the dispatcher successfully notifies all +// the queues, no error is returned. +func TestNotifyQueuesSuccess(t *testing.T) { + t.Parallel() + + // Create two mock consumers. + consumer1 := &MockConsumer{} + defer consumer1.AssertExpectations(t) + consumer1.On("Name").Return("mocker1") + + consumer2 := &MockConsumer{} + defer consumer2.AssertExpectations(t) + consumer2.On("Name").Return("mocker2") + + // Create two queues. + queue1 := []Consumer{consumer1} + queue2 := []Consumer{consumer2} + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Register the queues. + b.RegisterQueue(queue1) + b.RegisterQueue(queue2) + + // Attach the blockbeat. + b.beat = mockBeat + + // Mock the consumers to return nil error on ProcessBlock for + // both calls. + consumer1.On("ProcessBlock", mockBeat).Return(nil).Once() + consumer2.On("ProcessBlock", mockBeat).Return(nil).Once() + + // Notify the queues. The mockers will be asserted in the end to + // validate the calls. + err := b.notifyQueues() + require.NoError(t, err) +} + +// TestNotifyQueuesError checks when one of the queue returns an error, this +// error is returned by the method. +func TestNotifyQueuesError(t *testing.T) { + t.Parallel() + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker1") + + // Create one queue. + queue := []Consumer{consumer} + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Register the queues. + b.RegisterQueue(queue) + + // Attach the blockbeat. + b.beat = mockBeat + + // Mock the consumer to return an error on ProcessBlock. + consumer.On("ProcessBlock", mockBeat).Return(errDummy).Once() + + // Notify the queues. The mockers will be asserted in the end to + // validate the calls. + err := b.notifyQueues() + require.ErrorIs(t, err, errDummy) +} diff --git a/chainio/interface.go b/chainio/interface.go new file mode 100644 index 0000000000..03c09faf7c --- /dev/null +++ b/chainio/interface.go @@ -0,0 +1,53 @@ +package chainio + +import "github.com/btcsuite/btclog/v2" + +// Blockbeat defines an interface that can be used by subsystems to retrieve +// block data. It is sent by the BlockbeatDispatcher to all the registered +// consumers whenever a new block is received. Once the consumer finishes +// processing the block, it must signal it by calling `NotifyBlockProcessed`. +// +// The blockchain is a state machine - whenever there's a state change, it's +// manifested in a block. The blockbeat is a way to notify subsystems of this +// state change, and to provide them with the data they need to process it. In +// other words, subsystems must react to this state change and should consider +// being driven by the blockbeat in their own state machines. +type Blockbeat interface { + // blockbeat is a private interface that's only used in this package. + blockbeat + + // Height returns the current block height. + Height() int32 +} + +// blockbeat defines a set of private methods used in this package to make +// interaction with the blockbeat easier. +type blockbeat interface { + // logger returns the internal logger used by the blockbeat which has a + // block height prefix. + logger() btclog.Logger +} + +// Consumer defines a blockbeat consumer interface. Subsystems that need block +// info must implement it. +type Consumer interface { + // TODO(yy): We should also define the start methods used by the + // consumers such that when implementing the interface, the consumer + // will always be started with a blockbeat. This cannot be enforced at + // the moment as we need refactor all the start methods to only take a + // beat. + // + // Start(beat Blockbeat) error + + // Name returns a human-readable string for this subsystem. + Name() string + + // ProcessBlock takes a blockbeat and processes it. It should not + // return until the subsystem has updated its state based on the block + // data. + // + // NOTE: The consumer must try its best to NOT return an error. If an + // error is returned from processing the block, it means the subsystem + // cannot react to onchain state changes and lnd will shutdown. + ProcessBlock(b Blockbeat) error +} diff --git a/chainio/log.go b/chainio/log.go new file mode 100644 index 0000000000..2d8c26f7a5 --- /dev/null +++ b/chainio/log.go @@ -0,0 +1,32 @@ +package chainio + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/build" +) + +// Subsystem defines the logging code for this subsystem. +const Subsystem = "CHIO" + +// clog is a logger that is initialized with no output filters. This means the +// package will not perform any logging by default until the caller requests +// it. +var clog btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// DisableLog disables all library log output. Logging output is disabled by +// default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. This +// should be used in preference to SetLogWriter if the caller is also using +// btclog. +func UseLogger(logger btclog.Logger) { + clog = logger +} diff --git a/chainio/mocks.go b/chainio/mocks.go new file mode 100644 index 0000000000..5677734e1d --- /dev/null +++ b/chainio/mocks.go @@ -0,0 +1,50 @@ +package chainio + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/stretchr/testify/mock" +) + +// MockConsumer is a mock implementation of the Consumer interface. +type MockConsumer struct { + mock.Mock +} + +// Compile-time constraint to ensure MockConsumer implements Consumer. +var _ Consumer = (*MockConsumer)(nil) + +// Name returns a human-readable string for this subsystem. +func (m *MockConsumer) Name() string { + args := m.Called() + return args.String(0) +} + +// ProcessBlock takes a blockbeat and processes it. A receive-only error chan +// must be returned. +func (m *MockConsumer) ProcessBlock(b Blockbeat) error { + args := m.Called(b) + + return args.Error(0) +} + +// MockBlockbeat is a mock implementation of the Blockbeat interface. +type MockBlockbeat struct { + mock.Mock +} + +// Compile-time constraint to ensure MockBlockbeat implements Blockbeat. +var _ Blockbeat = (*MockBlockbeat)(nil) + +// Height returns the current block height. +func (m *MockBlockbeat) Height() int32 { + args := m.Called() + + return args.Get(0).(int32) +} + +// logger returns the logger for the blockbeat. +func (m *MockBlockbeat) logger() btclog.Logger { + args := m.Called() + + return args.Get(0).(btclog.Logger) +} diff --git a/chainntnfs/bitcoindnotify/bitcoind.go b/chainntnfs/bitcoindnotify/bitcoind.go index fc20fbb857..59c03d5171 100644 --- a/chainntnfs/bitcoindnotify/bitcoind.go +++ b/chainntnfs/bitcoindnotify/bitcoind.go @@ -15,7 +15,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/blockcache" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/queue" ) diff --git a/chainntnfs/btcdnotify/btcd.go b/chainntnfs/btcdnotify/btcd.go index c3a40a00bf..e3bff289cf 100644 --- a/chainntnfs/btcdnotify/btcd.go +++ b/chainntnfs/btcdnotify/btcd.go @@ -17,7 +17,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/blockcache" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/queue" ) diff --git a/chainntnfs/interface.go b/chainntnfs/interface.go index b2383636aa..1b8a5acb50 100644 --- a/chainntnfs/interface.go +++ b/chainntnfs/interface.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) var ( diff --git a/chainntnfs/mocks.go b/chainntnfs/mocks.go index d9ab9928d0..4a888b162e 100644 --- a/chainntnfs/mocks.go +++ b/chainntnfs/mocks.go @@ -3,7 +3,7 @@ package chainntnfs import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/stretchr/testify/mock" ) diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index edc422482e..a9a9ede704 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -23,7 +23,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs/btcdnotify" "github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" diff --git a/chanbackup/backup.go b/chanbackup/backup.go index 5853b37e45..afffe5a2e8 100644 --- a/chanbackup/backup.go +++ b/chanbackup/backup.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) // LiveChannelSource is an interface that allows us to query for the set of diff --git a/chanbackup/single.go b/chanbackup/single.go index b741320b07..01d14f6c07 100644 --- a/chanbackup/single.go +++ b/chanbackup/single.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnencrypt" "github.com/lightningnetwork/lnd/lnwire" diff --git a/chanbackup/single_test.go b/chanbackup/single_test.go index d2212bd859..0fe402926d 100644 --- a/chanbackup/single_test.go +++ b/chanbackup/single_test.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnencrypt" "github.com/lightningnetwork/lnd/lnwire" diff --git a/channeldb/channel.go b/channeldb/channel.go index 9ca57312aa..f4e99a6f8c 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -19,7 +19,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/walletdb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 2cac0baced..b1ca100eb3 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -18,7 +18,7 @@ import ( _ "github.com/btcsuite/btcwallet/walletdb/bdb" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" diff --git a/channeldb/migration/lnwire21/custom_records.go b/channeldb/migration/lnwire21/custom_records.go index f0f59185e9..7771c8ec8b 100644 --- a/channeldb/migration/lnwire21/custom_records.go +++ b/channeldb/migration/lnwire21/custom_records.go @@ -6,7 +6,7 @@ import ( "io" "sort" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) @@ -163,9 +163,12 @@ func (c CustomRecords) SerializeTo(w io.Writer) error { // ProduceRecordsSorted converts a slice of record producers into a slice of // records and then sorts it by type. func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record { - records := fn.Map(func(producer tlv.RecordProducer) tlv.Record { - return producer.Record() - }, recordProducers) + records := fn.Map( + recordProducers, + func(producer tlv.RecordProducer) tlv.Record { + return producer.Record() + }, + ) // Ensure that the set of records are sorted before we attempt to // decode from the stream, to ensure they're canonical. @@ -196,9 +199,9 @@ func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record { // RecordsAsProducers converts a slice of records into a slice of record // producers. func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer { - return fn.Map(func(record tlv.Record) tlv.RecordProducer { + return fn.Map(records, func(record tlv.Record) tlv.RecordProducer { return &record - }, records) + }) } // EncodeRecords encodes the given records into a byte slice. diff --git a/channeldb/migration32/mission_control_store.go b/channeldb/migration32/mission_control_store.go index 3ac9d6114c..76463eb6ca 100644 --- a/channeldb/migration32/mission_control_store.go +++ b/channeldb/migration32/mission_control_store.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/wire" lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) @@ -371,7 +371,7 @@ func extractMCRoute(r *Route) *mcRoute { // extractMCHops extracts the Hop fields that MC actually uses from a slice of // Hops. func extractMCHops(hops []*Hop) mcHops { - return fn.Map(extractMCHop, hops) + return fn.Map(hops, extractMCHop) } // extractMCHop extracts the Hop fields that MC actually uses from a Hop. diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index 3abc73f81e..ea6eaf13f2 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -7,7 +7,7 @@ import ( "math" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index 4290552eee..2df6627e2c 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lnwire" diff --git a/cmd/commands/cmd_macaroon.go b/cmd/commands/cmd_macaroon.go index 15c29380a7..d7d6d5f9dc 100644 --- a/cmd/commands/cmd_macaroon.go +++ b/cmd/commands/cmd_macaroon.go @@ -10,7 +10,7 @@ import ( "strings" "unicode" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/macaroons" @@ -177,12 +177,15 @@ func bakeMacaroon(ctx *cli.Context) error { "%w", err) } - ops := fn.Map(func(p *lnrpc.MacaroonPermission) bakery.Op { - return bakery.Op{ - Entity: p.Entity, - Action: p.Action, - } - }, parsedPermissions) + ops := fn.Map( + parsedPermissions, + func(p *lnrpc.MacaroonPermission) bakery.Op { + return bakery.Op{ + Entity: p.Entity, + Action: p.Action, + } + }, + ) rawMacaroon, err = macaroons.BakeFromRootKey(macRootKey, ops) if err != nil { diff --git a/config_builder.go b/config_builder.go index 42650bb68b..afafbed754 100644 --- a/config_builder.go +++ b/config_builder.go @@ -33,9 +33,10 @@ import ( "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -47,7 +48,6 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/rpcwallet" "github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/msgmux" - "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sqldb" @@ -166,7 +166,7 @@ type AuxComponents struct { // TrafficShaper is an optional traffic shaper that can be used to // control the outgoing channel of a payment. - TrafficShaper fn.Option[routing.TlvTrafficShaper] + TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] // MsgRouter is an optional message router that if set will be used in // place of a new blank default message router. diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index b4d6877202..eb42ad3cb4 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -2,6 +2,7 @@ package contractcourt import ( "errors" + "fmt" "io" "sync" @@ -9,7 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/sweep" ) @@ -23,9 +24,6 @@ type anchorResolver struct { // anchor is the outpoint on the commitment transaction. anchor wire.OutPoint - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // broadcastHeight is the height that the original contract was // broadcast to the main-chain at. We'll use this value to bound any // historical queries to the chain for spends/confirmations. @@ -71,7 +69,7 @@ func newAnchorResolver(anchorSignDescriptor input.SignDescriptor, currentReport: report, } - r.initLogger(r) + r.initLogger(fmt.Sprintf("%T(%v)", r, r.anchor)) return r } @@ -83,49 +81,12 @@ func (c *anchorResolver) ResolverKey() []byte { return nil } -// Resolve offers the anchor output to the sweeper and waits for it to be swept. -func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) { - // Attempt to update the sweep parameters to the post-confirmation - // situation. We don't want to force sweep anymore, because the anchor - // lost its special purpose to get the commitment confirmed. It is just - // an output that we want to sweep only if it is economical to do so. - // - // An exclusive group is not necessary anymore, because we know that - // this is the only anchor that can be swept. - // - // We also clear the parent tx information for cpfp, because the - // commitment tx is confirmed. - // - // After a restart or when the remote force closes, the sweeper is not - // yet aware of the anchor. In that case, it will be added as new input - // to the sweeper. - witnessType := input.CommitmentAnchor - - // For taproot channels, we need to use the proper witness type. - if c.chanType.IsTaproot() { - witnessType = input.TaprootAnchorSweepSpend - } - - anchorInput := input.MakeBaseInput( - &c.anchor, witnessType, &c.anchorSignDescriptor, - c.broadcastHeight, nil, - ) - - resultChan, err := c.Sweeper.SweepInput( - &anchorInput, - sweep.Params{ - // For normal anchor sweeping, the budget is 330 sats. - Budget: btcutil.Amount( - anchorInput.SignDesc().Output.Value, - ), - - // There's no rush to sweep the anchor, so we use a nil - // deadline here. - DeadlineHeight: fn.None[int32](), - }, - ) - if err != nil { - return nil, err +// Resolve waits for the output to be swept. +func (c *anchorResolver) Resolve() (ContractResolver, error) { + // If we're already resolved, then we can exit early. + if c.IsResolved() { + c.log.Errorf("already resolved") + return nil, nil } var ( @@ -134,7 +95,7 @@ func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) { ) select { - case sweepRes := <-resultChan: + case sweepRes := <-c.sweepResultChan: switch sweepRes.Err { // Anchor was swept successfully. case nil: @@ -160,6 +121,8 @@ func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) { return nil, errResolverShuttingDown } + c.log.Infof("resolved in tx %v", spendTx) + // Update report to reflect that funds are no longer in limbo. c.reportLock.Lock() if outcome == channeldb.ResolverOutcomeClaimed { @@ -171,7 +134,7 @@ func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) { ) c.reportLock.Unlock() - c.resolved = true + c.markResolved() return nil, c.PutResolverReport(nil, report) } @@ -180,15 +143,10 @@ func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) { // // NOTE: Part of the ContractResolver interface. func (c *anchorResolver) Stop() { - close(c.quit) -} + c.log.Debugf("stopping...") + defer c.log.Debugf("stopped") -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (c *anchorResolver) IsResolved() bool { - return c.resolved + close(c.quit) } // SupplementState allows the user of a ContractResolver to supplement it with @@ -215,3 +173,68 @@ func (c *anchorResolver) Encode(w io.Writer) error { // A compile time assertion to ensure anchorResolver meets the // ContractResolver interface. var _ ContractResolver = (*anchorResolver)(nil) + +// Launch offers the anchor output to the sweeper. +func (c *anchorResolver) Launch() error { + if c.isLaunched() { + c.log.Tracef("already launched") + return nil + } + + c.log.Debugf("launching resolver...") + c.markLaunched() + + // If we're already resolved, then we can exit early. + if c.IsResolved() { + c.log.Errorf("already resolved") + return nil + } + + // Attempt to update the sweep parameters to the post-confirmation + // situation. We don't want to force sweep anymore, because the anchor + // lost its special purpose to get the commitment confirmed. It is just + // an output that we want to sweep only if it is economical to do so. + // + // An exclusive group is not necessary anymore, because we know that + // this is the only anchor that can be swept. + // + // We also clear the parent tx information for cpfp, because the + // commitment tx is confirmed. + // + // After a restart or when the remote force closes, the sweeper is not + // yet aware of the anchor. In that case, it will be added as new input + // to the sweeper. + witnessType := input.CommitmentAnchor + + // For taproot channels, we need to use the proper witness type. + if c.chanType.IsTaproot() { + witnessType = input.TaprootAnchorSweepSpend + } + + anchorInput := input.MakeBaseInput( + &c.anchor, witnessType, &c.anchorSignDescriptor, + c.broadcastHeight, nil, + ) + + resultChan, err := c.Sweeper.SweepInput( + &anchorInput, + sweep.Params{ + // For normal anchor sweeping, the budget is 330 sats. + Budget: btcutil.Amount( + anchorInput.SignDesc().Output.Value, + ), + + // There's no rush to sweep the anchor, so we use a nil + // deadline here. + DeadlineHeight: fn.None[int32](), + }, + ) + + if err != nil { + return err + } + + c.sweepResultChan = resultChan + + return nil +} diff --git a/contractcourt/breach_arbitrator.go b/contractcourt/breach_arbitrator.go index d59829b5e5..33bc7f7e33 100644 --- a/contractcourt/breach_arbitrator.go +++ b/contractcourt/breach_arbitrator.go @@ -15,7 +15,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" @@ -1537,9 +1537,9 @@ func (b *BreachArbitrator) createSweepTx( // outputs from the regular, BTC only outputs. So we only need one such // output, which'll carry the custom channel "valuables" from both the // breached commitment and HTLC outputs. - hasBlobs := fn.Any(func(i input.Input) bool { + hasBlobs := fn.Any(inputs, func(i input.Input) bool { return i.ResolutionBlob().IsSome() - }, inputs) + }) if hasBlobs { weightEstimate.AddP2TROutput() } @@ -1624,7 +1624,7 @@ func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit, // First, we'll add the extra sweep output if it exists, subtracting the // amount from the sweep amt. if b.cfg.AuxSweeper.IsSome() { - extraChangeOut.WhenResult(func(o sweep.SweepOutput) { + extraChangeOut.WhenOk(func(o sweep.SweepOutput) { sweepAmt -= o.Value txn.AddTxOut(&o.TxOut) @@ -1697,7 +1697,7 @@ func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit, return &justiceTxCtx{ justiceTx: txn, sweepAddr: pkScript, - extraTxOut: extraChangeOut.Option(), + extraTxOut: extraChangeOut.OkToSome(), fee: txFee, inputs: inputs, }, nil diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index 576009eda4..99ed852696 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -22,7 +22,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/channels" @@ -36,7 +36,7 @@ import ( ) var ( - defaultTimeout = 30 * time.Second + defaultTimeout = 10 * time.Second breachOutPoints = []wire.OutPoint{ { diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index 740b4471d5..5644e60fad 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -2,6 +2,7 @@ package contractcourt import ( "encoding/binary" + "fmt" "io" "github.com/lightningnetwork/lnd/channeldb" @@ -11,9 +12,6 @@ import ( // future, this will likely take over the duties the current BreachArbitrator // has. type breachResolver struct { - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // subscribed denotes whether or not the breach resolver has subscribed // to the BreachArbitrator for breach resolution. subscribed bool @@ -32,7 +30,7 @@ func newBreachResolver(resCfg ResolverConfig) *breachResolver { replyChan: make(chan struct{}), } - r.initLogger(r) + r.initLogger(fmt.Sprintf("%T(%v)", r, r.ChanPoint)) return r } @@ -47,7 +45,7 @@ func (b *breachResolver) ResolverKey() []byte { // been broadcast. // // TODO(yy): let sweeper handle the breach inputs. -func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) { +func (b *breachResolver) Resolve() (ContractResolver, error) { if !b.subscribed { complete, err := b.SubscribeBreachComplete( &b.ChanPoint, b.replyChan, @@ -59,7 +57,7 @@ func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) { // If the breach resolution process is already complete, then // we can cleanup and checkpoint the resolved state. if complete { - b.resolved = true + b.markResolved() return nil, b.Checkpoint(b) } @@ -72,8 +70,9 @@ func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) { // The replyChan has been closed, signalling that the breach // has been fully resolved. Checkpoint the resolved state and // exit. - b.resolved = true + b.markResolved() return nil, b.Checkpoint(b) + case <-b.quit: } @@ -82,22 +81,17 @@ func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) { // Stop signals the breachResolver to stop. func (b *breachResolver) Stop() { + b.log.Debugf("stopping...") close(b.quit) } -// IsResolved returns true if the breachResolver is fully resolved and cleanup -// can occur. -func (b *breachResolver) IsResolved() bool { - return b.resolved -} - // SupplementState adds additional state to the breachResolver. func (b *breachResolver) SupplementState(_ *channeldb.OpenChannel) { } // Encode encodes the breachResolver to the passed writer. func (b *breachResolver) Encode(w io.Writer) error { - return binary.Write(w, endian, b.resolved) + return binary.Write(w, endian, b.IsResolved()) } // newBreachResolverFromReader attempts to decode an encoded breachResolver @@ -110,11 +104,15 @@ func newBreachResolverFromReader(r io.Reader, resCfg ResolverConfig) ( replyChan: make(chan struct{}), } - if err := binary.Read(r, endian, &b.resolved); err != nil { + var resolved bool + if err := binary.Read(r, endian, &resolved); err != nil { return nil, err } + if resolved { + b.markResolved() + } - b.initLogger(b) + b.initLogger(fmt.Sprintf("%T(%v)", b, b.ChanPoint)) return b, nil } @@ -122,3 +120,16 @@ func newBreachResolverFromReader(r io.Reader, resCfg ResolverConfig) ( // A compile time assertion to ensure breachResolver meets the ContractResolver // interface. var _ ContractResolver = (*breachResolver)(nil) + +// TODO(yy): implement it once the outputs are offered to the sweeper. +func (b *breachResolver) Launch() error { + if b.isLaunched() { + b.log.Tracef("already launched") + return nil + } + + b.log.Debugf("launching resolver...") + b.markLaunched() + + return nil +} diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go index a0908ea3fa..7d199c5c28 100644 --- a/contractcourt/briefcase.go +++ b/contractcourt/briefcase.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go index 0f44db2abb..aa2e711efc 100644 --- a/contractcourt/briefcase_test.go +++ b/contractcourt/briefcase_test.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" @@ -206,8 +206,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver, ogRes.outputIncubating, diskRes.outputIncubating) } if ogRes.resolved != diskRes.resolved { - t.Fatalf("expected %v, got %v", ogRes.resolved, - diskRes.resolved) + t.Fatalf("expected %v, got %v", ogRes.resolved.Load(), + diskRes.resolved.Load()) } if ogRes.broadcastHeight != diskRes.broadcastHeight { t.Fatalf("expected %v, got %v", @@ -229,8 +229,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver, ogRes.outputIncubating, diskRes.outputIncubating) } if ogRes.resolved != diskRes.resolved { - t.Fatalf("expected %v, got %v", ogRes.resolved, - diskRes.resolved) + t.Fatalf("expected %v, got %v", ogRes.resolved.Load(), + diskRes.resolved.Load()) } if ogRes.broadcastHeight != diskRes.broadcastHeight { t.Fatalf("expected %v, got %v", @@ -275,8 +275,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver, ogRes.commitResolution, diskRes.commitResolution) } if ogRes.resolved != diskRes.resolved { - t.Fatalf("expected %v, got %v", ogRes.resolved, - diskRes.resolved) + t.Fatalf("expected %v, got %v", ogRes.resolved.Load(), + diskRes.resolved.Load()) } if ogRes.broadcastHeight != diskRes.broadcastHeight { t.Fatalf("expected %v, got %v", @@ -312,13 +312,14 @@ func TestContractInsertionRetrieval(t *testing.T) { SweepSignDesc: testSignDesc, }, outputIncubating: true, - resolved: true, broadcastHeight: 102, htlc: channeldb.HTLC{ HtlcIndex: 12, }, } - successResolver := htlcSuccessResolver{ + timeoutResolver.resolved.Store(true) + + successResolver := &htlcSuccessResolver{ htlcResolution: lnwallet.IncomingHtlcResolution{ Preimage: testPreimage, SignedSuccessTx: nil, @@ -327,40 +328,49 @@ func TestContractInsertionRetrieval(t *testing.T) { SweepSignDesc: testSignDesc, }, outputIncubating: true, - resolved: true, broadcastHeight: 109, htlc: channeldb.HTLC{ RHash: testPreimage, }, } - resolvers := []ContractResolver{ - &timeoutResolver, - &successResolver, - &commitSweepResolver{ - commitResolution: lnwallet.CommitOutputResolution{ - SelfOutPoint: testChanPoint2, - SelfOutputSignDesc: testSignDesc, - MaturityDelay: 99, - }, - resolved: false, - broadcastHeight: 109, - chanPoint: testChanPoint1, + successResolver.resolved.Store(true) + + commitResolver := &commitSweepResolver{ + commitResolution: lnwallet.CommitOutputResolution{ + SelfOutPoint: testChanPoint2, + SelfOutputSignDesc: testSignDesc, + MaturityDelay: 99, }, + broadcastHeight: 109, + chanPoint: testChanPoint1, + } + commitResolver.resolved.Store(false) + + resolvers := []ContractResolver{ + &timeoutResolver, successResolver, commitResolver, } // All resolvers require a unique ResolverKey() output. To achieve this // for the composite resolvers, we'll mutate the underlying resolver // with a new outpoint. - contestTimeout := timeoutResolver - contestTimeout.htlcResolution.ClaimOutpoint = randOutPoint() + contestTimeout := htlcTimeoutResolver{ + htlcResolution: lnwallet.OutgoingHtlcResolution{ + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + } resolvers = append(resolvers, &htlcOutgoingContestResolver{ htlcTimeoutResolver: &contestTimeout, }) - contestSuccess := successResolver - contestSuccess.htlcResolution.ClaimOutpoint = randOutPoint() + contestSuccess := &htlcSuccessResolver{ + htlcResolution: lnwallet.IncomingHtlcResolution{ + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + } resolvers = append(resolvers, &htlcIncomingContestResolver{ htlcExpiry: 100, - htlcSuccessResolver: &contestSuccess, + htlcSuccessResolver: contestSuccess, }) // For quick lookup during the test, we'll create this map which allow @@ -438,12 +448,12 @@ func TestContractResolution(t *testing.T) { SweepSignDesc: testSignDesc, }, outputIncubating: true, - resolved: true, broadcastHeight: 192, htlc: channeldb.HTLC{ HtlcIndex: 9912, }, } + timeoutResolver.resolved.Store(true) // First, we'll insert the resolver into the database and ensure that // we get the same resolver out the other side. We do not need to apply @@ -491,12 +501,13 @@ func TestContractSwapping(t *testing.T) { SweepSignDesc: testSignDesc, }, outputIncubating: true, - resolved: true, broadcastHeight: 102, htlc: channeldb.HTLC{ HtlcIndex: 12, }, } + timeoutResolver.resolved.Store(true) + contestResolver := &htlcOutgoingContestResolver{ htlcTimeoutResolver: timeoutResolver, } diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 6d9b30d208..011b5225cd 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -11,10 +11,11 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/walletdb" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" @@ -244,6 +245,10 @@ type ChainArbitrator struct { started int32 // To be used atomically. stopped int32 // To be used atomically. + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + sync.Mutex // activeChannels is a map of all the active contracts that are still @@ -262,6 +267,9 @@ type ChainArbitrator struct { // active channels that it must still watch over. chanSource *channeldb.DB + // beat is the current best known blockbeat. + beat chainio.Blockbeat + quit chan struct{} wg sync.WaitGroup @@ -272,15 +280,23 @@ type ChainArbitrator struct { func NewChainArbitrator(cfg ChainArbitratorConfig, db *channeldb.DB) *ChainArbitrator { - return &ChainArbitrator{ + c := &ChainArbitrator{ cfg: cfg, activeChannels: make(map[wire.OutPoint]*ChannelArbitrator), activeWatchers: make(map[wire.OutPoint]*chainWatcher), chanSource: db, quit: make(chan struct{}), } + + // Mount the block consumer. + c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name()) + + return c } +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*ChainArbitrator)(nil) + // arbChannel is a wrapper around an open channel that channel arbitrators // interact with. type arbChannel struct { @@ -554,147 +570,30 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error { } // Start launches all goroutines that the ChainArbitrator needs to operate. -func (c *ChainArbitrator) Start() error { +func (c *ChainArbitrator) Start(beat chainio.Blockbeat) error { if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { return nil } - log.Infof("ChainArbitrator starting with config: budget=[%v]", - &c.cfg.Budget) + // Set the current beat. + c.beat = beat + + log.Infof("ChainArbitrator starting at height %d with budget=[%v]", + &c.cfg.Budget, c.beat.Height()) // First, we'll fetch all the channels that are still open, in order to // collect them within our set of active contracts. - openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels() - if err != nil { + if err := c.loadOpenChannels(); err != nil { return err } - if len(openChannels) > 0 { - log.Infof("Creating ChannelArbitrators for %v active channels", - len(openChannels)) - } - - // For each open channel, we'll configure then launch a corresponding - // ChannelArbitrator. - for _, channel := range openChannels { - chanPoint := channel.FundingOutpoint - channel := channel - - // First, we'll create an active chainWatcher for this channel - // to ensure that we detect any relevant on chain events. - breachClosure := func(ret *lnwallet.BreachRetribution) error { - return c.cfg.ContractBreach(chanPoint, ret) - } - - chainWatcher, err := newChainWatcher( - chainWatcherConfig{ - chanState: channel, - notifier: c.cfg.Notifier, - signer: c.cfg.Signer, - isOurAddr: c.cfg.IsOurAddress, - contractBreach: breachClosure, - extractStateNumHint: lnwallet.GetStateNumHint, - auxLeafStore: c.cfg.AuxLeafStore, - auxResolver: c.cfg.AuxResolver, - }, - ) - if err != nil { - return err - } - - c.activeWatchers[chanPoint] = chainWatcher - channelArb, err := newActiveChannelArbitrator( - channel, c, chainWatcher.SubscribeChannelEvents(), - ) - if err != nil { - return err - } - - c.activeChannels[chanPoint] = channelArb - - // Republish any closing transactions for this channel. - err = c.republishClosingTxs(channel) - if err != nil { - log.Errorf("Failed to republish closing txs for "+ - "channel %v", chanPoint) - } - } - // In addition to the channels that we know to be open, we'll also // launch arbitrators to finishing resolving any channels that are in // the pending close state. - closingChannels, err := c.chanSource.ChannelStateDB().FetchClosedChannels( - true, - ) - if err != nil { + if err := c.loadPendingCloseChannels(); err != nil { return err } - if len(closingChannels) > 0 { - log.Infof("Creating ChannelArbitrators for %v closing channels", - len(closingChannels)) - } - - // Next, for each channel is the closing state, we'll launch a - // corresponding more restricted resolver, as we don't have to watch - // the chain any longer, only resolve the contracts on the confirmed - // commitment. - //nolint:ll - for _, closeChanInfo := range closingChannels { - // We can leave off the CloseContract and ForceCloseChan - // methods as the channel is already closed at this point. - chanPoint := closeChanInfo.ChanPoint - arbCfg := ChannelArbitratorConfig{ - ChanPoint: chanPoint, - ShortChanID: closeChanInfo.ShortChanID, - ChainArbitratorConfig: c.cfg, - ChainEvents: &ChainEventSubscription{}, - IsPendingClose: true, - ClosingHeight: closeChanInfo.CloseHeight, - CloseType: closeChanInfo.CloseType, - PutResolverReport: func(tx kvdb.RwTx, - report *channeldb.ResolverReport) error { - - return c.chanSource.PutResolverReport( - tx, c.cfg.ChainHash, &chanPoint, report, - ) - }, - FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) { - chanStateDB := c.chanSource.ChannelStateDB() - return chanStateDB.FetchHistoricalChannel(&chanPoint) - }, - FindOutgoingHTLCDeadline: func( - htlc channeldb.HTLC) fn.Option[int32] { - - return c.FindOutgoingHTLCDeadline( - closeChanInfo.ShortChanID, htlc, - ) - }, - } - chanLog, err := newBoltArbitratorLog( - c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint, - ) - if err != nil { - return err - } - arbCfg.MarkChannelResolved = func() error { - if c.cfg.NotifyFullyResolvedChannel != nil { - c.cfg.NotifyFullyResolvedChannel(chanPoint) - } - - return c.ResolveContract(chanPoint) - } - - // We create an empty map of HTLC's here since it's possible - // that the channel is in StateDefault and updateActiveHTLCs is - // called. We want to avoid writing to an empty map. Since the - // channel is already in the process of being resolved, no new - // HTLCs will be added. - c.activeChannels[chanPoint] = NewChannelArbitrator( - arbCfg, make(map[HtlcSetKey]htlcSet), chanLog, - ) - } - // Now, we'll start all chain watchers in parallel to shorten start up // duration. In neutrino mode, this allows spend registrations to take // advantage of batch spend reporting, instead of doing a single rescan @@ -746,7 +645,7 @@ func (c *ChainArbitrator) Start() error { // transaction. var startStates map[wire.OutPoint]*chanArbStartState - err = kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error { + err := kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error { for _, arbitrator := range c.activeChannels { startState, err := arbitrator.getStartState(tx) if err != nil { @@ -778,24 +677,17 @@ func (c *ChainArbitrator) Start() error { arbitrator.cfg.ChanPoint) } - if err := arbitrator.Start(startState); err != nil { + if err := arbitrator.Start(startState, c.beat); err != nil { stopAndLog() return err } } - // Subscribe to a single stream of block epoch notifications that we - // will dispatch to all active arbitrators. - blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - // Start our goroutine which will dispatch blocks to each arbitrator. c.wg.Add(1) go func() { defer c.wg.Done() - c.dispatchBlocks(blockEpoch) + c.dispatchBlocks() }() // TODO(roasbeef): eventually move all breach watching here @@ -803,94 +695,22 @@ func (c *ChainArbitrator) Start() error { return nil } -// blockRecipient contains the information we need to dispatch a block to a -// channel arbitrator. -type blockRecipient struct { - // chanPoint is the funding outpoint of the channel. - chanPoint wire.OutPoint - - // blocks is the channel that new block heights are sent into. This - // channel should be sufficiently buffered as to not block the sender. - blocks chan<- int32 - - // quit is closed if the receiving entity is shutting down. - quit chan struct{} -} - // dispatchBlocks consumes a block epoch notification stream and dispatches // blocks to each of the chain arb's active channel arbitrators. This function // must be run in a goroutine. -func (c *ChainArbitrator) dispatchBlocks( - blockEpoch *chainntnfs.BlockEpochEvent) { - - // getRecipients is a helper function which acquires the chain arb - // lock and returns a set of block recipients which can be used to - // dispatch blocks. - getRecipients := func() []blockRecipient { - c.Lock() - blocks := make([]blockRecipient, 0, len(c.activeChannels)) - for _, channel := range c.activeChannels { - blocks = append(blocks, blockRecipient{ - chanPoint: channel.cfg.ChanPoint, - blocks: channel.blocks, - quit: channel.quit, - }) - } - c.Unlock() - - return blocks - } - - // On exit, cancel our blocks subscription and close each block channel - // so that the arbitrators know they will no longer be receiving blocks. - defer func() { - blockEpoch.Cancel() - - recipients := getRecipients() - for _, recipient := range recipients { - close(recipient.blocks) - } - }() - +func (c *ChainArbitrator) dispatchBlocks() { // Consume block epochs until we receive the instruction to shutdown. for { select { // Consume block epochs, exiting if our subscription is // terminated. - case block, ok := <-blockEpoch.Epochs: - if !ok { - log.Trace("dispatchBlocks block epoch " + - "cancelled") - return - } + case beat := <-c.BlockbeatChan: + // Set the current blockbeat. + c.beat = beat - // Get the set of currently active channels block - // subscription channels and dispatch the block to - // each. - for _, recipient := range getRecipients() { - select { - // Deliver the block to the arbitrator. - case recipient.blocks <- block.Height: - - // If the recipient is shutting down, exit - // without delivering the block. This may be - // the case when two blocks are mined in quick - // succession, and the arbitrator resolves - // after the first block, and does not need to - // consume the second block. - case <-recipient.quit: - log.Debugf("channel: %v exit without "+ - "receiving block: %v", - recipient.chanPoint, - block.Height) - - // If the chain arb is shutting down, we don't - // need to deliver any more blocks (everything - // will be shutting down). - case <-c.quit: - return - } - } + // Send this blockbeat to all the active channels and + // wait for them to finish processing it. + c.handleBlockbeat(beat) // Exit if the chain arbitrator is shutting down. case <-c.quit: @@ -899,6 +719,32 @@ func (c *ChainArbitrator) dispatchBlocks( } } +// handleBlockbeat sends the blockbeat to all active channel arbitrator in +// parallel and wait for them to finish processing it. +func (c *ChainArbitrator) handleBlockbeat(beat chainio.Blockbeat) { + // Read the active channels in a lock. + c.Lock() + + // Create a slice to record active channel arbitrator. + channels := make([]chainio.Consumer, 0, len(c.activeChannels)) + + // Copy the active channels to the slice. + for _, channel := range c.activeChannels { + channels = append(channels, channel) + } + + c.Unlock() + + // Iterate all the copied channels and send the blockbeat to them. + // + // NOTE: This method will timeout if the processing of blocks of the + // subsystems is too long (60s). + err := chainio.DispatchConcurrent(beat, channels) + + // Notify the chain arbitrator has processed the block. + c.NotifyBlockProcessed(beat, err) +} + // republishClosingTxs will load any stored cooperative or unilateral closing // transactions and republish them. This helps ensure propagation of the // transactions in the event that prior publications failed. @@ -1248,7 +1094,7 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error // arbitrators, then launch it. c.activeChannels[chanPoint] = channelArb - if err := channelArb.Start(nil); err != nil { + if err := channelArb.Start(nil, c.beat); err != nil { return err } @@ -1361,3 +1207,152 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID, // TODO(roasbeef): arbitration reports // * types: contested, waiting for success conf, etc + +// NOTE: part of the `chainio.Consumer` interface. +func (c *ChainArbitrator) Name() string { + return "ChainArbitrator" +} + +// loadOpenChannels loads all channels that are currently open in the database +// and registers them with the chainWatcher for future notification. +func (c *ChainArbitrator) loadOpenChannels() error { + openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels() + if err != nil { + return err + } + + if len(openChannels) == 0 { + return nil + } + + log.Infof("Creating ChannelArbitrators for %v active channels", + len(openChannels)) + + // For each open channel, we'll configure then launch a corresponding + // ChannelArbitrator. + for _, channel := range openChannels { + chanPoint := channel.FundingOutpoint + channel := channel + + // First, we'll create an active chainWatcher for this channel + // to ensure that we detect any relevant on chain events. + breachClosure := func(ret *lnwallet.BreachRetribution) error { + return c.cfg.ContractBreach(chanPoint, ret) + } + + chainWatcher, err := newChainWatcher( + chainWatcherConfig{ + chanState: channel, + notifier: c.cfg.Notifier, + signer: c.cfg.Signer, + isOurAddr: c.cfg.IsOurAddress, + contractBreach: breachClosure, + extractStateNumHint: lnwallet.GetStateNumHint, + auxLeafStore: c.cfg.AuxLeafStore, + auxResolver: c.cfg.AuxResolver, + }, + ) + if err != nil { + return err + } + + c.activeWatchers[chanPoint] = chainWatcher + channelArb, err := newActiveChannelArbitrator( + channel, c, chainWatcher.SubscribeChannelEvents(), + ) + if err != nil { + return err + } + + c.activeChannels[chanPoint] = channelArb + + // Republish any closing transactions for this channel. + err = c.republishClosingTxs(channel) + if err != nil { + log.Errorf("Failed to republish closing txs for "+ + "channel %v", chanPoint) + } + } + + return nil +} + +// loadPendingCloseChannels loads all channels that are currently pending +// closure in the database and registers them with the ChannelArbitrator to +// continue the resolution process. +func (c *ChainArbitrator) loadPendingCloseChannels() error { + chanStateDB := c.chanSource.ChannelStateDB() + + closingChannels, err := chanStateDB.FetchClosedChannels(true) + if err != nil { + return err + } + + if len(closingChannels) == 0 { + return nil + } + + log.Infof("Creating ChannelArbitrators for %v closing channels", + len(closingChannels)) + + // Next, for each channel is the closing state, we'll launch a + // corresponding more restricted resolver, as we don't have to watch + // the chain any longer, only resolve the contracts on the confirmed + // commitment. + //nolint:ll + for _, closeChanInfo := range closingChannels { + // We can leave off the CloseContract and ForceCloseChan + // methods as the channel is already closed at this point. + chanPoint := closeChanInfo.ChanPoint + arbCfg := ChannelArbitratorConfig{ + ChanPoint: chanPoint, + ShortChanID: closeChanInfo.ShortChanID, + ChainArbitratorConfig: c.cfg, + ChainEvents: &ChainEventSubscription{}, + IsPendingClose: true, + ClosingHeight: closeChanInfo.CloseHeight, + CloseType: closeChanInfo.CloseType, + PutResolverReport: func(tx kvdb.RwTx, + report *channeldb.ResolverReport) error { + + return c.chanSource.PutResolverReport( + tx, c.cfg.ChainHash, &chanPoint, report, + ) + }, + FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) { + return chanStateDB.FetchHistoricalChannel(&chanPoint) + }, + FindOutgoingHTLCDeadline: func( + htlc channeldb.HTLC) fn.Option[int32] { + + return c.FindOutgoingHTLCDeadline( + closeChanInfo.ShortChanID, htlc, + ) + }, + } + chanLog, err := newBoltArbitratorLog( + c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint, + ) + if err != nil { + return err + } + arbCfg.MarkChannelResolved = func() error { + if c.cfg.NotifyFullyResolvedChannel != nil { + c.cfg.NotifyFullyResolvedChannel(chanPoint) + } + + return c.ResolveContract(chanPoint) + } + + // We create an empty map of HTLC's here since it's possible + // that the channel is in StateDefault and updateActiveHTLCs is + // called. We want to avoid writing to an empty map. Since the + // channel is already in the process of being resolved, no new + // HTLCs will be added. + c.activeChannels[chanPoint] = NewChannelArbitrator( + arbCfg, make(map[HtlcSetKey]htlcSet), chanLog, + ) + } + + return nil +} diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index fe2603ca5a..622686f76c 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -77,7 +77,6 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { @@ -91,7 +90,8 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { chainArbCfg, db, ) - if err := chainArb.Start(); err != nil { + beat := newBeatFromHeight(0) + if err := chainArb.Start(beat); err != nil { t.Fatal(err) } t.Cleanup(func() { @@ -158,7 +158,6 @@ func TestResolveContract(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { @@ -175,7 +174,8 @@ func TestResolveContract(t *testing.T) { chainArb := NewChainArbitrator( chainArbCfg, db, ) - if err := chainArb.Start(); err != nil { + beat := newBeatFromHeight(0) + if err := chainArb.Start(beat); err != nil { t.Fatal(err) } t.Cleanup(func() { diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index e79c8d546b..e29f21e7f4 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -18,7 +18,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" @@ -451,7 +451,7 @@ func (c *chainWatcher) handleUnknownLocalState( leaseExpiry = c.cfg.chanState.ThawHeight } - remoteAuxLeaf := fn.ChainOption( + remoteAuxLeaf := fn.FlatMapOption( func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { return l.RemoteAuxLeaf }, @@ -468,7 +468,7 @@ func (c *chainWatcher) handleUnknownLocalState( // Next, we'll derive our script that includes the revocation base for // the remote party allowing them to claim this output before the CSV // delay if we breach. - localAuxLeaf := fn.ChainOption( + localAuxLeaf := fn.FlatMapOption( func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { return l.LocalAuxLeaf }, @@ -1062,15 +1062,15 @@ func (c *chainWatcher) toSelfAmount(tx *wire.MsgTx) btcutil.Amount { return false } - return fn.Any(c.cfg.isOurAddr, addrs) + return fn.Any(addrs, c.cfg.isOurAddr) } // Grab all of the outputs that correspond with our delivery address // or our wallet is aware of. - outs := fn.Filter(fn.PredOr(isDeliveryOutput, isWalletOutput), tx.TxOut) + outs := fn.Filter(tx.TxOut, fn.PredOr(isDeliveryOutput, isWalletOutput)) // Grab the values for those outputs. - vals := fn.Map(func(o *wire.TxOut) int64 { return o.Value }, outs) + vals := fn.Map(outs, func(o *wire.TxOut) int64 { return o.Value }) // Return the sum. return btcutil.Amount(fn.Sum(vals)) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 319b437e4e..9b554cb648 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -14,8 +14,9 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -330,6 +331,10 @@ type ChannelArbitrator struct { started int32 // To be used atomically. stopped int32 // To be used atomically. + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + // startTimestamp is the time when this ChannelArbitrator was started. startTimestamp time.Time @@ -352,11 +357,6 @@ type ChannelArbitrator struct { // to do its duty. cfg ChannelArbitratorConfig - // blocks is a channel that the arbitrator will receive new blocks on. - // This channel should be buffered by so that it does not block the - // sender. - blocks chan int32 - // signalUpdates is a channel that any new live signals for the channel // we're watching over will be sent. signalUpdates chan *signalUpdateMsg @@ -404,9 +404,8 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, unmerged[RemotePendingHtlcSet] = htlcSets[RemotePendingHtlcSet] } - return &ChannelArbitrator{ + c := &ChannelArbitrator{ log: log, - blocks: make(chan int32, arbitratorBlockBufferSize), signalUpdates: make(chan *signalUpdateMsg), resolutionSignal: make(chan struct{}), forceCloseReqs: make(chan *forceCloseReq), @@ -415,8 +414,16 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, cfg: cfg, quit: make(chan struct{}), } + + // Mount the block consumer. + c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name()) + + return c } +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*ChannelArbitrator)(nil) + // chanArbStartState contains the information from disk that we need to start // up a channel arbitrator. type chanArbStartState struct { @@ -455,7 +462,9 @@ func (c *ChannelArbitrator) getStartState(tx kvdb.RTx) (*chanArbStartState, // Start starts all the goroutines that the ChannelArbitrator needs to operate. // If takes a start state, which will be looked up on disk if it is not // provided. -func (c *ChannelArbitrator) Start(state *chanArbStartState) error { +func (c *ChannelArbitrator) Start(state *chanArbStartState, + beat chainio.Blockbeat) error { + if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { return nil } @@ -477,10 +486,22 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { // Set our state from our starting state. c.state = state.currentState - _, bestHeight, err := c.cfg.ChainIO.GetBestBlock() - if err != nil { - return err - } + // Get the starting height. + bestHeight := beat.Height() + + c.wg.Add(1) + go c.channelAttendant(bestHeight, state.commitSet) + + return nil +} + +// progressStateMachineAfterRestart attempts to progress the state machine +// after a restart. This makes sure that if the state transition failed, we +// will try to progress the state machine again. Moreover it will relaunch +// resolvers if the channel is still in the pending close state and has not +// been fully resolved yet. +func (c *ChannelArbitrator) progressStateMachineAfterRestart(bestHeight int32, + commitSet *CommitSet) error { // If the channel has been marked pending close in the database, and we // haven't transitioned the state machine to StateContractClosed (or a @@ -527,7 +548,7 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { // on-chain state, and our set of active contracts. startingState := c.state nextState, _, err := c.advanceState( - triggerHeight, trigger, state.commitSet, + triggerHeight, trigger, commitSet, ) if err != nil { switch err { @@ -564,14 +585,12 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { // receive a chain event from the chain watcher that the // commitment has been confirmed on chain, and before we // advance our state step, we call InsertConfirmedCommitSet. - err := c.relaunchResolvers(state.commitSet, triggerHeight) + err := c.relaunchResolvers(commitSet, triggerHeight) if err != nil { return err } } - c.wg.Add(1) - go c.channelAttendant(bestHeight) return nil } @@ -797,7 +816,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // TODO(roasbeef): this isn't re-launched? } - c.launchResolvers(unresolvedContracts, true) + c.resolveContracts(unresolvedContracts) return nil } @@ -997,7 +1016,7 @@ func (c *ChannelArbitrator) stateStep( getIdx := func(htlc channeldb.HTLC) uint64 { return htlc.HtlcIndex } - dustHTLCSet := fn.NewSet(fn.Map(getIdx, dustHTLCs)...) + dustHTLCSet := fn.NewSet(fn.Map(dustHTLCs, getIdx)...) err = c.abandonForwards(dustHTLCSet) if err != nil { return StateError, closeTx, err @@ -1306,7 +1325,7 @@ func (c *ChannelArbitrator) stateStep( return htlc.HtlcIndex } remoteDangling := fn.NewSet(fn.Map( - getIdx, htlcActions[HtlcFailDanglingAction], + htlcActions[HtlcFailDanglingAction], getIdx, )...) err := c.abandonForwards(remoteDangling) if err != nil { @@ -1336,7 +1355,7 @@ func (c *ChannelArbitrator) stateStep( // Finally, we'll launch all the required contract resolvers. // Once they're all resolved, we're no longer needed. - c.launchResolvers(resolvers, false) + c.resolveContracts(resolvers) nextState = StateWaitingFullResolution @@ -1559,17 +1578,72 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32, return fn.Some(int32(deadline)), valueLeft, nil } -// launchResolvers updates the activeResolvers list and starts the resolvers. -func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver, - immediate bool) { - +// resolveContracts updates the activeResolvers list and starts to resolve each +// contract concurrently, and launches them. +func (c *ChannelArbitrator) resolveContracts(resolvers []ContractResolver) { c.activeResolversLock.Lock() - defer c.activeResolversLock.Unlock() - c.activeResolvers = resolvers + c.activeResolversLock.Unlock() + + // Launch all resolvers. + c.launchResolvers() + for _, contract := range resolvers { c.wg.Add(1) - go c.resolveContract(contract, immediate) + go c.resolveContract(contract) + } +} + +// launchResolvers launches all the active resolvers concurrently. +func (c *ChannelArbitrator) launchResolvers() { + c.activeResolversLock.Lock() + resolvers := c.activeResolvers + c.activeResolversLock.Unlock() + + // errChans is a map of channels that will be used to receive errors + // returned from launching the resolvers. + errChans := make(map[ContractResolver]chan error, len(resolvers)) + + // Launch each resolver in goroutines. + for _, r := range resolvers { + // If the contract is already resolved, there's no need to + // launch it again. + if r.IsResolved() { + log.Debugf("ChannelArbitrator(%v): skipping resolver "+ + "%T as it's already resolved", c.cfg.ChanPoint, + r) + + continue + } + + // Create a signal chan. + errChan := make(chan error, 1) + errChans[r] = errChan + + go func() { + err := r.Launch() + errChan <- err + }() + } + + // Wait for all resolvers to finish launching. + for r, errChan := range errChans { + select { + case err := <-errChan: + if err == nil { + continue + } + + log.Errorf("ChannelArbitrator(%v): unable to launch "+ + "contract resolver(%T): %v", c.cfg.ChanPoint, r, + err) + + case <-c.quit: + log.Debugf("ChannelArbitrator quit signal received, " + + "exit launchResolvers") + + return + } } } @@ -1593,8 +1667,8 @@ func (c *ChannelArbitrator) advanceState( for { priorState = c.state log.Debugf("ChannelArbitrator(%v): attempting state step with "+ - "trigger=%v from state=%v", c.cfg.ChanPoint, trigger, - priorState) + "trigger=%v from state=%v at height=%v", + c.cfg.ChanPoint, trigger, priorState, triggerHeight) nextState, closeTx, err := c.stateStep( triggerHeight, trigger, confCommitSet, @@ -2541,9 +2615,7 @@ func (c *ChannelArbitrator) replaceResolver(oldResolver, // contracts. // // NOTE: This MUST be run as a goroutine. -func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, - immediate bool) { - +func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver) { defer c.wg.Done() log.Debugf("ChannelArbitrator(%v): attempting to resolve %T", @@ -2564,7 +2636,7 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, default: // Otherwise, we'll attempt to resolve the current // contract. - nextContract, err := currentContract.Resolve(immediate) + nextContract, err := currentContract.Resolve() if err != nil { if err == errResolverShuttingDown { return @@ -2613,6 +2685,13 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, // loop. currentContract = nextContract + // Launch the new contract. + err = currentContract.Launch() + if err != nil { + log.Errorf("Failed to launch %T: %v", + currentContract, err) + } + // If this contract is actually fully resolved, then // we'll mark it as such within the database. case currentContract.IsResolved(): @@ -2716,44 +2795,49 @@ func (c *ChannelArbitrator) updateActiveHTLCs() { // Nursery for incubation, and ultimate sweeping. // // NOTE: This MUST be run as a goroutine. -func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { +// +//nolint:funlen +func (c *ChannelArbitrator) channelAttendant(bestHeight int32, + commitSet *CommitSet) { // TODO(roasbeef): tell top chain arb we're done defer func() { c.wg.Done() }() + err := c.progressStateMachineAfterRestart(bestHeight, commitSet) + if err != nil { + // In case of an error, we return early but we do not shutdown + // LND, because there might be other channels that still can be + // resolved and we don't want to interfere with that. + // We continue to run the channel attendant in case the channel + // closes via other means for example the remote pary force + // closes the channel. So we log the error and continue. + log.Errorf("Unable to progress state machine after "+ + "restart: %v", err) + } + for { select { // A new block has arrived, we'll examine all the active HTLC's // to see if any of them have expired, and also update our // track of the best current height. - case blockHeight, ok := <-c.blocks: - if !ok { - return - } - bestHeight = blockHeight + case beat := <-c.BlockbeatChan: + bestHeight = beat.Height() - // If we're not in the default state, then we can - // ignore this signal as we're waiting for contract - // resolution. - if c.state != StateDefault { - continue - } + log.Debugf("ChannelArbitrator(%v): new block height=%v", + c.cfg.ChanPoint, bestHeight) - // Now that a new block has arrived, we'll attempt to - // advance our state forward. - nextState, _, err := c.advanceState( - uint32(bestHeight), chainTrigger, nil, - ) + err := c.handleBlockbeat(beat) if err != nil { - log.Errorf("Unable to advance state: %v", err) + log.Errorf("Handle block=%v got err: %v", + bestHeight, err) } // If as a result of this trigger, the contract is // fully resolved, then well exit. - if nextState == StateFullyResolved { + if c.state == StateFullyResolved { return } @@ -2802,14 +2886,12 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // We have broadcasted our commitment, and it is now confirmed // on-chain. case closeInfo := <-c.cfg.ChainEvents.LocalUnilateralClosure: - log.Infof("ChannelArbitrator(%v): local on-chain "+ - "channel close", c.cfg.ChanPoint) - if c.state != StateCommitmentBroadcasted { log.Errorf("ChannelArbitrator(%v): unexpected "+ "local on-chain channel close", c.cfg.ChanPoint) } + closeTx := closeInfo.CloseTx resolutions, err := closeInfo.ContractResolutions. @@ -2837,6 +2919,10 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { return } + log.Infof("ChannelArbitrator(%v): local force close "+ + "tx=%v confirmed", c.cfg.ChanPoint, + closeTx.TxHash()) + contractRes := &ContractResolutions{ CommitHash: closeTx.TxHash(), CommitResolution: resolutions.CommitResolution, @@ -3104,6 +3190,37 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { } } +// handleBlockbeat processes a newly received blockbeat by advancing the +// arbitrator's internal state using the received block height. +func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Blockbeat) error { + // Notify we've processed the block. + defer c.NotifyBlockProcessed(beat, nil) + + // Try to advance the state if we are in StateDefault. + if c.state == StateDefault { + // Now that a new block has arrived, we'll attempt to advance + // our state forward. + _, _, err := c.advanceState( + uint32(beat.Height()), chainTrigger, nil, + ) + if err != nil { + return fmt.Errorf("unable to advance state: %w", err) + } + } + + // Launch all active resolvers when a new blockbeat is received. + c.launchResolvers() + + return nil +} + +// Name returns a human-readable string for this subsystem. +// +// NOTE: Part of chainio.Consumer interface. +func (c *ChannelArbitrator) Name() string { + return fmt.Sprintf("ChannelArbitrator(%v)", c.cfg.ChanPoint) +} + // checkLegacyBreach returns StateFullyResolved if the channel was closed with // a breach transaction before the channel arbitrator launched its own breach // resolver. StateContractClosed is returned if this is a modern breach close diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 92ad608eb9..827e10b5c7 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -13,14 +13,17 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -226,6 +229,15 @@ func (c *chanArbTestCtx) CleanUp() { } } +// receiveBlockbeat mocks the behavior of a blockbeat being sent by the +// BlockbeatDispatcher, which essentially mocks the method `ProcessBlock`. +func (c *chanArbTestCtx) receiveBlockbeat(height int) { + go func() { + beat := newBeatFromHeight(int32(height)) + c.chanArb.BlockbeatChan <- beat + }() +} + // AssertStateTransitions asserts that the state machine steps through the // passed states in order. func (c *chanArbTestCtx) AssertStateTransitions(expectedStates ...ArbitratorState) { @@ -285,7 +297,8 @@ func (c *chanArbTestCtx) Restart(restartClosure func(*chanArbTestCtx)) (*chanArb restartClosure(newCtx) } - if err := newCtx.chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := newCtx.chanArb.Start(nil, beat); err != nil { return nil, err } @@ -512,7 +525,8 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) { chanArbCtx, err := createTestChannelArbitrator(t, log) require.NoError(t, err, "unable to create ChannelArbitrator") - if err := chanArbCtx.chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArbCtx.chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -570,7 +584,8 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -623,7 +638,8 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -735,7 +751,8 @@ func TestChannelArbitratorBreachClose(t *testing.T) { chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.Registry = &mockRegistry{} - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -862,7 +879,8 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.Registry = &mockRegistry{} - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -965,6 +983,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { }, }, } + closeTxid := closeTx.TxHash() htlcOp := wire.OutPoint{ Hash: closeTx.TxHash(), @@ -1036,7 +1055,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { } require.Equal(t, expectedFinalHtlcs, chanArbCtx.finalHtlcs) - // We'll no re-create the resolver, notice that we use the existing + // We'll now re-create the resolver, notice that we use the existing // arbLog so it carries over the same on-disk state. chanArbCtxNew, err := chanArbCtx.Restart(nil) require.NoError(t, err, "unable to create ChannelArbitrator") @@ -1045,10 +1064,19 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { // Post restart, it should be the case that our resolver was properly // supplemented, and we only have a single resolver in the final set. - if len(chanArb.activeResolvers) != 1 { - t.Fatalf("expected single resolver, instead got: %v", - len(chanArb.activeResolvers)) - } + // The resolvers are added concurrently so we need to wait here. + err = wait.NoError(func() error { + chanArb.activeResolversLock.Lock() + defer chanArb.activeResolversLock.Unlock() + + if len(chanArb.activeResolvers) != 1 { + return fmt.Errorf("expected single resolver, instead "+ + "got: %v", len(chanArb.activeResolvers)) + } + + return nil + }, defaultTimeout) + require.NoError(t, err) // We'll now examine the in-memory state of the active resolvers to // ensure t hey were populated properly. @@ -1086,7 +1114,11 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { // Notify resolver that the HTLC output of the commitment has been // spent. - oldNotifier.SpendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} + oldNotifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: closeTx, + SpentOutPoint: &wire.OutPoint{}, + SpenderTxHash: &closeTxid, + } // Finally, we should also receive a resolution message instructing the // switch to cancel back the HTLC. @@ -1113,8 +1145,12 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { default: } - // Notify resolver that the second level transaction is spent. - oldNotifier.SpendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} + // Notify resolver that the output of the timeout tx has been spent. + oldNotifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: closeTx, + SpentOutPoint: &wire.OutPoint{}, + SpenderTxHash: &closeTxid, + } // At this point channel should be marked as resolved. chanArbCtxNew.AssertStateTransitions(StateFullyResolved) @@ -1138,7 +1174,8 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1245,7 +1282,8 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1351,7 +1389,8 @@ func TestChannelArbitratorPersistence(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1469,7 +1508,8 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1656,7 +1696,8 @@ func TestChannelArbitratorCommitFailure(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1740,7 +1781,8 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) { chanArb.cfg.ClosingHeight = 100 chanArb.cfg.CloseType = channeldb.RemoteForceClose - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(100) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1770,7 +1812,8 @@ func TestChannelArbitratorAlreadyForceClosed(t *testing.T) { chanArbCtx, err := createTestChannelArbitrator(t, log) require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1868,9 +1911,10 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { - t.Fatalf("unable to start ChannelArbitrator: %v", err) - } + beat := newBeatFromHeight(0) + err = chanArb.Start(nil, beat) + require.NoError(t, err) + defer chanArb.Stop() // Now that our channel arb has started, we'll set up @@ -1914,7 +1958,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // now mine a block (height 5), which is 5 blocks away // (our grace delta) from the expiry of that HTLC. case testCase.htlcExpired: - chanArbCtx.chanArb.blocks <- 5 + beat := newBeatFromHeight(5) + chanArbCtx.chanArb.BlockbeatChan <- beat // Otherwise, we'll just trigger a regular force close // request. @@ -2026,8 +2071,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // so instead, we'll mine another block which'll cause // it to re-examine its state and realize there're no // more HTLCs. - chanArbCtx.chanArb.blocks <- 6 - chanArbCtx.AssertStateTransitions(StateFullyResolved) + chanArbCtx.receiveBlockbeat(6) }) } } @@ -2064,7 +2108,8 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { return false } - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -2098,13 +2143,15 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { // We will advance the uptime to 10 seconds which should be still within // the grace period and should not trigger going to chain. testClock.SetTime(startTime.Add(time.Second * 10)) - chanArbCtx.chanArb.blocks <- 5 + beat = newBeatFromHeight(5) + chanArbCtx.chanArb.BlockbeatChan <- beat chanArbCtx.AssertState(StateDefault) // We will advance the uptime to 16 seconds which should trigger going // to chain. testClock.SetTime(startTime.Add(time.Second * 16)) - chanArbCtx.chanArb.blocks <- 6 + beat = newBeatFromHeight(6) + chanArbCtx.chanArb.BlockbeatChan <- beat chanArbCtx.AssertStateTransitions( StateBroadcastCommit, StateCommitmentBroadcasted, @@ -2217,8 +2264,8 @@ func TestRemoteCloseInitiator(t *testing.T) { "ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start "+ "ChannelArbitrator: %v", err) } @@ -2472,7 +2519,7 @@ func TestSweepAnchors(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + chanArbCtx.receiveBlockbeat(int(heightHint)) htlcIndexBase := uint64(99) deadlineDelta := uint32(10) @@ -2635,7 +2682,7 @@ func TestSweepLocalAnchor(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + chanArbCtx.receiveBlockbeat(int(heightHint)) htlcIndex := uint64(99) deadlineDelta := uint32(10) @@ -2769,7 +2816,9 @@ func TestChannelArbitratorAnchors(t *testing.T) { }, } - if err := chanArb.Start(nil); err != nil { + heightHint := uint32(1000) + beat := newBeatFromHeight(int32(heightHint)) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -2781,27 +2830,28 @@ func TestChannelArbitratorAnchors(t *testing.T) { } chanArb.UpdateContractSignals(signals) - // Set current block height. - heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) - htlcAmt := lnwire.MilliSatoshi(1_000_000) // Create testing HTLCs. - deadlineDelta := uint32(10) - deadlinePreimageDelta := deadlineDelta + 2 + spendingHeight := uint32(beat.Height()) + deadlineDelta := uint32(100) + + deadlinePreimageDelta := deadlineDelta htlcWithPreimage := channeldb.HTLC{ - HtlcIndex: 99, - RefundTimeout: heightHint + deadlinePreimageDelta, + HtlcIndex: 99, + // RefundTimeout is 101. + RefundTimeout: spendingHeight + deadlinePreimageDelta, RHash: rHash, Incoming: true, Amt: htlcAmt, } + expectedDeadline := deadlineDelta/2 + spendingHeight - deadlineHTLCdelta := deadlineDelta + 3 + deadlineHTLCdelta := deadlineDelta + 40 htlc := channeldb.HTLC{ - HtlcIndex: 100, - RefundTimeout: heightHint + deadlineHTLCdelta, + HtlcIndex: 100, + // RefundTimeout is 141. + RefundTimeout: spendingHeight + deadlineHTLCdelta, Amt: htlcAmt, } @@ -2886,7 +2936,9 @@ func TestChannelArbitratorAnchors(t *testing.T) { //nolint:ll chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ - SpendDetail: &chainntnfs.SpendDetail{}, + SpendDetail: &chainntnfs.SpendDetail{ + SpendingHeight: int32(spendingHeight), + }, LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{ CloseTx: closeTx, ContractResolutions: fn.Some(lnwallet.ContractResolutions{ @@ -2950,12 +3002,14 @@ func TestChannelArbitratorAnchors(t *testing.T) { // to htlcWithPreimage's CLTV. require.Equal(t, 2, len(chanArbCtx.sweeper.deadlines)) require.EqualValues(t, - heightHint+deadlinePreimageDelta/2, - chanArbCtx.sweeper.deadlines[0], + expectedDeadline, + chanArbCtx.sweeper.deadlines[0], "want %d, got %d", + expectedDeadline, chanArbCtx.sweeper.deadlines[0], ) require.EqualValues(t, - heightHint+deadlinePreimageDelta/2, - chanArbCtx.sweeper.deadlines[1], + expectedDeadline, + chanArbCtx.sweeper.deadlines[1], "want %d, got %d", + expectedDeadline, chanArbCtx.sweeper.deadlines[1], ) } @@ -3000,9 +3054,12 @@ func TestChannelArbitratorStartForceCloseFail(t *testing.T) { { name: "Commitment is rejected with an " + "unmatched error", - broadcastErr: fmt.Errorf("Reject Commitment Tx"), - expectedState: StateBroadcastCommit, - expectedStartup: false, + broadcastErr: fmt.Errorf("Reject Commitment Tx"), + expectedState: StateBroadcastCommit, + // We should still be able to start up since we other + // channels might be closing as well and we should + // resolve the contracts. + expectedStartup: true, }, // We started after the DLP was triggered, and try to force @@ -3054,7 +3111,8 @@ func TestChannelArbitratorStartForceCloseFail(t *testing.T) { return test.broadcastErr } - err = chanArb.Start(nil) + beat := newBeatFromHeight(0) + err = chanArb.Start(nil, beat) if !test.expectedStartup { require.ErrorIs(t, err, test.broadcastErr) @@ -3102,7 +3160,8 @@ func assertResolverReport(t *testing.T, reports chan *channeldb.ResolverReport, select { case report := <-reports: if !reflect.DeepEqual(report, expected) { - t.Fatalf("expected: %v, got: %v", expected, report) + t.Fatalf("expected: %v, got: %v", spew.Sdump(expected), + spew.Sdump(report)) } case <-time.After(defaultTimeout): @@ -3133,3 +3192,11 @@ func (m *mockChannel) ForceCloseChan() (*wire.MsgTx, error) { return &wire.MsgTx{}, nil } + +func newBeatFromHeight(height int32) *chainio.Beat { + epoch := chainntnfs.BlockEpoch{ + Height: height, + } + + return chainio.NewBeat(epoch) +} diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 4b47a34294..55ee08e5d8 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/sweep" @@ -39,9 +39,6 @@ type commitSweepResolver struct { // this HTLC on-chain. commitResolution lnwallet.CommitOutputResolution - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // broadcastHeight is the height that the original contract was // broadcast to the main-chain at. We'll use this value to bound any // historical queries to the chain for spends/confirmations. @@ -88,7 +85,7 @@ func newCommitSweepResolver(res lnwallet.CommitOutputResolution, chanPoint: chanPoint, } - r.initLogger(r) + r.initLogger(fmt.Sprintf("%T(%v)", r, r.commitResolution.SelfOutPoint)) r.initReport() return r @@ -101,36 +98,6 @@ func (c *commitSweepResolver) ResolverKey() []byte { return key[:] } -// waitForHeight registers for block notifications and waits for the provided -// block height to be reached. -func waitForHeight(waitHeight uint32, notifier chainntnfs.ChainNotifier, - quit <-chan struct{}) error { - - // Register for block epochs. After registration, the current height - // will be sent on the channel immediately. - blockEpochs, err := notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - defer blockEpochs.Cancel() - - for { - select { - case newBlock, ok := <-blockEpochs.Epochs: - if !ok { - return errResolverShuttingDown - } - height := newBlock.Height - if height >= int32(waitHeight) { - return nil - } - - case <-quit: - return errResolverShuttingDown - } - } -} - // waitForSpend waits for the given outpoint to be spent, and returns the // details of the spending tx. func waitForSpend(op *wire.OutPoint, pkScript []byte, heightHint uint32, @@ -195,203 +162,17 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { // returned. // // NOTE: This function MUST be run as a goroutine. + +// TODO(yy): fix the funlen in the next PR. // //nolint:funlen -func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { +func (c *commitSweepResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. - if c.resolved { + if c.IsResolved() { + c.log.Errorf("already resolved") return nil, nil } - confHeight, err := c.getCommitTxConfHeight() - if err != nil { - return nil, err - } - - // Wait up until the CSV expires, unless we also have a CLTV that - // expires after. - unlockHeight := confHeight + c.commitResolution.MaturityDelay - if c.hasCLTV() { - unlockHeight = uint32(math.Max( - float64(unlockHeight), float64(c.leaseExpiry), - )) - } - - c.log.Debugf("commit conf_height=%v, unlock_height=%v", - confHeight, unlockHeight) - - // Update report now that we learned the confirmation height. - c.reportLock.Lock() - c.currentReport.MaturityHeight = unlockHeight - c.reportLock.Unlock() - - // If there is a csv/cltv lock, we'll wait for that. - if c.commitResolution.MaturityDelay > 0 || c.hasCLTV() { - // Determine what height we should wait until for the locks to - // expire. - var waitHeight uint32 - switch { - // If we have both a csv and cltv lock, we'll need to look at - // both and see which expires later. - case c.commitResolution.MaturityDelay > 0 && c.hasCLTV(): - c.log.Debugf("waiting for CSV and CLTV lock to expire "+ - "at height %v", unlockHeight) - // If the CSV expires after the CLTV, or there is no - // CLTV, then we can broadcast a sweep a block before. - // Otherwise, we need to broadcast at our expected - // unlock height. - waitHeight = uint32(math.Max( - float64(unlockHeight-1), float64(c.leaseExpiry), - )) - - // If we only have a csv lock, wait for the height before the - // lock expires as the spend path should be unlocked by then. - case c.commitResolution.MaturityDelay > 0: - c.log.Debugf("waiting for CSV lock to expire at "+ - "height %v", unlockHeight) - waitHeight = unlockHeight - 1 - } - - err := waitForHeight(waitHeight, c.Notifier, c.quit) - if err != nil { - return nil, err - } - } - - var ( - isLocalCommitTx bool - - signDesc = c.commitResolution.SelfOutputSignDesc - ) - - switch { - // For taproot channels, we'll know if this is the local commit based - // on the timelock value. For remote commitment transactions, the - // witness script has a timelock of 1. - case c.chanType.IsTaproot(): - delayKey := c.localChanCfg.DelayBasePoint.PubKey - nonDelayKey := c.localChanCfg.PaymentBasePoint.PubKey - - signKey := c.commitResolution.SelfOutputSignDesc.KeyDesc.PubKey - - // If the key in the script is neither of these, we shouldn't - // proceed. This should be impossible. - if !signKey.IsEqual(delayKey) && !signKey.IsEqual(nonDelayKey) { - return nil, fmt.Errorf("unknown sign key %v", signKey) - } - - // The commitment transaction is ours iff the signing key is - // the delay key. - isLocalCommitTx = signKey.IsEqual(delayKey) - - // The output is on our local commitment if the script starts with - // OP_IF for the revocation clause. On the remote commitment it will - // either be a regular P2WKH or a simple sig spend with a CSV delay. - default: - isLocalCommitTx = signDesc.WitnessScript[0] == txscript.OP_IF - } - isDelayedOutput := c.commitResolution.MaturityDelay != 0 - - c.log.Debugf("isDelayedOutput=%v, isLocalCommitTx=%v", isDelayedOutput, - isLocalCommitTx) - - // There're three types of commitments, those that have tweaks for the - // remote key (us in this case), those that don't, and a third where - // there is no tweak and the output is delayed. On the local commitment - // our output will always be delayed. We'll rely on the presence of the - // commitment tweak to discern which type of commitment this is. - var witnessType input.WitnessType - switch { - // The local delayed output for a taproot channel. - case isLocalCommitTx && c.chanType.IsTaproot(): - witnessType = input.TaprootLocalCommitSpend - - // The CSV 1 delayed output for a taproot channel. - case !isLocalCommitTx && c.chanType.IsTaproot(): - witnessType = input.TaprootRemoteCommitSpend - - // Delayed output to us on our local commitment for a channel lease in - // which we are the initiator. - case isLocalCommitTx && c.hasCLTV(): - witnessType = input.LeaseCommitmentTimeLock - - // Delayed output to us on our local commitment. - case isLocalCommitTx: - witnessType = input.CommitmentTimeLock - - // A confirmed output to us on the remote commitment for a channel lease - // in which we are the initiator. - case isDelayedOutput && c.hasCLTV(): - witnessType = input.LeaseCommitmentToRemoteConfirmed - - // A confirmed output to us on the remote commitment. - case isDelayedOutput: - witnessType = input.CommitmentToRemoteConfirmed - - // A non-delayed output on the remote commitment where the key is - // tweakless. - case c.commitResolution.SelfOutputSignDesc.SingleTweak == nil: - witnessType = input.CommitSpendNoDelayTweakless - - // A non-delayed output on the remote commitment where the key is - // tweaked. - default: - witnessType = input.CommitmentNoDelay - } - - c.log.Infof("Sweeping with witness type: %v", witnessType) - - // We'll craft an input with all the information required for the - // sweeper to create a fully valid sweeping transaction to recover - // these coins. - var inp *input.BaseInput - if c.hasCLTV() { - inp = input.NewCsvInputWithCltv( - &c.commitResolution.SelfOutPoint, witnessType, - &c.commitResolution.SelfOutputSignDesc, - c.broadcastHeight, c.commitResolution.MaturityDelay, - c.leaseExpiry, - input.WithResolutionBlob( - c.commitResolution.ResolutionBlob, - ), - ) - } else { - inp = input.NewCsvInput( - &c.commitResolution.SelfOutPoint, witnessType, - &c.commitResolution.SelfOutputSignDesc, - c.broadcastHeight, c.commitResolution.MaturityDelay, - input.WithResolutionBlob( - c.commitResolution.ResolutionBlob, - ), - ) - } - - // TODO(roasbeef): instead of ading ctrl block to the sign desc, make - // new input type, have sweeper set it? - - // Calculate the budget for the sweeping this input. - budget := calculateBudget( - btcutil.Amount(inp.SignDesc().Output.Value), - c.Budget.ToLocalRatio, c.Budget.ToLocal, - ) - c.log.Infof("Sweeping commit output using budget=%v", budget) - - // With our input constructed, we'll now offer it to the sweeper. - resultChan, err := c.Sweeper.SweepInput( - inp, sweep.Params{ - Budget: budget, - - // Specify a nil deadline here as there's no time - // pressure. - DeadlineHeight: fn.None[int32](), - }, - ) - if err != nil { - c.log.Errorf("unable to sweep input: %v", err) - - return nil, err - } - var sweepTxID chainhash.Hash // Sweeper is going to join this input with other inputs if possible @@ -400,7 +181,7 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { // happen. outcome := channeldb.ResolverOutcomeClaimed select { - case sweepResult := <-resultChan: + case sweepResult := <-c.sweepResultChan: switch sweepResult.Err { // If the remote party was able to sweep this output it's // likely what we sent was actually a revoked commitment. @@ -440,7 +221,7 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { report := c.currentReport.resolverReport( &sweepTxID, channeldb.ResolverTypeCommit, outcome, ) - c.resolved = true + c.markResolved() // Checkpoint the resolver with a closure that will write the outcome // of the resolver and its sweep transaction to disk. @@ -452,17 +233,11 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { // // NOTE: Part of the ContractResolver interface. func (c *commitSweepResolver) Stop() { + c.log.Debugf("stopping...") + defer c.log.Debugf("stopped") close(c.quit) } -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (c *commitSweepResolver) IsResolved() bool { - return c.resolved -} - // SupplementState allows the user of a ContractResolver to supplement it with // state required for the proper resolution of a contract. // @@ -491,7 +266,7 @@ func (c *commitSweepResolver) Encode(w io.Writer) error { return err } - if err := binary.Write(w, endian, c.resolved); err != nil { + if err := binary.Write(w, endian, c.IsResolved()); err != nil { return err } if err := binary.Write(w, endian, c.broadcastHeight); err != nil { @@ -526,9 +301,14 @@ func newCommitSweepResolverFromReader(r io.Reader, resCfg ResolverConfig) ( return nil, err } - if err := binary.Read(r, endian, &c.resolved); err != nil { + var resolved bool + if err := binary.Read(r, endian, &resolved); err != nil { return nil, err } + if resolved { + c.markResolved() + } + if err := binary.Read(r, endian, &c.broadcastHeight); err != nil { return nil, err } @@ -545,7 +325,7 @@ func newCommitSweepResolverFromReader(r io.Reader, resCfg ResolverConfig) ( // removed this, but keep in mind that this data may still be present in // the database. - c.initLogger(c) + c.initLogger(fmt.Sprintf("%T(%v)", c, c.commitResolution.SelfOutPoint)) c.initReport() return c, nil @@ -585,3 +365,181 @@ func (c *commitSweepResolver) initReport() { // A compile time assertion to ensure commitSweepResolver meets the // ContractResolver interface. var _ reportingContractResolver = (*commitSweepResolver)(nil) + +// Launch constructs a commit input and offers it to the sweeper. +func (c *commitSweepResolver) Launch() error { + if c.isLaunched() { + c.log.Tracef("already launched") + return nil + } + + c.log.Debugf("launching resolver...") + c.markLaunched() + + // If we're already resolved, then we can exit early. + if c.IsResolved() { + c.log.Errorf("already resolved") + return nil + } + + confHeight, err := c.getCommitTxConfHeight() + if err != nil { + return err + } + + // Wait up until the CSV expires, unless we also have a CLTV that + // expires after. + unlockHeight := confHeight + c.commitResolution.MaturityDelay + if c.hasCLTV() { + unlockHeight = uint32(math.Max( + float64(unlockHeight), float64(c.leaseExpiry), + )) + } + + // Update report now that we learned the confirmation height. + c.reportLock.Lock() + c.currentReport.MaturityHeight = unlockHeight + c.reportLock.Unlock() + + // Derive the witness type for this input. + witnessType, err := c.decideWitnessType() + if err != nil { + return err + } + + // We'll craft an input with all the information required for the + // sweeper to create a fully valid sweeping transaction to recover + // these coins. + var inp *input.BaseInput + if c.hasCLTV() { + inp = input.NewCsvInputWithCltv( + &c.commitResolution.SelfOutPoint, witnessType, + &c.commitResolution.SelfOutputSignDesc, + c.broadcastHeight, c.commitResolution.MaturityDelay, + c.leaseExpiry, + ) + } else { + inp = input.NewCsvInput( + &c.commitResolution.SelfOutPoint, witnessType, + &c.commitResolution.SelfOutputSignDesc, + c.broadcastHeight, c.commitResolution.MaturityDelay, + ) + } + + // TODO(roasbeef): instead of ading ctrl block to the sign desc, make + // new input type, have sweeper set it? + + // Calculate the budget for the sweeping this input. + budget := calculateBudget( + btcutil.Amount(inp.SignDesc().Output.Value), + c.Budget.ToLocalRatio, c.Budget.ToLocal, + ) + c.log.Infof("sweeping commit output %v using budget=%v", witnessType, + budget) + + // With our input constructed, we'll now offer it to the sweeper. + resultChan, err := c.Sweeper.SweepInput( + inp, sweep.Params{ + Budget: budget, + + // Specify a nil deadline here as there's no time + // pressure. + DeadlineHeight: fn.None[int32](), + }, + ) + if err != nil { + c.log.Errorf("unable to sweep input: %v", err) + + return err + } + + c.sweepResultChan = resultChan + + return nil +} + +// decideWitnessType returns the witness type for the input. +func (c *commitSweepResolver) decideWitnessType() (input.WitnessType, error) { + var ( + isLocalCommitTx bool + signDesc = c.commitResolution.SelfOutputSignDesc + ) + + switch { + // For taproot channels, we'll know if this is the local commit based + // on the timelock value. For remote commitment transactions, the + // witness script has a timelock of 1. + case c.chanType.IsTaproot(): + delayKey := c.localChanCfg.DelayBasePoint.PubKey + nonDelayKey := c.localChanCfg.PaymentBasePoint.PubKey + + signKey := c.commitResolution.SelfOutputSignDesc.KeyDesc.PubKey + + // If the key in the script is neither of these, we shouldn't + // proceed. This should be impossible. + if !signKey.IsEqual(delayKey) && !signKey.IsEqual(nonDelayKey) { + return nil, fmt.Errorf("unknown sign key %v", signKey) + } + + // The commitment transaction is ours iff the signing key is + // the delay key. + isLocalCommitTx = signKey.IsEqual(delayKey) + + // The output is on our local commitment if the script starts with + // OP_IF for the revocation clause. On the remote commitment it will + // either be a regular P2WKH or a simple sig spend with a CSV delay. + default: + isLocalCommitTx = signDesc.WitnessScript[0] == txscript.OP_IF + } + + isDelayedOutput := c.commitResolution.MaturityDelay != 0 + + c.log.Debugf("isDelayedOutput=%v, isLocalCommitTx=%v", isDelayedOutput, + isLocalCommitTx) + + // There're three types of commitments, those that have tweaks for the + // remote key (us in this case), those that don't, and a third where + // there is no tweak and the output is delayed. On the local commitment + // our output will always be delayed. We'll rely on the presence of the + // commitment tweak to discern which type of commitment this is. + var witnessType input.WitnessType + switch { + // The local delayed output for a taproot channel. + case isLocalCommitTx && c.chanType.IsTaproot(): + witnessType = input.TaprootLocalCommitSpend + + // The CSV 1 delayed output for a taproot channel. + case !isLocalCommitTx && c.chanType.IsTaproot(): + witnessType = input.TaprootRemoteCommitSpend + + // Delayed output to us on our local commitment for a channel lease in + // which we are the initiator. + case isLocalCommitTx && c.hasCLTV(): + witnessType = input.LeaseCommitmentTimeLock + + // Delayed output to us on our local commitment. + case isLocalCommitTx: + witnessType = input.CommitmentTimeLock + + // A confirmed output to us on the remote commitment for a channel lease + // in which we are the initiator. + case isDelayedOutput && c.hasCLTV(): + witnessType = input.LeaseCommitmentToRemoteConfirmed + + // A confirmed output to us on the remote commitment. + case isDelayedOutput: + witnessType = input.CommitmentToRemoteConfirmed + + // A non-delayed output on the remote commitment where the key is + // tweakless. + case c.commitResolution.SelfOutputSignDesc.SingleTweak == nil: + witnessType = input.CommitSpendNoDelayTweakless + + // A non-delayed output on the remote commitment where the key is + // tweaked. + default: + witnessType = input.CommitmentNoDelay + } + + return witnessType, nil +} diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index 077fb8f82c..6855fddcd3 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/sweep" + "github.com/stretchr/testify/require" ) type commitSweepResolverTestContext struct { @@ -82,7 +83,10 @@ func (i *commitSweepResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + err := i.resolver.Launch() + require.NoError(i.t, err) + + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -90,12 +94,6 @@ func (i *commitSweepResolverTestContext) resolve() { }() } -func (i *commitSweepResolverTestContext) notifyEpoch(height int32) { - i.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: height, - } -} - func (i *commitSweepResolverTestContext) waitForResult() { i.t.Helper() @@ -292,22 +290,10 @@ func testCommitSweepResolverDelay(t *testing.T, sweepErr error) { t.Fatal("report maturity height incorrect") } - // Notify initial block height. The csv lock is still in effect, so we - // don't expect any sweep to happen yet. - ctx.notifyEpoch(testInitialBlockHeight) - - select { - case <-ctx.sweeper.sweptInputs: - t.Fatal("no sweep expected") - case <-time.After(sweepProcessInterval): - } - - // A new block arrives. The commit tx confirmed at height -1 and the csv - // is 3, so a spend will be valid in the first block after height +1. - ctx.notifyEpoch(testInitialBlockHeight + 1) - - <-ctx.sweeper.sweptInputs - + // Notify initial block height. Although the csv lock is still in + // effect, we expect the input being sent to the sweeper before the csv + // lock expires. + // // Set the resolution report outcome based on whether our sweep // succeeded. outcome := channeldb.ResolverOutcomeClaimed diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 53f4f680d0..d11bd2f597 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -5,11 +5,13 @@ import ( "errors" "fmt" "io" + "sync/atomic" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/sweep" ) var ( @@ -35,6 +37,17 @@ type ContractResolver interface { // resides within. ResolverKey() []byte + // Launch starts the resolver by constructing an input and offering it + // to the sweeper. Once offered, it's expected to monitor the sweeping + // result in a goroutine invoked by calling Resolve. + // + // NOTE: We can call `Resolve` inside a goroutine at the end of this + // method to avoid calling it in the ChannelArbitrator. However, there + // are some DB-related operations such as SwapContract/ResolveContract + // which need to be done inside the resolvers instead, which needs a + // deeper refactoring. + Launch() error + // Resolve instructs the contract resolver to resolve the output // on-chain. Once the output has been *fully* resolved, the function // should return immediately with a nil ContractResolver value for the @@ -42,7 +55,7 @@ type ContractResolver interface { // resolution, then another resolve is returned. // // NOTE: This function MUST be run as a goroutine. - Resolve(immediate bool) (ContractResolver, error) + Resolve() (ContractResolver, error) // SupplementState allows the user of a ContractResolver to supplement // it with state required for the proper resolution of a contract. @@ -109,6 +122,21 @@ type contractResolverKit struct { log btclog.Logger quit chan struct{} + + // sweepResultChan is the result chan returned from calling + // `SweepInput`. It should be mounted to the specific resolver once the + // input has been offered to the sweeper. + sweepResultChan chan sweep.Result + + // launched specifies whether the resolver has been launched. Calling + // `Launch` will be a no-op if this is true. This value is not saved to + // db, as it's fine to relaunch a resolver after a restart. It's only + // used to avoid resending requests to the sweeper when a new blockbeat + // is received. + launched atomic.Bool + + // resolved reflects if the contract has been fully resolved or not. + resolved atomic.Bool } // newContractResolverKit instantiates the mix-in struct. @@ -120,11 +148,36 @@ func newContractResolverKit(cfg ResolverConfig) *contractResolverKit { } // initLogger initializes the resolver-specific logger. -func (r *contractResolverKit) initLogger(resolver ContractResolver) { - logPrefix := fmt.Sprintf("%T(%v):", resolver, r.ChanPoint) +func (r *contractResolverKit) initLogger(prefix string) { + logPrefix := fmt.Sprintf("ChannelArbitrator(%v): %s:", r.ChanPoint, + prefix) + r.log = log.WithPrefix(logPrefix) } +// IsResolved returns true if the stored state in the resolve is fully +// resolved. In this case the target output can be forgotten. +// +// NOTE: Part of the ContractResolver interface. +func (r *contractResolverKit) IsResolved() bool { + return r.resolved.Load() +} + +// markResolved marks the resolver as resolved. +func (r *contractResolverKit) markResolved() { + r.resolved.Store(true) +} + +// isLaunched returns true if the resolver has been launched. +func (r *contractResolverKit) isLaunched() bool { + return r.launched.Load() +} + +// markLaunched marks the resolver as launched. +func (r *contractResolverKit) markLaunched() { + r.launched.Store(true) +} + var ( // errResolverShuttingDown is returned when the resolver stops // progressing because it received the quit signal. diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 73841eb88c..bc5948487d 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" @@ -78,6 +78,37 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error { return nil } +// Launch will call the inner resolver's launch method if the preimage can be +// found, otherwise it's a no-op. +func (h *htlcIncomingContestResolver) Launch() error { + // NOTE: we don't mark this resolver as launched as the inner resolver + // will set it when it's launched. + if h.isLaunched() { + h.log.Tracef("already launched") + return nil + } + + h.log.Debugf("launching contest resolver...") + + // Query the preimage and apply it if we already know it. + applied, err := h.findAndapplyPreimage() + if err != nil { + return err + } + + // No preimage found, leave it to be handled by the resolver. + if !applied { + return nil + } + + h.log.Debugf("found preimage for htlc=%x, transforming into success "+ + "resolver and launching it", h.htlc.RHash) + + // Once we've applied the preimage, we'll launch the inner resolver to + // attempt to claim the HTLC. + return h.htlcSuccessResolver.Launch() +} + // Resolve attempts to resolve this contract. As we don't yet know of the // preimage for the contract, we'll wait for one of two things to happen: // @@ -90,12 +121,11 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error { // as we have no remaining actions left at our disposal. // // NOTE: Part of the ContractResolver interface. -func (h *htlcIncomingContestResolver) Resolve( - _ bool) (ContractResolver, error) { - +func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. - if h.resolved { + if h.IsResolved() { + h.log.Errorf("already resolved") return nil, nil } @@ -103,15 +133,14 @@ func (h *htlcIncomingContestResolver) Resolve( // now. payload, nextHopOnionBlob, err := h.decodePayload() if err != nil { - log.Debugf("ChannelArbitrator(%v): cannot decode payload of "+ - "htlc %v", h.ChanPoint, h.HtlcPoint()) + h.log.Debugf("cannot decode payload of htlc %v", h.HtlcPoint()) // If we've locked in an htlc with an invalid payload on our // commitment tx, we don't need to resolve it. The other party // will time it out and get their funds back. This situation // can present itself when we crash before processRemoteAdds in // the link has ran. - h.resolved = true + h.markResolved() if err := h.processFinalHtlcFail(); err != nil { return nil, err @@ -164,7 +193,7 @@ func (h *htlcIncomingContestResolver) Resolve( log.Infof("%T(%v): HTLC has timed out (expiry=%v, height=%v), "+ "abandoning", h, h.htlcResolution.ClaimOutpoint, h.htlcExpiry, currentHeight) - h.resolved = true + h.markResolved() if err := h.processFinalHtlcFail(); err != nil { return nil, err @@ -179,65 +208,6 @@ func (h *htlcIncomingContestResolver) Resolve( return nil, h.Checkpoint(h, report) } - // applyPreimage is a helper function that will populate our internal - // resolver with the preimage we learn of. This should be called once - // the preimage is revealed so the inner resolver can properly complete - // its duties. The error return value indicates whether the preimage - // was properly applied. - applyPreimage := func(preimage lntypes.Preimage) error { - // Sanity check to see if this preimage matches our htlc. At - // this point it should never happen that it does not match. - if !preimage.Matches(h.htlc.RHash) { - return errors.New("preimage does not match hash") - } - - // Update htlcResolution with the matching preimage. - h.htlcResolution.Preimage = preimage - - log.Infof("%T(%v): applied preimage=%v", h, - h.htlcResolution.ClaimOutpoint, preimage) - - isSecondLevel := h.htlcResolution.SignedSuccessTx != nil - - // If we didn't have to go to the second level to claim (this - // is the remote commitment transaction), then we don't need to - // modify our canned witness. - if !isSecondLevel { - return nil - } - - isTaproot := txscript.IsPayToTaproot( - h.htlcResolution.SignedSuccessTx.TxOut[0].PkScript, - ) - - // If this is our commitment transaction, then we'll need to - // populate the witness for the second-level HTLC transaction. - switch { - // For taproot channels, the witness for sweeping with success - // looks like: - // - - // - // - // So we'll insert it at the 3rd index of the witness. - case isTaproot: - //nolint:ll - h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[2] = preimage[:] - - // Within the witness for the success transaction, the - // preimage is the 4th element as it looks like: - // - // * <0> - // - // We'll populate it within the witness, as since this - // was a "contest" resolver, we didn't yet know of the - // preimage. - case !isTaproot: - h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[3] = preimage[:] - } - - return nil - } - // Define a closure to process htlc resolutions either directly or // triggered by future notifications. processHtlcResolution := func(e invoices.HtlcResolution) ( @@ -249,7 +219,7 @@ func (h *htlcIncomingContestResolver) Resolve( // If the htlc resolution was a settle, apply the // preimage and return a success resolver. case *invoices.HtlcSettleResolution: - err := applyPreimage(resolution.Preimage) + err := h.applyPreimage(resolution.Preimage) if err != nil { return nil, err } @@ -264,7 +234,7 @@ func (h *htlcIncomingContestResolver) Resolve( h.htlcResolution.ClaimOutpoint, h.htlcExpiry, currentHeight) - h.resolved = true + h.markResolved() if err := h.processFinalHtlcFail(); err != nil { return nil, err @@ -314,6 +284,9 @@ func (h *htlcIncomingContestResolver) Resolve( return nil, err } + h.log.Debugf("received resolution from registry: %v", + resolution) + defer func() { h.Registry.HodlUnsubscribeAll(hodlQueue.ChanIn()) @@ -371,7 +344,9 @@ func (h *htlcIncomingContestResolver) Resolve( // However, we don't know how to ourselves, so we'll // return our inner resolver which has the knowledge to // do so. - if err := applyPreimage(preimage); err != nil { + h.log.Debugf("Found preimage for htlc=%x", h.htlc.RHash) + + if err := h.applyPreimage(preimage); err != nil { return nil, err } @@ -390,7 +365,10 @@ func (h *htlcIncomingContestResolver) Resolve( continue } - if err := applyPreimage(preimage); err != nil { + h.log.Debugf("Received preimage for htlc=%x", + h.htlc.RHash) + + if err := h.applyPreimage(preimage); err != nil { return nil, err } @@ -417,7 +395,8 @@ func (h *htlcIncomingContestResolver) Resolve( "(expiry=%v, height=%v), abandoning", h, h.htlcResolution.ClaimOutpoint, h.htlcExpiry, currentHeight) - h.resolved = true + + h.markResolved() if err := h.processFinalHtlcFail(); err != nil { return nil, err @@ -437,6 +416,76 @@ func (h *htlcIncomingContestResolver) Resolve( } } +// applyPreimage is a helper function that will populate our internal resolver +// with the preimage we learn of. This should be called once the preimage is +// revealed so the inner resolver can properly complete its duties. The error +// return value indicates whether the preimage was properly applied. +func (h *htlcIncomingContestResolver) applyPreimage( + preimage lntypes.Preimage) error { + + // Sanity check to see if this preimage matches our htlc. At this point + // it should never happen that it does not match. + if !preimage.Matches(h.htlc.RHash) { + return errors.New("preimage does not match hash") + } + + // We may already have the preimage since both the `Launch` and + // `Resolve` methods will look for it. + if h.htlcResolution.Preimage != lntypes.ZeroHash { + h.log.Debugf("already applied preimage for htlc=%x", + h.htlc.RHash) + + return nil + } + + // Update htlcResolution with the matching preimage. + h.htlcResolution.Preimage = preimage + + log.Infof("%T(%v): applied preimage=%v", h, + h.htlcResolution.ClaimOutpoint, preimage) + + isSecondLevel := h.htlcResolution.SignedSuccessTx != nil + + // If we didn't have to go to the second level to claim (this + // is the remote commitment transaction), then we don't need to + // modify our canned witness. + if !isSecondLevel { + return nil + } + + isTaproot := txscript.IsPayToTaproot( + h.htlcResolution.SignedSuccessTx.TxOut[0].PkScript, + ) + + // If this is our commitment transaction, then we'll need to + // populate the witness for the second-level HTLC transaction. + switch { + // For taproot channels, the witness for sweeping with success + // looks like: + // - + // + // + // So we'll insert it at the 3rd index of the witness. + case isTaproot: + //nolint:ll + h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[2] = preimage[:] + + // Within the witness for the success transaction, the + // preimage is the 4th element as it looks like: + // + // * <0> + // + // We'll populate it within the witness, as since this + // was a "contest" resolver, we didn't yet know of the + // preimage. + case !isTaproot: + //nolint:ll + h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[3] = preimage[:] + } + + return nil +} + // report returns a report on the resolution state of the contract. func (h *htlcIncomingContestResolver) report() *ContractReport { // No locking needed as these values are read-only. @@ -463,17 +512,11 @@ func (h *htlcIncomingContestResolver) report() *ContractReport { // // NOTE: Part of the ContractResolver interface. func (h *htlcIncomingContestResolver) Stop() { + h.log.Debugf("stopping...") + defer h.log.Debugf("stopped") close(h.quit) } -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcIncomingContestResolver) IsResolved() bool { - return h.resolved -} - // Encode writes an encoded version of the ContractResolver into the passed // Writer. // @@ -562,3 +605,82 @@ func (h *htlcIncomingContestResolver) decodePayload() (*hop.Payload, // A compile time assertion to ensure htlcIncomingContestResolver meets the // ContractResolver interface. var _ htlcContractResolver = (*htlcIncomingContestResolver)(nil) + +// findAndapplyPreimage performs a non-blocking read to find the preimage for +// the incoming HTLC. If found, it will be applied to the resolver. This method +// is used for the resolver to decide whether it wants to transform into a +// success resolver during launching. +// +// NOTE: Since we have two places to query the preimage, we need to check both +// the preimage db and the invoice db to look up the preimage. +func (h *htlcIncomingContestResolver) findAndapplyPreimage() (bool, error) { + // Query to see if we already know the preimage. + preimage, ok := h.PreimageDB.LookupPreimage(h.htlc.RHash) + + // If the preimage is known, we'll apply it. + if ok { + if err := h.applyPreimage(preimage); err != nil { + return false, err + } + + // Successfully applied the preimage, we can now return. + return true, nil + } + + // First try to parse the payload. + payload, _, err := h.decodePayload() + if err != nil { + h.log.Errorf("Cannot decode payload of htlc %v", h.HtlcPoint()) + + // If we cannot decode the payload, we will return a nil error + // and let it to be handled in `Resolve`. + return false, nil + } + + // Exit early if this is not the exit hop, which means we are not the + // payment receiver and don't have preimage. + if payload.FwdInfo.NextHop != hop.Exit { + return false, nil + } + + // Notify registry that we are potentially resolving as an exit hop + // on-chain. If this HTLC indeed pays to an existing invoice, the + // invoice registry will tell us what to do with the HTLC. This is + // identical to HTLC resolution in the link. + circuitKey := models.CircuitKey{ + ChanID: h.ShortChanID, + HtlcID: h.htlc.HtlcIndex, + } + + // Try get the resolution - if it doesn't give us a resolution + // immediately, we'll assume we don't know it yet and let the `Resolve` + // handle the waiting. + // + // NOTE: we use a nil subscriber here and a zero current height as we + // are only interested in the settle resolution. + // + // TODO(yy): move this logic to link and let the preimage be accessed + // via the preimage beacon. + resolution, err := h.Registry.NotifyExitHopHtlc( + h.htlc.RHash, h.htlc.Amt, h.htlcExpiry, 0, + circuitKey, nil, nil, payload, + ) + if err != nil { + return false, err + } + + res, ok := resolution.(*invoices.HtlcSettleResolution) + + // Exit early if it's not a settle resolution. + if !ok { + return false, nil + } + + // Otherwise we have a settle resolution, apply the preimage. + err = h.applyPreimage(res.Preimage) + if err != nil { + return false, err + } + + return true, nil +} diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index 22280f953e..f17190e96e 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -5,11 +5,13 @@ import ( "io" "testing" + "github.com/btcsuite/btcd/wire" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" @@ -356,6 +358,7 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver return nil }, + Sweeper: newMockSweeper(), }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { @@ -374,10 +377,16 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver }, } + res := lnwallet.IncomingHtlcResolution{ + SweepSignDesc: input.SignDescriptor{ + Output: &wire.TxOut{}, + }, + } + c.resolver = &htlcIncomingContestResolver{ htlcSuccessResolver: &htlcSuccessResolver{ contractResolverKit: *newContractResolverKit(cfg), - htlcResolution: lnwallet.IncomingHtlcResolution{}, + htlcResolution: res, htlc: channeldb.HTLC{ Amt: lnwire.MilliSatoshi(testHtlcAmount), RHash: testResHash, @@ -386,6 +395,7 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver }, htlcExpiry: testHtlcExpiry, } + c.resolver.initLogger("htlcIncomingContestResolver") return c } @@ -395,7 +405,11 @@ func (i *incomingResolverTestContext) resolve() { i.resolveErr = make(chan error, 1) go func() { var err error - i.nextResolver, err = i.resolver.Resolve(false) + + err = i.resolver.Launch() + require.NoError(i.t, err) + + i.nextResolver, err = i.resolver.Resolve() i.resolveErr <- err }() diff --git a/contractcourt/htlc_lease_resolver.go b/contractcourt/htlc_lease_resolver.go index 53fa893553..6230f96777 100644 --- a/contractcourt/htlc_lease_resolver.go +++ b/contractcourt/htlc_lease_resolver.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/tlv" ) @@ -57,10 +57,10 @@ func (h *htlcLeaseResolver) makeSweepInput(op *wire.OutPoint, signDesc *input.SignDescriptor, csvDelay, broadcastHeight uint32, payHash [32]byte, resBlob fn.Option[tlv.Blob]) *input.BaseInput { - if h.hasCLTV() { - log.Infof("%T(%x): CSV and CLTV locks expired, offering "+ - "second-layer output to sweeper: %v", h, payHash, op) + log.Infof("%T(%x): offering second-layer output to sweeper: %v", h, + payHash, op) + if h.hasCLTV() { return input.NewCsvInputWithCltv( op, cltvWtype, signDesc, broadcastHeight, csvDelay, diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 2466544c98..b66a3fdf0b 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -1,12 +1,11 @@ package contractcourt import ( - "fmt" "io" "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet" ) @@ -36,6 +35,37 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, } } +// Launch will call the inner resolver's launch method if the expiry height has +// been reached, otherwise it's a no-op. +func (h *htlcOutgoingContestResolver) Launch() error { + // NOTE: we don't mark this resolver as launched as the inner resolver + // will set it when it's launched. + if h.isLaunched() { + h.log.Tracef("already launched") + return nil + } + + h.log.Debugf("launching contest resolver...") + + _, bestHeight, err := h.ChainIO.GetBestBlock() + if err != nil { + return err + } + + if uint32(bestHeight) < h.htlcResolution.Expiry { + return nil + } + + // If the current height is >= expiry, then a timeout path spend will + // be valid to be included in the next block, and we can immediately + // return the resolver. + h.log.Infof("expired (height=%v, expiry=%v), transforming into "+ + "timeout resolver and launching it", bestHeight, + h.htlcResolution.Expiry) + + return h.htlcTimeoutResolver.Launch() +} + // Resolve commences the resolution of this contract. As this contract hasn't // yet timed out, we'll wait for one of two things to happen // @@ -49,12 +79,11 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, // When either of these two things happens, we'll create a new resolver which // is able to handle the final resolution of the contract. We're only the pivot // point. -func (h *htlcOutgoingContestResolver) Resolve( - _ bool) (ContractResolver, error) { - +func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. - if h.resolved { + if h.IsResolved() { + h.log.Errorf("already resolved") return nil, nil } @@ -88,8 +117,7 @@ func (h *htlcOutgoingContestResolver) Resolve( return nil, errResolverShuttingDown } - // TODO(roasbeef): Checkpoint? - return h.claimCleanUp(commitSpend) + return nil, h.claimCleanUp(commitSpend) // If it hasn't, then we'll watch for both the expiration, and the // sweeping out this output. @@ -126,12 +154,21 @@ func (h *htlcOutgoingContestResolver) Resolve( // finalized` will be returned and the broadcast will // fail. newHeight := uint32(newBlock.Height) - if newHeight >= h.htlcResolution.Expiry { - log.Infof("%T(%v): HTLC has expired "+ + expiry := h.htlcResolution.Expiry + + // Check if the expiry height is about to be reached. + // We offer this HTLC one block earlier to make sure + // when the next block arrives, the sweeper will pick + // up this input and sweep it immediately. The sweeper + // will handle the waiting for the one last block till + // expiry. + if newHeight >= expiry-1 { + h.log.Infof("HTLC about to expire "+ "(height=%v, expiry=%v), transforming "+ "into timeout resolver", h, h.htlcResolution.ClaimOutpoint, newHeight, h.htlcResolution.Expiry) + return h.htlcTimeoutResolver, nil } @@ -146,10 +183,10 @@ func (h *htlcOutgoingContestResolver) Resolve( // party is by revealing the preimage. So we'll perform // our duties to clean up the contract once it has been // claimed. - return h.claimCleanUp(commitSpend) + return nil, h.claimCleanUp(commitSpend) case <-h.quit: - return nil, fmt.Errorf("resolver canceled") + return nil, errResolverShuttingDown } } } @@ -180,17 +217,11 @@ func (h *htlcOutgoingContestResolver) report() *ContractReport { // // NOTE: Part of the ContractResolver interface. func (h *htlcOutgoingContestResolver) Stop() { + h.log.Debugf("stopping...") + defer h.log.Debugf("stopped") close(h.quit) } -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcOutgoingContestResolver) IsResolved() bool { - return h.resolved -} - // Encode writes an encoded version of the ContractResolver into the passed // Writer. // diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index 6608a6fb51..625df60bf1 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) const ( @@ -159,6 +160,7 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { return nil }, + ChainIO: &mock.ChainIO{}, }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { @@ -195,6 +197,7 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { }, }, } + resolver.initLogger("htlcOutgoingContestResolver") return &outgoingResolverTestContext{ resolver: resolver, @@ -209,7 +212,10 @@ func (i *outgoingResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + err := i.resolver.Launch() + require.NoError(i.t, err) + + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index b2716ad305..a4d27ba4e8 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -2,6 +2,7 @@ package contractcourt import ( "encoding/binary" + "fmt" "io" "sync" @@ -9,10 +10,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" @@ -43,9 +42,6 @@ type htlcSuccessResolver struct { // second-level output (true). outputIncubating bool - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // broadcastHeight is the height that the original contract was // broadcast to the main-chain at. We'll use this value to bound any // historical queries to the chain for spends/confirmations. @@ -81,27 +77,30 @@ func newSuccessResolver(res lnwallet.IncomingHtlcResolution, } h.initReport() + h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint())) return h } -// ResolverKey returns an identifier which should be globally unique for this -// particular resolver within the chain the original contract resides within. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) ResolverKey() []byte { +// outpoint returns the outpoint of the HTLC output we're attempting to sweep. +func (h *htlcSuccessResolver) outpoint() wire.OutPoint { // The primary key for this resolver will be the outpoint of the HTLC // on the commitment transaction itself. If this is our commitment, // then the output can be found within the signed success tx, // otherwise, it's just the ClaimOutpoint. - var op wire.OutPoint if h.htlcResolution.SignedSuccessTx != nil { - op = h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint - } else { - op = h.htlcResolution.ClaimOutpoint + return h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint } - key := newResolverID(op) + return h.htlcResolution.ClaimOutpoint +} + +// ResolverKey returns an identifier which should be globally unique for this +// particular resolver within the chain the original contract resides within. +// +// NOTE: Part of the ContractResolver interface. +func (h *htlcSuccessResolver) ResolverKey() []byte { + key := newResolverID(h.outpoint()) return key[:] } @@ -112,423 +111,66 @@ func (h *htlcSuccessResolver) ResolverKey() []byte { // anymore. Every HTLC has already passed through the incoming contest resolver // and in there the invoice was already marked as settled. // -// TODO(roasbeef): create multi to batch -// // NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) Resolve( - immediate bool) (ContractResolver, error) { - - // If we're already resolved, then we can exit early. - if h.resolved { - return nil, nil - } - - // If we don't have a success transaction, then this means that this is - // an output on the remote party's commitment transaction. - if h.htlcResolution.SignedSuccessTx == nil { - return h.resolveRemoteCommitOutput(immediate) - } - - // Otherwise this an output on our own commitment, and we must start by - // broadcasting the second-level success transaction. - secondLevelOutpoint, err := h.broadcastSuccessTx(immediate) - if err != nil { - return nil, err - } - - // To wrap this up, we'll wait until the second-level transaction has - // been spent, then fully resolve the contract. - log.Infof("%T(%x): waiting for second-level HTLC output to be spent "+ - "after csv_delay=%v", h, h.htlc.RHash[:], h.htlcResolution.CsvDelay) - - spend, err := waitForSpend( - secondLevelOutpoint, - h.htlcResolution.SweepSignDesc.Output.PkScript, - h.broadcastHeight, h.Notifier, h.quit, - ) - if err != nil { - return nil, err - } - - h.reportLock.Lock() - h.currentReport.RecoveredBalance = h.currentReport.LimboBalance - h.currentReport.LimboBalance = 0 - h.reportLock.Unlock() - - h.resolved = true - return nil, h.checkpointClaim( - spend.SpenderTxHash, channeldb.ResolverOutcomeClaimed, - ) -} - -// broadcastSuccessTx handles an HTLC output on our local commitment by -// broadcasting the second-level success transaction. It returns the ultimate -// outpoint of the second-level tx, that we must wait to be spent for the -// resolver to be fully resolved. -func (h *htlcSuccessResolver) broadcastSuccessTx( - immediate bool) (*wire.OutPoint, error) { - - // If we have non-nil SignDetails, this means that have a 2nd level - // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY - // (the case for anchor type channels). In this case we can re-sign it - // and attach fees at will. We let the sweeper handle this job. We use - // the checkpointed outputIncubating field to determine if we already - // swept the HTLC output into the second level transaction. - if h.htlcResolution.SignDetails != nil { - return h.broadcastReSignedSuccessTx(immediate) - } - - // Otherwise we'll publish the second-level transaction directly and - // offer the resolution to the nursery to handle. - log.Infof("%T(%x): broadcasting second-layer transition tx: %v", - h, h.htlc.RHash[:], spew.Sdump(h.htlcResolution.SignedSuccessTx)) - - // We'll now broadcast the second layer transaction so we can kick off - // the claiming process. - // - // TODO(roasbeef): after changing sighashes send to tx bundler - label := labels.MakeLabel( - labels.LabelTypeChannelClose, &h.ShortChanID, - ) - err := h.PublishTx(h.htlcResolution.SignedSuccessTx, label) - if err != nil { - return nil, err - } - - // Otherwise, this is an output on our commitment transaction. In this - // case, we'll send it to the incubator, but only if we haven't already - // done so. - if !h.outputIncubating { - log.Infof("%T(%x): incubating incoming htlc output", - h, h.htlc.RHash[:]) - - err := h.IncubateOutputs( - h.ChanPoint, fn.None[lnwallet.OutgoingHtlcResolution](), - fn.Some(h.htlcResolution), - h.broadcastHeight, fn.Some(int32(h.htlc.RefundTimeout)), - ) - if err != nil { - return nil, err - } - - h.outputIncubating = true - - if err := h.Checkpoint(h); err != nil { - log.Errorf("unable to Checkpoint: %v", err) - return nil, err - } - } - - return &h.htlcResolution.ClaimOutpoint, nil -} - -// broadcastReSignedSuccessTx handles the case where we have non-nil -// SignDetails, and offers the second level transaction to the Sweeper, that -// will re-sign it and attach fees at will. // -//nolint:funlen -func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( - *wire.OutPoint, error) { - - // Keep track of the tx spending the HTLC output on the commitment, as - // this will be the confirmed second-level tx we'll ultimately sweep. - var commitSpend *chainntnfs.SpendDetail - - // We will have to let the sweeper re-sign the success tx and wait for - // it to confirm, if we haven't already. - isTaproot := txscript.IsPayToTaproot( - h.htlcResolution.SweepSignDesc.Output.PkScript, - ) - if !h.outputIncubating { - var secondLevelInput input.HtlcSecondLevelAnchorInput - if isTaproot { - //nolint:ll - secondLevelInput = input.MakeHtlcSecondLevelSuccessTaprootInput( - h.htlcResolution.SignedSuccessTx, - h.htlcResolution.SignDetails, h.htlcResolution.Preimage, - h.broadcastHeight, - input.WithResolutionBlob( - h.htlcResolution.ResolutionBlob, - ), - ) - } else { - //nolint:ll - secondLevelInput = input.MakeHtlcSecondLevelSuccessAnchorInput( - h.htlcResolution.SignedSuccessTx, - h.htlcResolution.SignDetails, h.htlcResolution.Preimage, - h.broadcastHeight, - ) - } - - // Calculate the budget for this sweep. - value := btcutil.Amount( - secondLevelInput.SignDesc().Output.Value, - ) - budget := calculateBudget( - value, h.Budget.DeadlineHTLCRatio, - h.Budget.DeadlineHTLC, - ) - - // The deadline would be the CLTV in this HTLC output. If we - // are the initiator of this force close, with the default - // `IncomingBroadcastDelta`, it means we have 10 blocks left - // when going onchain. Given we need to mine one block to - // confirm the force close tx, and one more block to trigger - // the sweep, we have 8 blocks left to sweep the HTLC. - deadline := fn.Some(int32(h.htlc.RefundTimeout)) - - log.Infof("%T(%x): offering second-level HTLC success tx to "+ - "sweeper with deadline=%v, budget=%v", h, - h.htlc.RHash[:], h.htlc.RefundTimeout, budget) - - // We'll now offer the second-level transaction to the sweeper. - _, err := h.Sweeper.SweepInput( - &secondLevelInput, - sweep.Params{ - Budget: budget, - DeadlineHeight: deadline, - Immediate: immediate, - }, - ) - if err != nil { - return nil, err - } - - log.Infof("%T(%x): waiting for second-level HTLC success "+ - "transaction to confirm", h, h.htlc.RHash[:]) - - // Wait for the second level transaction to confirm. - commitSpend, err = waitForSpend( - &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint, - h.htlcResolution.SignDetails.SignDesc.Output.PkScript, - h.broadcastHeight, h.Notifier, h.quit, - ) - if err != nil { - return nil, err - } +// TODO(yy): refactor the interface method to return an error only. +func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { + var err error - // Now that the second-level transaction has confirmed, we - // checkpoint the state so we'll go to the next stage in case - // of restarts. - h.outputIncubating = true - if err := h.Checkpoint(h); err != nil { - log.Errorf("unable to Checkpoint: %v", err) - return nil, err - } - - log.Infof("%T(%x): second-level HTLC success transaction "+ - "confirmed!", h, h.htlc.RHash[:]) - } - - // If we ended up here after a restart, we must again get the - // spend notification. - if commitSpend == nil { - var err error - commitSpend, err = waitForSpend( - &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint, - h.htlcResolution.SignDetails.SignDesc.Output.PkScript, - h.broadcastHeight, h.Notifier, h.quit, - ) - if err != nil { - return nil, err - } - } - - // The HTLC success tx has a CSV lock that we must wait for, and if - // this is a lease enforced channel and we're the imitator, we may need - // to wait for longer. - waitHeight := h.deriveWaitHeight( - h.htlcResolution.CsvDelay, commitSpend, - ) - - // Now that the sweeper has broadcasted the second-level transaction, - // it has confirmed, and we have checkpointed our state, we'll sweep - // the second level output. We report the resolver has moved the next - // stage. - h.reportLock.Lock() - h.currentReport.Stage = 2 - h.currentReport.MaturityHeight = waitHeight - h.reportLock.Unlock() - - if h.hasCLTV() { - log.Infof("%T(%x): waiting for CSV and CLTV lock to "+ - "expire at height %v", h, h.htlc.RHash[:], - waitHeight) - } else { - log.Infof("%T(%x): waiting for CSV lock to expire at "+ - "height %v", h, h.htlc.RHash[:], waitHeight) - } - - // Deduct one block so this input is offered to the sweeper one block - // earlier since the sweeper will wait for one block to trigger the - // sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) - if err != nil { - return nil, err - } - - // We'll use this input index to determine the second-level output - // index on the transaction, as the signatures requires the indexes to - // be the same. We don't look for the second-level output script - // directly, as there might be more than one HTLC output to the same - // pkScript. - op := &wire.OutPoint{ - Hash: *commitSpend.SpenderTxHash, - Index: commitSpend.SpenderInputIndex, - } - - // Let the sweeper sweep the second-level output now that the - // CSV/CLTV locks have expired. - var witType input.StandardWitnessType - if isTaproot { - witType = input.TaprootHtlcAcceptedSuccessSecondLevel - } else { - witType = input.HtlcAcceptedSuccessSecondLevel - } - inp := h.makeSweepInput( - op, witType, - input.LeaseHtlcAcceptedSuccessSecondLevel, - &h.htlcResolution.SweepSignDesc, - h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight), - h.htlc.RHash, h.htlcResolution.ResolutionBlob, - ) - - // Calculate the budget for this sweep. - budget := calculateBudget( - btcutil.Amount(inp.SignDesc().Output.Value), - h.Budget.NoDeadlineHTLCRatio, - h.Budget.NoDeadlineHTLC, - ) + switch { + // If we're already resolved, then we can exit early. + case h.IsResolved(): + h.log.Errorf("already resolved") - log.Infof("%T(%x): offering second-level success tx output to sweeper "+ - "with no deadline and budget=%v at height=%v", h, - h.htlc.RHash[:], budget, waitHeight) + // If this is an output on the remote party's commitment transaction, + // use the direct-spend path to sweep the htlc. + case h.isRemoteCommitOutput(): + err = h.resolveRemoteCommitOutput() - // TODO(roasbeef): need to update above for leased types - _, err = h.Sweeper.SweepInput( - inp, - sweep.Params{ - Budget: budget, + // If this is an output on our commitment transaction using post-anchor + // channel type, it will be handled by the sweeper. + case h.isZeroFeeOutput(): + err = h.resolveSuccessTx() - // For second level success tx, there's no rush to get - // it confirmed, so we use a nil deadline. - DeadlineHeight: fn.None[int32](), - }, - ) - if err != nil { - return nil, err + // If this is an output on our own commitment using pre-anchor channel + // type, we will publish the success tx and offer the output to the + // nursery. + default: + err = h.resolveLegacySuccessTx() } - // Will return this outpoint, when this is spent the resolver is fully - // resolved. - return op, nil + return nil, err } // resolveRemoteCommitOutput handles sweeping an HTLC output on the remote // commitment with the preimage. In this case we can sweep the output directly, // and don't have to broadcast a second-level transaction. -func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) ( - ContractResolver, error) { - - isTaproot := txscript.IsPayToTaproot( - h.htlcResolution.SweepSignDesc.Output.PkScript, - ) - - // Before we can craft out sweeping transaction, we need to - // create an input which contains all the items required to add - // this input to a sweeping transaction, and generate a - // witness. - var inp input.Input - if isTaproot { - inp = lnutils.Ptr(input.MakeTaprootHtlcSucceedInput( - &h.htlcResolution.ClaimOutpoint, - &h.htlcResolution.SweepSignDesc, - h.htlcResolution.Preimage[:], - h.broadcastHeight, - h.htlcResolution.CsvDelay, - input.WithResolutionBlob( - h.htlcResolution.ResolutionBlob, - ), - )) - } else { - inp = lnutils.Ptr(input.MakeHtlcSucceedInput( - &h.htlcResolution.ClaimOutpoint, - &h.htlcResolution.SweepSignDesc, - h.htlcResolution.Preimage[:], - h.broadcastHeight, - h.htlcResolution.CsvDelay, - )) - } - - // Calculate the budget for this sweep. - budget := calculateBudget( - btcutil.Amount(inp.SignDesc().Output.Value), - h.Budget.DeadlineHTLCRatio, - h.Budget.DeadlineHTLC, - ) - - deadline := fn.Some(int32(h.htlc.RefundTimeout)) - - log.Infof("%T(%x): offering direct-preimage HTLC output to sweeper "+ - "with deadline=%v, budget=%v", h, h.htlc.RHash[:], - h.htlc.RefundTimeout, budget) - - // We'll now offer the direct preimage HTLC to the sweeper. - _, err := h.Sweeper.SweepInput( - inp, - sweep.Params{ - Budget: budget, - DeadlineHeight: deadline, - Immediate: immediate, - }, - ) - if err != nil { - return nil, err - } +func (h *htlcSuccessResolver) resolveRemoteCommitOutput() error { + h.log.Info("waiting for direct-preimage spend of the htlc to confirm") // Wait for the direct-preimage HTLC sweep tx to confirm. + // + // TODO(yy): use the result chan returned from `SweepInput`. sweepTxDetails, err := waitForSpend( &h.htlcResolution.ClaimOutpoint, h.htlcResolution.SweepSignDesc.Output.PkScript, h.broadcastHeight, h.Notifier, h.quit, ) if err != nil { - return nil, err + return err } - // Once the transaction has received a sufficient number of - // confirmations, we'll mark ourselves as fully resolved and exit. - h.resolved = true + // TODO(yy): should also update the `RecoveredBalance` and + // `LimboBalance` like other paths? // Checkpoint the resolver, and write the outcome to disk. - return nil, h.checkpointClaim( - sweepTxDetails.SpenderTxHash, - channeldb.ResolverOutcomeClaimed, - ) + return h.checkpointClaim(sweepTxDetails.SpenderTxHash) } // checkpointClaim checkpoints the success resolver with the reports it needs. // If this htlc was claimed two stages, it will write reports for both stages, // otherwise it will just write for the single htlc claim. -func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash, - outcome channeldb.ResolverOutcome) error { - +func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash) error { // Mark the htlc as final settled. err := h.ChainArbitratorConfig.PutFinalHtlcOutcome( h.ChannelArbitratorConfig.ShortChanID, h.htlc.HtlcIndex, true, @@ -556,7 +198,7 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash, OutPoint: h.htlcResolution.ClaimOutpoint, Amount: amt, ResolverType: channeldb.ResolverTypeIncomingHtlc, - ResolverOutcome: outcome, + ResolverOutcome: channeldb.ResolverOutcomeClaimed, SpendTxID: spendTx, }, } @@ -581,6 +223,7 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash, } // Finally, we checkpoint the resolver with our report(s). + h.markResolved() return h.Checkpoint(h, reports...) } @@ -589,15 +232,10 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash, // // NOTE: Part of the ContractResolver interface. func (h *htlcSuccessResolver) Stop() { - close(h.quit) -} + h.log.Debugf("stopping...") + defer h.log.Debugf("stopped") -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) IsResolved() bool { - return h.resolved + close(h.quit) } // report returns a report on the resolution state of the contract. @@ -649,7 +287,7 @@ func (h *htlcSuccessResolver) Encode(w io.Writer) error { if err := binary.Write(w, endian, h.outputIncubating); err != nil { return err } - if err := binary.Write(w, endian, h.resolved); err != nil { + if err := binary.Write(w, endian, h.IsResolved()); err != nil { return err } if err := binary.Write(w, endian, h.broadcastHeight); err != nil { @@ -688,9 +326,15 @@ func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) ( if err := binary.Read(r, endian, &h.outputIncubating); err != nil { return nil, err } - if err := binary.Read(r, endian, &h.resolved); err != nil { + + var resolved bool + if err := binary.Read(r, endian, &resolved); err != nil { return nil, err } + if resolved { + h.markResolved() + } + if err := binary.Read(r, endian, &h.broadcastHeight); err != nil { return nil, err } @@ -709,6 +353,7 @@ func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) ( } h.initReport() + h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint())) return h, nil } @@ -737,3 +382,391 @@ func (h *htlcSuccessResolver) SupplementDeadline(_ fn.Option[int32]) { // A compile time assertion to ensure htlcSuccessResolver meets the // ContractResolver interface. var _ htlcContractResolver = (*htlcSuccessResolver)(nil) + +// isRemoteCommitOutput returns a bool to indicate whether the htlc output is +// on the remote commitment. +func (h *htlcSuccessResolver) isRemoteCommitOutput() bool { + // If we don't have a success transaction, then this means that this is + // an output on the remote party's commitment transaction. + return h.htlcResolution.SignedSuccessTx == nil +} + +// isZeroFeeOutput returns a boolean indicating whether the htlc output is from +// a anchor-enabled channel, which uses the sighash SINGLE|ANYONECANPAY. +func (h *htlcSuccessResolver) isZeroFeeOutput() bool { + // If we have non-nil SignDetails, this means it has a 2nd level HTLC + // transaction that is signed using sighash SINGLE|ANYONECANPAY (the + // case for anchor type channels). In this case we can re-sign it and + // attach fees at will. + return h.htlcResolution.SignedSuccessTx != nil && + h.htlcResolution.SignDetails != nil +} + +// isTaproot returns true if the resolver is for a taproot output. +func (h *htlcSuccessResolver) isTaproot() bool { + return txscript.IsPayToTaproot( + h.htlcResolution.SweepSignDesc.Output.PkScript, + ) +} + +// sweepRemoteCommitOutput creates a sweep request to sweep the HTLC output on +// the remote commitment via the direct preimage-spend. +func (h *htlcSuccessResolver) sweepRemoteCommitOutput() error { + // Before we can craft out sweeping transaction, we need to create an + // input which contains all the items required to add this input to a + // sweeping transaction, and generate a witness. + var inp input.Input + + if h.isTaproot() { + inp = lnutils.Ptr(input.MakeTaprootHtlcSucceedInput( + &h.htlcResolution.ClaimOutpoint, + &h.htlcResolution.SweepSignDesc, + h.htlcResolution.Preimage[:], + h.broadcastHeight, + h.htlcResolution.CsvDelay, + input.WithResolutionBlob( + h.htlcResolution.ResolutionBlob, + ), + )) + } else { + inp = lnutils.Ptr(input.MakeHtlcSucceedInput( + &h.htlcResolution.ClaimOutpoint, + &h.htlcResolution.SweepSignDesc, + h.htlcResolution.Preimage[:], + h.broadcastHeight, + h.htlcResolution.CsvDelay, + )) + } + + // Calculate the budget for this sweep. + budget := calculateBudget( + btcutil.Amount(inp.SignDesc().Output.Value), + h.Budget.DeadlineHTLCRatio, + h.Budget.DeadlineHTLC, + ) + + deadline := fn.Some(int32(h.htlc.RefundTimeout)) + + log.Infof("%T(%x): offering direct-preimage HTLC output to sweeper "+ + "with deadline=%v, budget=%v", h, h.htlc.RHash[:], + h.htlc.RefundTimeout, budget) + + // We'll now offer the direct preimage HTLC to the sweeper. + _, err := h.Sweeper.SweepInput( + inp, + sweep.Params{ + Budget: budget, + DeadlineHeight: deadline, + }, + ) + + return err +} + +// sweepSuccessTx attempts to sweep the second level success tx. +func (h *htlcSuccessResolver) sweepSuccessTx() error { + var secondLevelInput input.HtlcSecondLevelAnchorInput + if h.isTaproot() { + secondLevelInput = input.MakeHtlcSecondLevelSuccessTaprootInput( + h.htlcResolution.SignedSuccessTx, + h.htlcResolution.SignDetails, h.htlcResolution.Preimage, + h.broadcastHeight, input.WithResolutionBlob( + h.htlcResolution.ResolutionBlob, + ), + ) + } else { + secondLevelInput = input.MakeHtlcSecondLevelSuccessAnchorInput( + h.htlcResolution.SignedSuccessTx, + h.htlcResolution.SignDetails, h.htlcResolution.Preimage, + h.broadcastHeight, + ) + } + + // Calculate the budget for this sweep. + value := btcutil.Amount(secondLevelInput.SignDesc().Output.Value) + budget := calculateBudget( + value, h.Budget.DeadlineHTLCRatio, h.Budget.DeadlineHTLC, + ) + + // The deadline would be the CLTV in this HTLC output. If we are the + // initiator of this force close, with the default + // `IncomingBroadcastDelta`, it means we have 10 blocks left when going + // onchain. + deadline := fn.Some(int32(h.htlc.RefundTimeout)) + + h.log.Infof("offering second-level HTLC success tx to sweeper with "+ + "deadline=%v, budget=%v", h.htlc.RefundTimeout, budget) + + // We'll now offer the second-level transaction to the sweeper. + _, err := h.Sweeper.SweepInput( + &secondLevelInput, + sweep.Params{ + Budget: budget, + DeadlineHeight: deadline, + }, + ) + + return err +} + +// sweepSuccessTxOutput attempts to sweep the output of the second level +// success tx. +func (h *htlcSuccessResolver) sweepSuccessTxOutput() error { + h.log.Debugf("sweeping output %v from 2nd-level HTLC success tx", + h.htlcResolution.ClaimOutpoint) + + // This should be non-blocking as we will only attempt to sweep the + // output when the second level tx has already been confirmed. In other + // words, waitForSpend will return immediately. + commitSpend, err := waitForSpend( + &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint, + h.htlcResolution.SignDetails.SignDesc.Output.PkScript, + h.broadcastHeight, h.Notifier, h.quit, + ) + if err != nil { + return err + } + + // The HTLC success tx has a CSV lock that we must wait for, and if + // this is a lease enforced channel and we're the imitator, we may need + // to wait for longer. + waitHeight := h.deriveWaitHeight(h.htlcResolution.CsvDelay, commitSpend) + + // Now that the sweeper has broadcasted the second-level transaction, + // it has confirmed, and we have checkpointed our state, we'll sweep + // the second level output. We report the resolver has moved the next + // stage. + h.reportLock.Lock() + h.currentReport.Stage = 2 + h.currentReport.MaturityHeight = waitHeight + h.reportLock.Unlock() + + if h.hasCLTV() { + log.Infof("%T(%x): waiting for CSV and CLTV lock to expire at "+ + "height %v", h, h.htlc.RHash[:], waitHeight) + } else { + log.Infof("%T(%x): waiting for CSV lock to expire at height %v", + h, h.htlc.RHash[:], waitHeight) + } + + // We'll use this input index to determine the second-level output + // index on the transaction, as the signatures requires the indexes to + // be the same. We don't look for the second-level output script + // directly, as there might be more than one HTLC output to the same + // pkScript. + op := &wire.OutPoint{ + Hash: *commitSpend.SpenderTxHash, + Index: commitSpend.SpenderInputIndex, + } + + // Let the sweeper sweep the second-level output now that the + // CSV/CLTV locks have expired. + var witType input.StandardWitnessType + if h.isTaproot() { + witType = input.TaprootHtlcAcceptedSuccessSecondLevel + } else { + witType = input.HtlcAcceptedSuccessSecondLevel + } + inp := h.makeSweepInput( + op, witType, + input.LeaseHtlcAcceptedSuccessSecondLevel, + &h.htlcResolution.SweepSignDesc, + h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight), + h.htlc.RHash, h.htlcResolution.ResolutionBlob, + ) + + // Calculate the budget for this sweep. + budget := calculateBudget( + btcutil.Amount(inp.SignDesc().Output.Value), + h.Budget.NoDeadlineHTLCRatio, + h.Budget.NoDeadlineHTLC, + ) + + log.Infof("%T(%x): offering second-level success tx output to sweeper "+ + "with no deadline and budget=%v at height=%v", h, + h.htlc.RHash[:], budget, waitHeight) + + // TODO(yy): use the result chan returned from SweepInput. + _, err = h.Sweeper.SweepInput( + inp, + sweep.Params{ + Budget: budget, + + // For second level success tx, there's no rush to get + // it confirmed, so we use a nil deadline. + DeadlineHeight: fn.None[int32](), + }, + ) + + return err +} + +// resolveLegacySuccessTx handles an HTLC output from a pre-anchor type channel +// by broadcasting the second-level success transaction. +func (h *htlcSuccessResolver) resolveLegacySuccessTx() error { + // Otherwise we'll publish the second-level transaction directly and + // offer the resolution to the nursery to handle. + h.log.Infof("broadcasting legacy second-level success tx: %v", + h.htlcResolution.SignedSuccessTx.TxHash()) + + // We'll now broadcast the second layer transaction so we can kick off + // the claiming process. + // + // TODO(yy): offer it to the sweeper instead. + label := labels.MakeLabel( + labels.LabelTypeChannelClose, &h.ShortChanID, + ) + err := h.PublishTx(h.htlcResolution.SignedSuccessTx, label) + if err != nil { + return err + } + + // Fast-forward to resolve the output from the success tx if the it has + // already been sent to the UtxoNursery. + if h.outputIncubating { + return h.resolveSuccessTxOutput(h.htlcResolution.ClaimOutpoint) + } + + h.log.Infof("incubating incoming htlc output") + + // Send the output to the incubator. + err = h.IncubateOutputs( + h.ChanPoint, fn.None[lnwallet.OutgoingHtlcResolution](), + fn.Some(h.htlcResolution), + h.broadcastHeight, fn.Some(int32(h.htlc.RefundTimeout)), + ) + if err != nil { + return err + } + + // Mark the output as incubating and checkpoint it. + h.outputIncubating = true + if err := h.Checkpoint(h); err != nil { + return err + } + + // Move to resolve the output. + return h.resolveSuccessTxOutput(h.htlcResolution.ClaimOutpoint) +} + +// resolveSuccessTx waits for the sweeping tx of the second-level success tx to +// confirm and offers the output from the success tx to the sweeper. +func (h *htlcSuccessResolver) resolveSuccessTx() error { + h.log.Infof("waiting for 2nd-level HTLC success transaction to confirm") + + // Create aliases to make the code more readable. + outpoint := h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint + pkScript := h.htlcResolution.SignDetails.SignDesc.Output.PkScript + + // Wait for the second level transaction to confirm. + commitSpend, err := waitForSpend( + &outpoint, pkScript, h.broadcastHeight, h.Notifier, h.quit, + ) + if err != nil { + return err + } + + // We'll use this input index to determine the second-level output + // index on the transaction, as the signatures requires the indexes to + // be the same. We don't look for the second-level output script + // directly, as there might be more than one HTLC output to the same + // pkScript. + op := wire.OutPoint{ + Hash: *commitSpend.SpenderTxHash, + Index: commitSpend.SpenderInputIndex, + } + + // If the 2nd-stage sweeping has already been started, we can + // fast-forward to start the resolving process for the stage two + // output. + if h.outputIncubating { + return h.resolveSuccessTxOutput(op) + } + + // Now that the second-level transaction has confirmed, we checkpoint + // the state so we'll go to the next stage in case of restarts. + h.outputIncubating = true + if err := h.Checkpoint(h); err != nil { + log.Errorf("unable to Checkpoint: %v", err) + return err + } + + h.log.Infof("2nd-level HTLC success tx=%v confirmed", + commitSpend.SpenderTxHash) + + // Send the sweep request for the output from the success tx. + if err := h.sweepSuccessTxOutput(); err != nil { + return err + } + + return h.resolveSuccessTxOutput(op) +} + +// resolveSuccessTxOutput waits for the spend of the output from the 2nd-level +// success tx. +func (h *htlcSuccessResolver) resolveSuccessTxOutput(op wire.OutPoint) error { + // To wrap this up, we'll wait until the second-level transaction has + // been spent, then fully resolve the contract. + log.Infof("%T(%x): waiting for second-level HTLC output to be spent "+ + "after csv_delay=%v", h, h.htlc.RHash[:], + h.htlcResolution.CsvDelay) + + spend, err := waitForSpend( + &op, h.htlcResolution.SweepSignDesc.Output.PkScript, + h.broadcastHeight, h.Notifier, h.quit, + ) + if err != nil { + return err + } + + h.reportLock.Lock() + h.currentReport.RecoveredBalance = h.currentReport.LimboBalance + h.currentReport.LimboBalance = 0 + h.reportLock.Unlock() + + return h.checkpointClaim(spend.SpenderTxHash) +} + +// Launch creates an input based on the details of the incoming htlc resolution +// and offers it to the sweeper. +func (h *htlcSuccessResolver) Launch() error { + if h.isLaunched() { + h.log.Tracef("already launched") + return nil + } + + h.log.Debugf("launching resolver...") + h.markLaunched() + + switch { + // If we're already resolved, then we can exit early. + case h.IsResolved(): + h.log.Errorf("already resolved") + return nil + + // If this is an output on the remote party's commitment transaction, + // use the direct-spend path. + case h.isRemoteCommitOutput(): + return h.sweepRemoteCommitOutput() + + // If this is an anchor type channel, we now sweep either the + // second-level success tx or the output from the second-level success + // tx. + case h.isZeroFeeOutput(): + // If the second-level success tx has already been swept, we + // can go ahead and sweep its output. + if h.outputIncubating { + return h.sweepSuccessTxOutput() + } + + // Otherwise, sweep the second level tx. + return h.sweepSuccessTx() + + // If this is a legacy channel type, the output is handled by the + // nursery via the Resolve so we do nothing here. + // + // TODO(yy): handle the legacy output by offering it to the sweeper. + default: + return nil + } +} diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index 23023729fa..fe6ee1ad0e 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "testing" + "time" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -12,7 +13,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" @@ -20,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) var testHtlcAmt = lnwire.MilliSatoshi(200000) @@ -39,6 +41,15 @@ type htlcResolverTestContext struct { t *testing.T } +func newHtlcResolverTestContextFromReader(t *testing.T, + newResolver func(htlc channeldb.HTLC, + cfg ResolverConfig) ContractResolver) *htlcResolverTestContext { + + ctx := newHtlcResolverTestContext(t, newResolver) + + return ctx +} + func newHtlcResolverTestContext(t *testing.T, newResolver func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver) *htlcResolverTestContext { @@ -133,8 +144,12 @@ func newHtlcResolverTestContext(t *testing.T, func (i *htlcResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) + go func() { - nextResolver, err := i.resolver.Resolve(false) + err := i.resolver.Launch() + require.NoError(i.t, err) + + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -192,6 +207,7 @@ func TestHtlcSuccessSingleStage(t *testing.T) { // sweeper. details := &chainntnfs.SpendDetail{ SpendingTx: sweepTx, + SpentOutPoint: &htlcOutpoint, SpenderTxHash: &sweepTxid, } ctx.notifier.SpendChan <- details @@ -215,8 +231,8 @@ func TestHtlcSuccessSingleStage(t *testing.T) { ) } -// TestSecondStageResolution tests successful sweep of a second stage htlc -// claim, going through the Nursery. +// TestHtlcSuccessSecondStageResolution tests successful sweep of a second +// stage htlc claim, going through the Nursery. func TestHtlcSuccessSecondStageResolution(t *testing.T) { commitOutpoint := wire.OutPoint{Index: 2} htlcOutpoint := wire.OutPoint{Index: 3} @@ -279,6 +295,7 @@ func TestHtlcSuccessSecondStageResolution(t *testing.T) { ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: sweepTx, + SpentOutPoint: &htlcOutpoint, SpenderTxHash: &sweepHash, } @@ -302,6 +319,8 @@ func TestHtlcSuccessSecondStageResolution(t *testing.T) { // TestHtlcSuccessSecondStageResolutionSweeper test that a resolver with // non-nil SignDetails will offer the second-level transaction to the sweeper // for re-signing. +// +//nolint:ll func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { commitOutpoint := wire.OutPoint{Index: 2} htlcOutpoint := wire.OutPoint{Index: 3} @@ -399,7 +418,20 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { _ bool) error { resolver := ctx.resolver.(*htlcSuccessResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs + + var ( + inp input.Input + ok bool + ) + + select { + case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs: + require.True(t, ok) + + case <-time.After(1 * time.Second): + t.Fatal("expected input to be swept") + } + op := inp.OutPoint() if op != commitOutpoint { return fmt.Errorf("outpoint %v swept, "+ @@ -412,6 +444,7 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { SpenderTxHash: &reSignedHash, SpenderInputIndex: 1, SpendingHeight: 10, + SpentOutPoint: &commitOutpoint, } return nil }, @@ -434,17 +467,37 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { SpenderTxHash: &reSignedHash, SpenderInputIndex: 1, SpendingHeight: 10, + SpentOutPoint: &commitOutpoint, } } - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } - // We expect it to sweep the second-level // transaction we notfied about above. resolver := ctx.resolver.(*htlcSuccessResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs + + // Mock `waitForSpend` to return the commit + // spend. + ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: reSignedSuccessTx, + SpenderTxHash: &reSignedHash, + SpenderInputIndex: 1, + SpendingHeight: 10, + SpentOutPoint: &commitOutpoint, + } + + var ( + inp input.Input + ok bool + ) + + select { + case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs: + require.True(t, ok) + + case <-time.After(1 * time.Second): + t.Fatal("expected input to be swept") + } + op := inp.OutPoint() exp := wire.OutPoint{ Hash: reSignedHash, @@ -461,6 +514,7 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { SpendingTx: sweepTx, SpenderTxHash: &sweepHash, SpendingHeight: 14, + SpentOutPoint: &op, } return nil @@ -508,11 +562,14 @@ func testHtlcSuccess(t *testing.T, resolution lnwallet.IncomingHtlcResolution, // for the next portion of the test. ctx := newHtlcResolverTestContext(t, func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver { - return &htlcSuccessResolver{ + r := &htlcSuccessResolver{ contractResolverKit: *newContractResolverKit(cfg), htlc: htlc, htlcResolution: resolution, } + r.initLogger("htlcSuccessResolver") + + return r }, ) @@ -562,11 +619,11 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext, var resolved, incubating bool if h, ok := resolver.(*htlcSuccessResolver); ok { - resolved = h.resolved + resolved = h.resolved.Load() incubating = h.outputIncubating } if h, ok := resolver.(*htlcTimeoutResolver); ok { - resolved = h.resolved + resolved = h.resolved.Load() incubating = h.outputIncubating } @@ -610,7 +667,12 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext, checkpointedState = append(checkpointedState, b.Bytes()) nextCheckpoint++ - checkpointChan <- struct{}{} + select { + case checkpointChan <- struct{}{}: + case <-time.After(1 * time.Second): + t.Fatal("checkpoint timeout") + } + return nil } @@ -621,6 +683,8 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext, // preCheckpoint logic if needed. resumed := true for i, cp := range expectedCheckpoints { + t.Logf("Running checkpoint %d", i) + if cp.preCheckpoint != nil { if err := cp.preCheckpoint(ctx, resumed); err != nil { t.Fatalf("failure at stage %d: %v", i, err) @@ -629,15 +693,15 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext, resumed = false // Wait for the resolver to have checkpointed its state. - <-checkpointChan + select { + case <-checkpointChan: + case <-time.After(1 * time.Second): + t.Fatalf("resolver did not checkpoint at stage %d", i) + } } // Wait for the resolver to fully complete. ctx.waitForResult() - if nextCheckpoint < len(expectedCheckpoints) { - t.Fatalf("not all checkpoints hit") - } - return checkpointedState } diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 9954c3c0db..1782cfb3ba 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -7,12 +7,13 @@ import ( "sync" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" @@ -37,9 +38,6 @@ type htlcTimeoutResolver struct { // incubator (utxo nursery). outputIncubating bool - // resolved reflects if the contract has been fully resolved or not. - resolved bool - // broadcastHeight is the height that the original contract was // broadcast to the main-chain at. We'll use this value to bound any // historical queries to the chain for spends/confirmations. @@ -82,6 +80,7 @@ func newTimeoutResolver(res lnwallet.OutgoingHtlcResolution, } h.initReport() + h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint())) return h } @@ -93,23 +92,25 @@ func (h *htlcTimeoutResolver) isTaproot() bool { ) } -// ResolverKey returns an identifier which should be globally unique for this -// particular resolver within the chain the original contract resides within. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) ResolverKey() []byte { +// outpoint returns the outpoint of the HTLC output we're attempting to sweep. +func (h *htlcTimeoutResolver) outpoint() wire.OutPoint { // The primary key for this resolver will be the outpoint of the HTLC // on the commitment transaction itself. If this is our commitment, // then the output can be found within the signed timeout tx, // otherwise, it's just the ClaimOutpoint. - var op wire.OutPoint if h.htlcResolution.SignedTimeoutTx != nil { - op = h.htlcResolution.SignedTimeoutTx.TxIn[0].PreviousOutPoint - } else { - op = h.htlcResolution.ClaimOutpoint + return h.htlcResolution.SignedTimeoutTx.TxIn[0].PreviousOutPoint } - key := newResolverID(op) + return h.htlcResolution.ClaimOutpoint +} + +// ResolverKey returns an identifier which should be globally unique for this +// particular resolver within the chain the original contract resides within. +// +// NOTE: Part of the ContractResolver interface. +func (h *htlcTimeoutResolver) ResolverKey() []byte { + key := newResolverID(h.outpoint()) return key[:] } @@ -157,7 +158,7 @@ const ( // by the remote party. It'll extract the preimage, add it to the global cache, // and finally send the appropriate clean up message. func (h *htlcTimeoutResolver) claimCleanUp( - commitSpend *chainntnfs.SpendDetail) (ContractResolver, error) { + commitSpend *chainntnfs.SpendDetail) error { // Depending on if this is our commitment or not, then we'll be looking // for a different witness pattern. @@ -192,7 +193,7 @@ func (h *htlcTimeoutResolver) claimCleanUp( // element, then we're actually on the losing side of a breach // attempt... case h.isTaproot() && len(spendingInput.Witness) == 1: - return nil, fmt.Errorf("breach attempt failed") + return fmt.Errorf("breach attempt failed") // Otherwise, they'll be spending directly from our commitment output. // In which case the witness stack looks like: @@ -209,8 +210,8 @@ func (h *htlcTimeoutResolver) claimCleanUp( preimage, err := lntypes.MakePreimage(preimageBytes) if err != nil { - return nil, fmt.Errorf("unable to create pre-image from "+ - "witness: %v", err) + return fmt.Errorf("unable to create pre-image from witness: %w", + err) } log.Infof("%T(%v): extracting preimage=%v from on-chain "+ @@ -232,9 +233,9 @@ func (h *htlcTimeoutResolver) claimCleanUp( HtlcIndex: h.htlc.HtlcIndex, PreImage: &pre, }); err != nil { - return nil, err + return err } - h.resolved = true + h.markResolved() // Checkpoint our resolver with a report which reflects the preimage // claim by the remote party. @@ -247,7 +248,7 @@ func (h *htlcTimeoutResolver) claimCleanUp( SpendTxID: commitSpend.SpenderTxHash, } - return nil, h.Checkpoint(h, report) + return h.Checkpoint(h, report) } // chainDetailsToWatch returns the output and script which we use to watch for @@ -418,70 +419,33 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool { // see a direct sweep via the timeout clause. // // NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) Resolve( - immediate bool) (ContractResolver, error) { - +func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. - if h.resolved { + if h.IsResolved() { + h.log.Errorf("already resolved") return nil, nil } - // Start by spending the HTLC output, either by broadcasting the - // second-level timeout transaction, or directly if this is the remote - // commitment. - commitSpend, err := h.spendHtlcOutput(immediate) - if err != nil { - return nil, err + // If this is an output on the remote party's commitment transaction, + // use the direct-spend path to sweep the htlc. + if h.isRemoteCommitOutput() { + return nil, h.resolveRemoteCommitOutput() } - // If the spend reveals the pre-image, then we'll enter the clean up - // workflow to pass the pre-image back to the incoming link, add it to - // the witness cache, and exit. - if isPreimageSpend( - h.isTaproot(), commitSpend, - h.htlcResolution.SignedTimeoutTx != nil, - ) { - - log.Infof("%T(%v): HTLC has been swept with pre-image by "+ - "remote party during timeout flow! Adding pre-image to "+ - "witness cache", h, h.htlc.RHash[:], - h.htlcResolution.ClaimOutpoint) - - return h.claimCleanUp(commitSpend) - } - - // At this point, the second-level transaction is sufficiently - // confirmed, or a transaction directly spending the output is. - // Therefore, we can now send back our clean up message, failing the - // HTLC on the incoming link. - // - // NOTE: This can be called twice if the outgoing resolver restarts - // before the second-stage timeout transaction is confirmed. - log.Infof("%T(%v): resolving htlc with incoming fail msg, "+ - "fully confirmed", h, h.htlcResolution.ClaimOutpoint) - - failureMsg := &lnwire.FailPermanentChannelFailure{} - err = h.DeliverResolutionMsg(ResolutionMsg{ - SourceChan: h.ShortChanID, - HtlcIndex: h.htlc.HtlcIndex, - Failure: failureMsg, - }) - if err != nil { - return nil, err + // If this is a zero-fee HTLC, we now handle the spend from our + // commitment transaction. + if h.isZeroFeeOutput() { + return nil, h.resolveTimeoutTx() } - // Depending on whether this was a local or remote commit, we must - // handle the spending transaction accordingly. - return h.handleCommitSpend(commitSpend) + // If this is an output on our own commitment using pre-anchor channel + // type, we will let the utxo nursery handle it. + return nil, h.resolveSecondLevelTxLegacy() } -// sweepSecondLevelTx sends a second level timeout transaction to the sweeper. +// sweepTimeoutTx sends a second level timeout transaction to the sweeper. // This transaction uses the SINLGE|ANYONECANPAY flag. -func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { - log.Infof("%T(%x): offering second-layer timeout tx to sweeper: %v", - h, h.htlc.RHash[:], - spew.Sdump(h.htlcResolution.SignedTimeoutTx)) - +func (h *htlcTimeoutResolver) sweepTimeoutTx() error { var inp input.Input if h.isTaproot() { inp = lnutils.Ptr(input.MakeHtlcSecondLevelTimeoutTaprootInput( @@ -512,33 +476,17 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { btcutil.Amount(inp.SignDesc().Output.Value), 2, 0, ) + h.log.Infof("offering 2nd-level HTLC timeout tx to sweeper "+ + "with deadline=%v, budget=%v", h.incomingHTLCExpiryHeight, + budget) + // For an outgoing HTLC, it must be swept before the RefundTimeout of // its incoming HTLC is reached. - // - // TODO(yy): we may end up mixing inputs with different time locks. - // Suppose we have two outgoing HTLCs, - // - HTLC1: nLocktime is 800000, CLTV delta is 80. - // - HTLC2: nLocktime is 800001, CLTV delta is 79. - // This means they would both have an incoming HTLC that expires at - // 800080, hence they share the same deadline but different locktimes. - // However, with current design, when we are at block 800000, HTLC1 is - // offered to the sweeper. When block 800001 is reached, HTLC1's - // sweeping process is already started, while HTLC2 is being offered to - // the sweeper, so they won't be mixed. This can become an issue tho, - // if we decide to sweep per X blocks. Or the contractcourt sees the - // block first while the sweeper is only aware of the last block. To - // properly fix it, we need `blockbeat` to make sure subsystems are in - // sync. - log.Infof("%T(%x): offering second-level HTLC timeout tx to sweeper "+ - "with deadline=%v, budget=%v", h, h.htlc.RHash[:], - h.incomingHTLCExpiryHeight, budget) - _, err := h.Sweeper.SweepInput( inp, sweep.Params{ Budget: budget, DeadlineHeight: h.incomingHTLCExpiryHeight, - Immediate: immediate, }, ) if err != nil { @@ -548,12 +496,13 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { return err } -// sendSecondLevelTxLegacy sends a second level timeout transaction to the utxo -// nursery. This transaction uses the legacy SIGHASH_ALL flag. -func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error { - log.Debugf("%T(%v): incubating htlc output", h, - h.htlcResolution.ClaimOutpoint) +// resolveSecondLevelTxLegacy sends a second level timeout transaction to the +// utxo nursery. This transaction uses the legacy SIGHASH_ALL flag. +func (h *htlcTimeoutResolver) resolveSecondLevelTxLegacy() error { + h.log.Debug("incubating htlc output") + // The utxo nursery will take care of broadcasting the second-level + // timeout tx and sweeping its output once it confirms. err := h.IncubateOutputs( h.ChanPoint, fn.Some(h.htlcResolution), fn.None[lnwallet.IncomingHtlcResolution](), @@ -563,16 +512,14 @@ func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error { return err } - h.outputIncubating = true - - return h.Checkpoint(h) + return h.resolveTimeoutTx() } // sweepDirectHtlcOutput sends the direct spend of the HTLC output to the // sweeper. This is used when the remote party goes on chain, and we're able to // sweep an HTLC we offered after a timeout. Only the CLTV encumbered outputs // are resolved via this path. -func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error { +func (h *htlcTimeoutResolver) sweepDirectHtlcOutput() error { var htlcWitnessType input.StandardWitnessType if h.isTaproot() { htlcWitnessType = input.TaprootHtlcOfferedRemoteTimeout @@ -612,7 +559,6 @@ func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error { // This is an outgoing HTLC, so we want to make sure // that we sweep it before the incoming HTLC expires. DeadlineHeight: h.incomingHTLCExpiryHeight, - Immediate: immediate, }, ) if err != nil { @@ -622,53 +568,6 @@ func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error { return nil } -// spendHtlcOutput handles the initial spend of an HTLC output via the timeout -// clause. If this is our local commitment, the second-level timeout TX will be -// used to spend the output into the next stage. If this is the remote -// commitment, the output will be swept directly without the timeout -// transaction. -func (h *htlcTimeoutResolver) spendHtlcOutput( - immediate bool) (*chainntnfs.SpendDetail, error) { - - switch { - // If we have non-nil SignDetails, this means that have a 2nd level - // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY - // (the case for anchor type channels). In this case we can re-sign it - // and attach fees at will. We let the sweeper handle this job. - case h.htlcResolution.SignDetails != nil && !h.outputIncubating: - if err := h.sweepSecondLevelTx(immediate); err != nil { - log.Errorf("Sending timeout tx to sweeper: %v", err) - - return nil, err - } - - // If this is a remote commitment there's no second level timeout txn, - // and we can just send this directly to the sweeper. - case h.htlcResolution.SignedTimeoutTx == nil && !h.outputIncubating: - if err := h.sweepDirectHtlcOutput(immediate); err != nil { - log.Errorf("Sending direct spend to sweeper: %v", err) - - return nil, err - } - - // If we have a SignedTimeoutTx but no SignDetails, this is a local - // commitment for a non-anchor channel, so we'll send it to the utxo - // nursery. - case h.htlcResolution.SignDetails == nil && !h.outputIncubating: - if err := h.sendSecondLevelTxLegacy(); err != nil { - log.Errorf("Sending timeout tx to nursery: %v", err) - - return nil, err - } - } - - // Now that we've handed off the HTLC to the nursery or sweeper, we'll - // watch for a spend of the output, and make our next move off of that. - // Depending on if this is our commitment, or the remote party's - // commitment, we'll be watching a different outpoint and script. - return h.watchHtlcSpend() -} - // watchHtlcSpend watches for a spend of the HTLC output. For neutrino backend, // it will check blocks for the confirmed spend. For btcd and bitcoind, it will // check both the mempool and the blocks. @@ -697,9 +596,6 @@ func (h *htlcTimeoutResolver) watchHtlcSpend() (*chainntnfs.SpendDetail, func (h *htlcTimeoutResolver) waitForConfirmedSpend(op *wire.OutPoint, pkScript []byte) (*chainntnfs.SpendDetail, error) { - log.Infof("%T(%v): waiting for spent of HTLC output %v to be "+ - "fully confirmed", h, h.htlcResolution.ClaimOutpoint, op) - // We'll block here until either we exit, or the HTLC output on the // commitment transaction has been spent. spend, err := waitForSpend( @@ -709,239 +605,18 @@ func (h *htlcTimeoutResolver) waitForConfirmedSpend(op *wire.OutPoint, return nil, err } - // Once confirmed, persist the state on disk. - if err := h.checkPointSecondLevelTx(); err != nil { - return nil, err - } - return spend, err } -// checkPointSecondLevelTx persists the state of a second level HTLC tx to disk -// if it's published by the sweeper. -func (h *htlcTimeoutResolver) checkPointSecondLevelTx() error { - // If this was the second level transaction published by the sweeper, - // we can checkpoint the resolver now that it's confirmed. - if h.htlcResolution.SignDetails != nil && !h.outputIncubating { - h.outputIncubating = true - if err := h.Checkpoint(h); err != nil { - log.Errorf("unable to Checkpoint: %v", err) - return err - } - } - - return nil -} - -// handleCommitSpend handles the spend of the HTLC output on the commitment -// transaction. If this was our local commitment, the spend will be he -// confirmed second-level timeout transaction, and we'll sweep that into our -// wallet. If the was a remote commitment, the resolver will resolve -// immetiately. -func (h *htlcTimeoutResolver) handleCommitSpend( - commitSpend *chainntnfs.SpendDetail) (ContractResolver, error) { - - var ( - // claimOutpoint will be the outpoint of the second level - // transaction, or on the remote commitment directly. It will - // start out as set in the resolution, but we'll update it if - // the second-level goes through the sweeper and changes its - // txid. - claimOutpoint = h.htlcResolution.ClaimOutpoint - - // spendTxID will be the ultimate spend of the claimOutpoint. - // We set it to the commit spend for now, as this is the - // ultimate spend in case this is a remote commitment. If we go - // through the second-level transaction, we'll update this - // accordingly. - spendTxID = commitSpend.SpenderTxHash - - reports []*channeldb.ResolverReport - ) - - switch { - - // If we swept an HTLC directly off the remote party's commitment - // transaction, then we can exit here as there's no second level sweep - // to do. - case h.htlcResolution.SignedTimeoutTx == nil: - break - - // If the sweeper is handling the second level transaction, wait for - // the CSV and possible CLTV lock to expire, before sweeping the output - // on the second-level. - case h.htlcResolution.SignDetails != nil: - waitHeight := h.deriveWaitHeight( - h.htlcResolution.CsvDelay, commitSpend, - ) - - h.reportLock.Lock() - h.currentReport.Stage = 2 - h.currentReport.MaturityHeight = waitHeight - h.reportLock.Unlock() - - if h.hasCLTV() { - log.Infof("%T(%x): waiting for CSV and CLTV lock to "+ - "expire at height %v", h, h.htlc.RHash[:], - waitHeight) - } else { - log.Infof("%T(%x): waiting for CSV lock to expire at "+ - "height %v", h, h.htlc.RHash[:], waitHeight) - } - - // Deduct one block so this input is offered to the sweeper one - // block earlier since the sweeper will wait for one block to - // trigger the sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) - if err != nil { - return nil, err - } - - // We'll use this input index to determine the second-level - // output index on the transaction, as the signatures requires - // the indexes to be the same. We don't look for the - // second-level output script directly, as there might be more - // than one HTLC output to the same pkScript. - op := &wire.OutPoint{ - Hash: *commitSpend.SpenderTxHash, - Index: commitSpend.SpenderInputIndex, - } - - var csvWitnessType input.StandardWitnessType - if h.isTaproot() { - //nolint:ll - csvWitnessType = input.TaprootHtlcOfferedTimeoutSecondLevel - } else { - csvWitnessType = input.HtlcOfferedTimeoutSecondLevel - } - - // Let the sweeper sweep the second-level output now that the - // CSV/CLTV locks have expired. - inp := h.makeSweepInput( - op, csvWitnessType, - input.LeaseHtlcOfferedTimeoutSecondLevel, - &h.htlcResolution.SweepSignDesc, - h.htlcResolution.CsvDelay, - uint32(commitSpend.SpendingHeight), h.htlc.RHash, - h.htlcResolution.ResolutionBlob, - ) - - // Calculate the budget for this sweep. - budget := calculateBudget( - btcutil.Amount(inp.SignDesc().Output.Value), - h.Budget.NoDeadlineHTLCRatio, - h.Budget.NoDeadlineHTLC, - ) - - log.Infof("%T(%x): offering second-level timeout tx output to "+ - "sweeper with no deadline and budget=%v at height=%v", - h, h.htlc.RHash[:], budget, waitHeight) - - _, err = h.Sweeper.SweepInput( - inp, - sweep.Params{ - Budget: budget, - - // For second level success tx, there's no rush - // to get it confirmed, so we use a nil - // deadline. - DeadlineHeight: fn.None[int32](), - }, - ) - if err != nil { - return nil, err - } - - // Update the claim outpoint to point to the second-level - // transaction created by the sweeper. - claimOutpoint = *op - fallthrough - - // Finally, if this was an output on our commitment transaction, we'll - // wait for the second-level HTLC output to be spent, and for that - // transaction itself to confirm. - case h.htlcResolution.SignedTimeoutTx != nil: - log.Infof("%T(%v): waiting for nursery/sweeper to spend CSV "+ - "delayed output", h, claimOutpoint) - - sweepTx, err := waitForSpend( - &claimOutpoint, - h.htlcResolution.SweepSignDesc.Output.PkScript, - h.broadcastHeight, h.Notifier, h.quit, - ) - if err != nil { - return nil, err - } - - // Update the spend txid to the hash of the sweep transaction. - spendTxID = sweepTx.SpenderTxHash - - // Once our sweep of the timeout tx has confirmed, we add a - // resolution for our timeoutTx tx first stage transaction. - timeoutTx := commitSpend.SpendingTx - index := commitSpend.SpenderInputIndex - spendHash := commitSpend.SpenderTxHash - - reports = append(reports, &channeldb.ResolverReport{ - OutPoint: timeoutTx.TxIn[index].PreviousOutPoint, - Amount: h.htlc.Amt.ToSatoshis(), - ResolverType: channeldb.ResolverTypeOutgoingHtlc, - ResolverOutcome: channeldb.ResolverOutcomeFirstStage, - SpendTxID: spendHash, - }) - } - - // With the clean up message sent, we'll now mark the contract - // resolved, update the recovered balance, record the timeout and the - // sweep txid on disk, and wait. - h.resolved = true - h.reportLock.Lock() - h.currentReport.RecoveredBalance = h.currentReport.LimboBalance - h.currentReport.LimboBalance = 0 - h.reportLock.Unlock() - - amt := btcutil.Amount(h.htlcResolution.SweepSignDesc.Output.Value) - reports = append(reports, &channeldb.ResolverReport{ - OutPoint: claimOutpoint, - Amount: amt, - ResolverType: channeldb.ResolverTypeOutgoingHtlc, - ResolverOutcome: channeldb.ResolverOutcomeTimeout, - SpendTxID: spendTxID, - }) - - return nil, h.Checkpoint(h, reports...) -} - // Stop signals the resolver to cancel any current resolution processes, and // suspend. // // NOTE: Part of the ContractResolver interface. func (h *htlcTimeoutResolver) Stop() { - close(h.quit) -} + h.log.Debugf("stopping...") + defer h.log.Debugf("stopped") -// IsResolved returns true if the stored state in the resolve is fully -// resolved. In this case the target output can be forgotten. -// -// NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) IsResolved() bool { - return h.resolved + close(h.quit) } // report returns a report on the resolution state of the contract. @@ -1003,7 +678,7 @@ func (h *htlcTimeoutResolver) Encode(w io.Writer) error { if err := binary.Write(w, endian, h.outputIncubating); err != nil { return err } - if err := binary.Write(w, endian, h.resolved); err != nil { + if err := binary.Write(w, endian, h.IsResolved()); err != nil { return err } if err := binary.Write(w, endian, h.broadcastHeight); err != nil { @@ -1044,9 +719,15 @@ func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) ( if err := binary.Read(r, endian, &h.outputIncubating); err != nil { return nil, err } - if err := binary.Read(r, endian, &h.resolved); err != nil { + + var resolved bool + if err := binary.Read(r, endian, &resolved); err != nil { return nil, err } + if resolved { + h.markResolved() + } + if err := binary.Read(r, endian, &h.broadcastHeight); err != nil { return nil, err } @@ -1066,6 +747,7 @@ func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) ( } h.initReport() + h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint())) return h, nil } @@ -1173,12 +855,6 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, // Create a result chan to hold the results. result := &spendResult{} - // hasMempoolSpend is a flag that indicates whether we have found a - // preimage spend from the mempool. This is used to determine whether - // to checkpoint the resolver or not when later we found the - // corresponding block spend. - hasMempoolSpent := false - // Wait for a spend event to arrive. for { select { @@ -1206,23 +882,6 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, // Once confirmed, persist the state on disk if // we haven't seen the output's spending tx in // mempool before. - // - // NOTE: we don't checkpoint the resolver if - // it's spending tx has already been found in - // mempool - the resolver will take care of the - // checkpoint in its `claimCleanUp`. If we do - // checkpoint here, however, we'd create a new - // record in db for the same htlc resolver - // which won't be cleaned up later, resulting - // the channel to stay in unresolved state. - // - // TODO(yy): when fee bumper is implemented, we - // need to further check whether this is a - // preimage spend. Also need to refactor here - // to save us some indentation. - if !hasMempoolSpent { - result.err = h.checkPointSecondLevelTx() - } } // Send the result and exit the loop. @@ -1256,7 +915,7 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, // continue the loop. hasPreimage := isPreimageSpend( h.isTaproot(), spendDetail, - h.htlcResolution.SignedTimeoutTx != nil, + !h.isRemoteCommitOutput(), ) if !hasPreimage { log.Debugf("HTLC output %s spent doesn't "+ @@ -1269,10 +928,6 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, result.spend = spendDetail resultChan <- result - // Set the hasMempoolSpent flag to true so we won't - // checkpoint the resolver again in db. - hasMempoolSpent = true - continue // If the resolver exits, we exit the goroutine. @@ -1284,3 +939,379 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult, } } } + +// isRemoteCommitOutput returns a bool to indicate whether the htlc output is +// on the remote commitment. +func (h *htlcTimeoutResolver) isRemoteCommitOutput() bool { + // If we don't have a timeout transaction, then this means that this is + // an output on the remote party's commitment transaction. + return h.htlcResolution.SignedTimeoutTx == nil +} + +// isZeroFeeOutput returns a boolean indicating whether the htlc output is from +// a anchor-enabled channel, which uses the sighash SINGLE|ANYONECANPAY. +func (h *htlcTimeoutResolver) isZeroFeeOutput() bool { + // If we have non-nil SignDetails, this means it has a 2nd level HTLC + // transaction that is signed using sighash SINGLE|ANYONECANPAY (the + // case for anchor type channels). In this case we can re-sign it and + // attach fees at will. + return h.htlcResolution.SignedTimeoutTx != nil && + h.htlcResolution.SignDetails != nil +} + +// waitHtlcSpendAndCheckPreimage waits for the htlc output to be spent and +// checks whether the spending reveals the preimage. If the preimage is found, +// it will be added to the preimage beacon to settle the incoming link, and a +// nil spend details will be returned. Otherwise, the spend details will be +// returned, indicating this is a non-preimage spend. +func (h *htlcTimeoutResolver) waitHtlcSpendAndCheckPreimage() ( + *chainntnfs.SpendDetail, error) { + + // Wait for the htlc output to be spent, which can happen in one of the + // paths, + // 1. The remote party spends the htlc output using the preimage. + // 2. The local party spends the htlc timeout tx from the local + // commitment. + // 3. The local party spends the htlc output directlt from the remote + // commitment. + spend, err := h.watchHtlcSpend() + if err != nil { + return nil, err + } + + // If the spend reveals the pre-image, then we'll enter the clean up + // workflow to pass the preimage back to the incoming link, add it to + // the witness cache, and exit. + if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) { + return nil, h.claimCleanUp(spend) + } + + return spend, nil +} + +// sweepTimeoutTxOutput attempts to sweep the output of the second level +// timeout tx. +func (h *htlcTimeoutResolver) sweepTimeoutTxOutput() error { + h.log.Debugf("sweeping output %v from 2nd-level HTLC timeout tx", + h.htlcResolution.ClaimOutpoint) + + // This should be non-blocking as we will only attempt to sweep the + // output when the second level tx has already been confirmed. In other + // words, waitHtlcSpendAndCheckPreimage will return immediately. + commitSpend, err := h.waitHtlcSpendAndCheckPreimage() + if err != nil { + return err + } + + // Exit early if the spend is nil, as this means it's a remote spend + // using the preimage path, which is handled in claimCleanUp. + if commitSpend == nil { + h.log.Infof("preimage spend detected, skipping 2nd-level " + + "HTLC output sweep") + + return nil + } + + waitHeight := h.deriveWaitHeight(h.htlcResolution.CsvDelay, commitSpend) + + // Now that the sweeper has broadcasted the second-level transaction, + // it has confirmed, and we have checkpointed our state, we'll sweep + // the second level output. We report the resolver has moved the next + // stage. + h.reportLock.Lock() + h.currentReport.Stage = 2 + h.currentReport.MaturityHeight = waitHeight + h.reportLock.Unlock() + + if h.hasCLTV() { + h.log.Infof("waiting for CSV and CLTV lock to expire at "+ + "height %v", waitHeight) + } else { + h.log.Infof("waiting for CSV lock to expire at height %v", + waitHeight) + } + + // We'll use this input index to determine the second-level output + // index on the transaction, as the signatures requires the indexes to + // be the same. We don't look for the second-level output script + // directly, as there might be more than one HTLC output to the same + // pkScript. + op := &wire.OutPoint{ + Hash: *commitSpend.SpenderTxHash, + Index: commitSpend.SpenderInputIndex, + } + + var witType input.StandardWitnessType + if h.isTaproot() { + witType = input.TaprootHtlcOfferedTimeoutSecondLevel + } else { + witType = input.HtlcOfferedTimeoutSecondLevel + } + + // Let the sweeper sweep the second-level output now that the CSV/CLTV + // locks have expired. + inp := h.makeSweepInput( + op, witType, + input.LeaseHtlcOfferedTimeoutSecondLevel, + &h.htlcResolution.SweepSignDesc, + h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight), + h.htlc.RHash, h.htlcResolution.ResolutionBlob, + ) + + // Calculate the budget for this sweep. + budget := calculateBudget( + btcutil.Amount(inp.SignDesc().Output.Value), + h.Budget.NoDeadlineHTLCRatio, + h.Budget.NoDeadlineHTLC, + ) + + h.log.Infof("offering output from 2nd-level timeout tx to sweeper "+ + "with no deadline and budget=%v", budget) + + // TODO(yy): use the result chan returned from SweepInput to get the + // confirmation status of this sweeping tx so we don't need to make + // anothe subscription via `RegisterSpendNtfn` for this outpoint here + // in the resolver. + _, err = h.Sweeper.SweepInput( + inp, + sweep.Params{ + Budget: budget, + + // For second level success tx, there's no rush + // to get it confirmed, so we use a nil + // deadline. + DeadlineHeight: fn.None[int32](), + }, + ) + + return err +} + +// checkpointStageOne creates a checkpoint for the first stage of the htlc +// timeout transaction. This is used to ensure that the resolver can resume +// watching for the second stage spend in case of a restart. +func (h *htlcTimeoutResolver) checkpointStageOne( + spendTxid chainhash.Hash) error { + + h.log.Debugf("checkpoint stage one spend of HTLC output %v, spent "+ + "in tx %v", h.outpoint(), spendTxid) + + // Now that the second-level transaction has confirmed, we checkpoint + // the state so we'll go to the next stage in case of restarts. + h.outputIncubating = true + + // Create stage-one report. + report := &channeldb.ResolverReport{ + OutPoint: h.outpoint(), + Amount: h.htlc.Amt.ToSatoshis(), + ResolverType: channeldb.ResolverTypeOutgoingHtlc, + ResolverOutcome: channeldb.ResolverOutcomeFirstStage, + SpendTxID: &spendTxid, + } + + // At this point, the second-level transaction is sufficiently + // confirmed. We can now send back our clean up message, failing the + // HTLC on the incoming link. + failureMsg := &lnwire.FailPermanentChannelFailure{} + err := h.DeliverResolutionMsg(ResolutionMsg{ + SourceChan: h.ShortChanID, + HtlcIndex: h.htlc.HtlcIndex, + Failure: failureMsg, + }) + if err != nil { + return err + } + + return h.Checkpoint(h, report) +} + +// checkpointClaim checkpoints the timeout resolver with the reports it needs. +func (h *htlcTimeoutResolver) checkpointClaim( + spendDetail *chainntnfs.SpendDetail) error { + + h.log.Infof("resolving htlc with incoming fail msg, output=%v "+ + "confirmed in tx=%v", spendDetail.SpentOutPoint, + spendDetail.SpenderTxHash) + + // Create a resolver report for the claiming of the HTLC. + amt := btcutil.Amount(h.htlcResolution.SweepSignDesc.Output.Value) + report := &channeldb.ResolverReport{ + OutPoint: *spendDetail.SpentOutPoint, + Amount: amt, + ResolverType: channeldb.ResolverTypeOutgoingHtlc, + ResolverOutcome: channeldb.ResolverOutcomeTimeout, + SpendTxID: spendDetail.SpenderTxHash, + } + + // Finally, we checkpoint the resolver with our report(s). + h.markResolved() + + return h.Checkpoint(h, report) +} + +// resolveRemoteCommitOutput handles sweeping an HTLC output on the remote +// commitment with via the timeout path. In this case we can sweep the output +// directly, and don't have to broadcast a second-level transaction. +func (h *htlcTimeoutResolver) resolveRemoteCommitOutput() error { + h.log.Debug("waiting for direct-timeout spend of the htlc to confirm") + + // Wait for the direct-timeout HTLC sweep tx to confirm. + spend, err := h.watchHtlcSpend() + if err != nil { + return err + } + + // If the spend reveals the preimage, then we'll enter the clean up + // workflow to pass the preimage back to the incoming link, add it to + // the witness cache, and exit. + if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) { + return h.claimCleanUp(spend) + } + + // Send the clean up msg to fail the incoming HTLC. + failureMsg := &lnwire.FailPermanentChannelFailure{} + err = h.DeliverResolutionMsg(ResolutionMsg{ + SourceChan: h.ShortChanID, + HtlcIndex: h.htlc.HtlcIndex, + Failure: failureMsg, + }) + if err != nil { + return err + } + + // TODO(yy): should also update the `RecoveredBalance` and + // `LimboBalance` like other paths? + + // Checkpoint the resolver, and write the outcome to disk. + return h.checkpointClaim(spend) +} + +// resolveTimeoutTx waits for the sweeping tx of the second-level +// timeout tx to confirm and offers the output from the timeout tx to the +// sweeper. +func (h *htlcTimeoutResolver) resolveTimeoutTx() error { + h.log.Debug("waiting for first-stage 2nd-level HTLC timeout tx to " + + "confirm") + + // Wait for the second level transaction to confirm. + spend, err := h.watchHtlcSpend() + if err != nil { + return err + } + + // If the spend reveals the preimage, then we'll enter the clean up + // workflow to pass the preimage back to the incoming link, add it to + // the witness cache, and exit. + if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) { + return h.claimCleanUp(spend) + } + + op := h.htlcResolution.ClaimOutpoint + spenderTxid := *spend.SpenderTxHash + + // If the timeout tx is a re-signed tx, we will need to find the actual + // spent outpoint from the spending tx. + if h.isZeroFeeOutput() { + op = wire.OutPoint{ + Hash: spenderTxid, + Index: spend.SpenderInputIndex, + } + } + + // If the 2nd-stage sweeping has already been started, we can + // fast-forward to start the resolving process for the stage two + // output. + if h.outputIncubating { + return h.resolveTimeoutTxOutput(op) + } + + h.log.Infof("2nd-level HTLC timeout tx=%v confirmed", spenderTxid) + + // Start the process to sweep the output from the timeout tx. + if h.isZeroFeeOutput() { + err = h.sweepTimeoutTxOutput() + if err != nil { + return err + } + } + + // Create a checkpoint since the timeout tx is confirmed and the sweep + // request has been made. + if err := h.checkpointStageOne(spenderTxid); err != nil { + return err + } + + // Start the resolving process for the stage two output. + return h.resolveTimeoutTxOutput(op) +} + +// resolveTimeoutTxOutput waits for the spend of the output from the 2nd-level +// timeout tx. +func (h *htlcTimeoutResolver) resolveTimeoutTxOutput(op wire.OutPoint) error { + h.log.Debugf("waiting for second-stage 2nd-level timeout tx output %v "+ + "to be spent after csv_delay=%v", op, h.htlcResolution.CsvDelay) + + spend, err := waitForSpend( + &op, h.htlcResolution.SweepSignDesc.Output.PkScript, + h.broadcastHeight, h.Notifier, h.quit, + ) + if err != nil { + return err + } + + h.reportLock.Lock() + h.currentReport.RecoveredBalance = h.currentReport.LimboBalance + h.currentReport.LimboBalance = 0 + h.reportLock.Unlock() + + return h.checkpointClaim(spend) +} + +// Launch creates an input based on the details of the outgoing htlc resolution +// and offers it to the sweeper. +func (h *htlcTimeoutResolver) Launch() error { + if h.isLaunched() { + h.log.Tracef("already launched") + return nil + } + + h.log.Debugf("launching resolver...") + h.launched.Store(true) + + switch { + // If we're already resolved, then we can exit early. + case h.IsResolved(): + h.log.Errorf("already resolved") + return nil + + // If this is an output on the remote party's commitment transaction, + // use the direct timeout spend path. + // + // NOTE: When the outputIncubating is false, it means that the output + // has been offered to the utxo nursery as starting in 0.18.4, we + // stopped marking this flag for direct timeout spends (#9062). In that + // case, we will do nothing and let the utxo nursery handle it. + case h.isRemoteCommitOutput() && !h.outputIncubating: + return h.sweepDirectHtlcOutput() + + // If this is an anchor type channel, we now sweep either the + // second-level timeout tx or the output from the second-level timeout + // tx. + case h.isZeroFeeOutput(): + // If the second-level timeout tx has already been swept, we + // can go ahead and sweep its output. + if h.outputIncubating { + return h.sweepTimeoutTxOutput() + } + + // Otherwise, sweep the second level tx. + return h.sweepTimeoutTx() + + // If this is an output on our own commitment using pre-anchor channel + // type, we will let the utxo nursery handle it via Resolve. + // + // TODO(yy): handle the legacy output by offering it to the sweeper. + default: + return nil + } +} diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index f3f23c385c..017d3d3886 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -40,7 +40,7 @@ type mockWitnessBeacon struct { func newMockWitnessBeacon() *mockWitnessBeacon { return &mockWitnessBeacon{ preImageUpdates: make(chan lntypes.Preimage, 1), - newPreimages: make(chan []lntypes.Preimage), + newPreimages: make(chan []lntypes.Preimage, 1), lookupPreimage: make(map[lntypes.Hash]lntypes.Preimage), } } @@ -280,7 +280,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { notifier := &mock.ChainNotifier{ EpochChan: make(chan *chainntnfs.BlockEpoch), - SpendChan: make(chan *chainntnfs.SpendDetail), + SpendChan: make(chan *chainntnfs.SpendDetail, 1), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -321,6 +321,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { return nil }, + HtlcNotifier: &mockHTLCNotifier{}, }, PutResolverReport: func(_ kvdb.RwTx, _ *channeldb.ResolverReport) error { @@ -356,6 +357,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { Amt: testHtlcAmt, }, } + resolver.initLogger("timeoutResolver") var reports []*channeldb.ResolverReport @@ -390,7 +392,12 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { go func() { defer wg.Done() - _, err := resolver.Resolve(false) + err := resolver.Launch() + if err != nil { + resolveErr <- err + } + + _, err = resolver.Resolve() if err != nil { resolveErr <- err } @@ -406,8 +413,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { sweepChan = mockSweeper.sweptInputs } - // The output should be offered to either the sweeper or - // the nursery. + // The output should be offered to either the sweeper or the nursery. select { case <-incubateChan: case <-sweepChan: @@ -431,6 +437,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { case notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: spendingTx, SpenderTxHash: &spendTxHash, + SpentOutPoint: &testChanPoint2, }: case <-time.After(time.Second * 5): t.Fatalf("failed to request spend ntfn") @@ -487,6 +494,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { case notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: spendingTx, SpenderTxHash: &spendTxHash, + SpentOutPoint: &testChanPoint2, }: case <-time.After(time.Second * 5): t.Fatalf("failed to request spend ntfn") @@ -524,7 +532,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { wg.Wait() // Finally, the resolver should be marked as resolved. - if !resolver.resolved { + if !resolver.resolved.Load() { t.Fatalf("resolver should be marked as resolved") } } @@ -549,6 +557,8 @@ func TestHtlcTimeoutResolver(t *testing.T) { // TestHtlcTimeoutSingleStage tests a remote commitment confirming, and the // local node sweeping the HTLC output directly after timeout. +// +//nolint:ll func TestHtlcTimeoutSingleStage(t *testing.T) { commitOutpoint := wire.OutPoint{Index: 3} @@ -573,6 +583,12 @@ func TestHtlcTimeoutSingleStage(t *testing.T) { SpendTxID: &sweepTxid, } + sweepSpend := &chainntnfs.SpendDetail{ + SpendingTx: sweepTx, + SpentOutPoint: &commitOutpoint, + SpenderTxHash: &sweepTxid, + } + checkpoints := []checkpoint{ { // We send a confirmation the sweep tx from published @@ -582,9 +598,10 @@ func TestHtlcTimeoutSingleStage(t *testing.T) { // The nursery will create and publish a sweep // tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: sweepTx, - SpenderTxHash: &sweepTxid, + select { + case ctx.notifier.SpendChan <- sweepSpend: + case <-time.After(time.Second * 5): + t.Fatalf("failed to send spend ntfn") } // The resolver should deliver a failure @@ -620,7 +637,9 @@ func TestHtlcTimeoutSingleStage(t *testing.T) { // TestHtlcTimeoutSecondStage tests a local commitment being confirmed, and the // local node claiming the HTLC output using the second-level timeout tx. -func TestHtlcTimeoutSecondStage(t *testing.T) { +// +//nolint:ll +func TestHtlcTimeoutSecondStagex(t *testing.T) { commitOutpoint := wire.OutPoint{Index: 2} htlcOutpoint := wire.OutPoint{Index: 3} @@ -678,23 +697,57 @@ func TestHtlcTimeoutSecondStage(t *testing.T) { SpendTxID: &sweepHash, } + timeoutSpend := &chainntnfs.SpendDetail{ + SpendingTx: timeoutTx, + SpentOutPoint: &commitOutpoint, + SpenderTxHash: &timeoutTxid, + } + + sweepSpend := &chainntnfs.SpendDetail{ + SpendingTx: sweepTx, + SpentOutPoint: &htlcOutpoint, + SpenderTxHash: &sweepHash, + } + checkpoints := []checkpoint{ { + preCheckpoint: func(ctx *htlcResolverTestContext, + _ bool) error { + + // Deliver spend of timeout tx. + ctx.notifier.SpendChan <- timeoutSpend + + return nil + }, + // Output should be handed off to the nursery. incubating: true, + reports: []*channeldb.ResolverReport{ + firstStage, + }, }, { // We send a confirmation for our sweep tx to indicate // that our sweep succeeded. preCheckpoint: func(ctx *htlcResolverTestContext, - _ bool) error { + resumed bool) error { - // The nursery will publish the timeout tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: timeoutTx, - SpenderTxHash: &timeoutTxid, + // When it's reloaded from disk, we need to + // re-send the notification to mock the first + // `watchHtlcSpend`. + if resumed { + // Deliver spend of timeout tx. + ctx.notifier.SpendChan <- timeoutSpend + + // Deliver spend of timeout tx output. + ctx.notifier.SpendChan <- sweepSpend + + return nil } + // Deliver spend of timeout tx output. + ctx.notifier.SpendChan <- sweepSpend + // The resolver should deliver a failure // resolution message (indicating we // successfully timed out the HTLC). @@ -707,12 +760,6 @@ func TestHtlcTimeoutSecondStage(t *testing.T) { t.Fatalf("resolution not sent") } - // Deliver spend of timeout tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: sweepTx, - SpenderTxHash: &sweepHash, - } - return nil }, @@ -722,7 +769,7 @@ func TestHtlcTimeoutSecondStage(t *testing.T) { incubating: true, resolved: true, reports: []*channeldb.ResolverReport{ - firstStage, secondState, + secondState, }, }, } @@ -796,10 +843,6 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) { } checkpoints := []checkpoint{ - { - // Output should be handed off to the nursery. - incubating: true, - }, { // We send a spend notification for a remote spend with // the preimage. @@ -812,6 +855,7 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) { // the preimage. ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: spendTx, + SpentOutPoint: &commitOutpoint, SpenderTxHash: &spendTxHash, } @@ -847,7 +891,7 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) { // After the success tx has confirmed, we expect the // checkpoint to be resolved, and with the above // report. - incubating: true, + incubating: false, resolved: true, reports: []*channeldb.ResolverReport{ claim, @@ -914,6 +958,7 @@ func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) { ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: remoteSuccessTx, + SpentOutPoint: &commitOutpoint, SpenderTxHash: &successTxid, } @@ -967,20 +1012,15 @@ func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) { // TestHtlcTimeoutSecondStageSweeper tests that for anchor channels, when a // local commitment confirms, the timeout tx is handed to the sweeper to claim // the HTLC output. +// +//nolint:ll func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { - commitOutpoint := wire.OutPoint{Index: 2} htlcOutpoint := wire.OutPoint{Index: 3} - sweepTx := &wire.MsgTx{ - TxIn: []*wire.TxIn{{}}, - TxOut: []*wire.TxOut{{}}, - } - sweepHash := sweepTx.TxHash() - timeoutTx := &wire.MsgTx{ TxIn: []*wire.TxIn{ { - PreviousOutPoint: commitOutpoint, + PreviousOutPoint: htlcOutpoint, }, }, TxOut: []*wire.TxOut{ @@ -1027,11 +1067,16 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { }, } reSignedHash := reSignedTimeoutTx.TxHash() - reSignedOutPoint := wire.OutPoint{ + + timeoutTxOutpoint := wire.OutPoint{ Hash: reSignedHash, Index: 1, } + // Make a copy so `isPreimageSpend` can easily pass. + sweepTx := reSignedTimeoutTx.Copy() + sweepHash := sweepTx.TxHash() + // twoStageResolution is a resolution for a htlc on the local // party's commitment, where the timeout tx can be re-signed. twoStageResolution := lnwallet.OutgoingHtlcResolution{ @@ -1045,7 +1090,7 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { } firstStage := &channeldb.ResolverReport{ - OutPoint: commitOutpoint, + OutPoint: htlcOutpoint, Amount: testHtlcAmt.ToSatoshis(), ResolverType: channeldb.ResolverTypeOutgoingHtlc, ResolverOutcome: channeldb.ResolverOutcomeFirstStage, @@ -1053,12 +1098,45 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { } secondState := &channeldb.ResolverReport{ - OutPoint: reSignedOutPoint, + OutPoint: timeoutTxOutpoint, Amount: btcutil.Amount(testSignDesc.Output.Value), ResolverType: channeldb.ResolverTypeOutgoingHtlc, ResolverOutcome: channeldb.ResolverOutcomeTimeout, SpendTxID: &sweepHash, } + // mockTimeoutTxSpend is a helper closure to mock `waitForSpend` to + // return the commit spend in `sweepTimeoutTxOutput`. + mockTimeoutTxSpend := func(ctx *htlcResolverTestContext) { + select { + case ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: reSignedTimeoutTx, + SpenderInputIndex: 1, + SpenderTxHash: &reSignedHash, + SpendingHeight: 10, + SpentOutPoint: &htlcOutpoint, + }: + + case <-time.After(time.Second * 1): + t.Fatalf("spend not sent") + } + } + + // mockSweepTxSpend is a helper closure to mock `waitForSpend` to + // return the commit spend in `sweepTimeoutTxOutput`. + mockSweepTxSpend := func(ctx *htlcResolverTestContext) { + select { + case ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: sweepTx, + SpenderInputIndex: 1, + SpenderTxHash: &sweepHash, + SpendingHeight: 10, + SpentOutPoint: &timeoutTxOutpoint, + }: + + case <-time.After(time.Second * 1): + t.Fatalf("spend not sent") + } + } checkpoints := []checkpoint{ { @@ -1067,28 +1145,40 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { _ bool) error { resolver := ctx.resolver.(*htlcTimeoutResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs + + var ( + inp input.Input + ok bool + ) + + select { + case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs: + require.True(t, ok) + + case <-time.After(1 * time.Second): + t.Fatal("expected input to be swept") + } + op := inp.OutPoint() - if op != commitOutpoint { + if op != htlcOutpoint { return fmt.Errorf("outpoint %v swept, "+ - "expected %v", op, - commitOutpoint) + "expected %v", op, htlcOutpoint) } - // Emulat the sweeper spending using the - // re-signed timeout tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: reSignedTimeoutTx, - SpenderInputIndex: 1, - SpenderTxHash: &reSignedHash, - SpendingHeight: 10, - } + // Mock `waitForSpend` twice, called in, + // - `resolveReSignedTimeoutTx` + // - `sweepTimeoutTxOutput`. + mockTimeoutTxSpend(ctx) + mockTimeoutTxSpend(ctx) return nil }, // incubating=true is used to signal that the // second-level transaction was confirmed. incubating: true, + reports: []*channeldb.ResolverReport{ + firstStage, + }, }, { // We send a confirmation for our sweep tx to indicate @@ -1096,18 +1186,18 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { preCheckpoint: func(ctx *htlcResolverTestContext, resumed bool) error { - // If we are resuming from a checkpoint, we - // expect the resolver to re-subscribe to a - // spend, hence we must resend it. + // Mock `waitForSpend` to return the commit + // spend. if resumed { - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: reSignedTimeoutTx, - SpenderInputIndex: 1, - SpenderTxHash: &reSignedHash, - SpendingHeight: 10, - } + mockTimeoutTxSpend(ctx) + mockTimeoutTxSpend(ctx) + mockSweepTxSpend(ctx) + + return nil } + mockSweepTxSpend(ctx) + // The resolver should deliver a failure // resolution message (indicating we // successfully timed out the HTLC). @@ -1120,15 +1210,23 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { t.Fatalf("resolution not sent") } - // Mimic CSV lock expiring. - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } - // The timeout tx output should now be given to // the sweeper. resolver := ctx.resolver.(*htlcTimeoutResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs + + var ( + inp input.Input + ok bool + ) + + select { + case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs: + require.True(t, ok) + + case <-time.After(1 * time.Second): + t.Fatal("expected input to be swept") + } + op := inp.OutPoint() exp := wire.OutPoint{ Hash: reSignedHash, @@ -1138,14 +1236,6 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { return fmt.Errorf("wrong outpoint swept") } - // Notify about the spend, which should resolve - // the resolver. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: sweepTx, - SpenderTxHash: &sweepHash, - SpendingHeight: 14, - } - return nil }, @@ -1155,7 +1245,6 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { incubating: true, resolved: true, reports: []*channeldb.ResolverReport{ - firstStage, secondState, }, }, @@ -1236,33 +1325,6 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) { } checkpoints := []checkpoint{ - { - // The output should be given to the sweeper. - preCheckpoint: func(ctx *htlcResolverTestContext, - _ bool) error { - - resolver := ctx.resolver.(*htlcTimeoutResolver) - inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs - op := inp.OutPoint() - if op != commitOutpoint { - return fmt.Errorf("outpoint %v swept, "+ - "expected %v", op, - commitOutpoint) - } - - // Emulate the remote sweeping the output with the preimage. - // re-signed timeout tx. - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: spendTx, - SpenderTxHash: &spendTxHash, - } - - return nil - }, - // incubating=true is used to signal that the - // second-level transaction was confirmed. - incubating: true, - }, { // We send a confirmation for our sweep tx to indicate // that our sweep succeeded. @@ -1277,6 +1339,7 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) { ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ SpendingTx: spendTx, SpenderTxHash: &spendTxHash, + SpentOutPoint: &commitOutpoint, } } @@ -1314,7 +1377,7 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) { // After the sweep has confirmed, we expect the // checkpoint to be resolved, and with the above // reports. - incubating: true, + incubating: false, resolved: true, reports: []*channeldb.ResolverReport{ claim, @@ -1339,21 +1402,26 @@ func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution, // for the next portion of the test. ctx := newHtlcResolverTestContext(t, func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver { - return &htlcTimeoutResolver{ + r := &htlcTimeoutResolver{ contractResolverKit: *newContractResolverKit(cfg), htlc: htlc, htlcResolution: resolution, } + r.initLogger("htlcTimeoutResolver") + + return r }, ) checkpointedState := runFromCheckpoint(t, ctx, checkpoints) + t.Log("Running resolver to completion after restart") + // Now, from every checkpoint created, we re-create the resolver, and // run the test from that checkpoint. for i := range checkpointedState { cp := bytes.NewReader(checkpointedState[i]) - ctx := newHtlcResolverTestContext(t, + ctx := newHtlcResolverTestContextFromReader(t, func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver { resolver, err := newTimeoutResolverFromReader(cp, cfg) if err != nil { @@ -1361,7 +1429,8 @@ func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution, } resolver.Supplement(htlc) - resolver.htlcResolution = resolution + resolver.initLogger("htlcTimeoutResolver") + return resolver }, ) diff --git a/contractcourt/mock_registry_test.go b/contractcourt/mock_registry_test.go index 5bba11afcb..0530ab51dd 100644 --- a/contractcourt/mock_registry_test.go +++ b/contractcourt/mock_registry_test.go @@ -29,6 +29,11 @@ func (r *mockRegistry) NotifyExitHopHtlc(payHash lntypes.Hash, wireCustomRecords lnwire.CustomRecords, payload invoices.Payload) (invoices.HtlcResolution, error) { + // Exit early if the notification channel is nil. + if hodlChan == nil { + return r.notifyResolution, r.notifyErr + } + r.notifyChan <- notifyExitHopData{ hodlChan: hodlChan, payHash: payHash, diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go index aef906a0ad..f78be9fa49 100644 --- a/contractcourt/utxonursery.go +++ b/contractcourt/utxonursery.go @@ -15,7 +15,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" @@ -794,7 +794,7 @@ func (u *UtxoNursery) graduateClass(classHeight uint32) error { return err } - utxnLog.Infof("Attempting to graduate height=%v: num_kids=%v, "+ + utxnLog.Debugf("Attempting to graduate height=%v: num_kids=%v, "+ "num_babies=%v", classHeight, len(kgtnOutputs), len(cribOutputs)) // Offer the outputs to the sweeper and set up notifications that will diff --git a/contractcourt/utxonursery_test.go b/contractcourt/utxonursery_test.go index 796d1ed239..f1b47cc2ca 100644 --- a/contractcourt/utxonursery_test.go +++ b/contractcourt/utxonursery_test.go @@ -18,7 +18,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/discovery/gossiper.go b/discovery/gossiper.go index aafadba479..9c51734396 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -19,7 +19,7 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 85a4e0657e..b74f69bf0a 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -24,7 +24,7 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" diff --git a/docs/grpc/ruby.md b/docs/grpc/ruby.md index 599dd2bc7a..457246423f 100644 --- a/docs/grpc/ruby.md +++ b/docs/grpc/ruby.md @@ -58,7 +58,7 @@ $:.unshift(File.dirname(__FILE__)) require 'grpc' require 'lightning_services_pb' -# Due to updated ECDSA generated tls.cert we need to let gprc know that +# Due to updated ECDSA generated tls.cert we need to let grpc know that # we need to use that cipher suite otherwise there will be a handshake # error when we communicate with the lnd rpc server. ENV['GRPC_SSL_CIPHER_SUITES'] = "HIGH+ECDSA" diff --git a/docs/release-notes/release-notes-0.18.4.md b/docs/release-notes/release-notes-0.18.4.md index 1fd299f3d7..ab72c01c31 100644 --- a/docs/release-notes/release-notes-0.18.4.md +++ b/docs/release-notes/release-notes-0.18.4.md @@ -23,6 +23,10 @@ cause a nil pointer dereference during the probing of a payment request that does not contain a payment address. +* [Fixed a bug](https://github.com/lightningnetwork/lnd/pull/9324) to prevent + potential deadlocks when LND depends on external components (e.g. aux + components, hooks). + # New Features The main channel state machine and database now allow for processing and storing diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index db5ca738ef..ee0b0c1360 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -90,6 +90,10 @@ * [The `walletrpc.FundPsbt` method now has a new option to specify the maximum fee to output amounts ratio.](https://github.com/lightningnetwork/lnd/pull/8600) +* When returning the response from list invoices RPC, the `lnrpc.Invoice.Htlcs` + are now [sorted](https://github.com/lightningnetwork/lnd/pull/9337) based on + the `InvoiceHTLC.HtlcIndex`. + ## lncli Additions * [A pre-generated macaroon root key can now be specified in `lncli create` and diff --git a/funding/aux_funding.go b/funding/aux_funding.go index 492612145a..c7ef653f47 100644 --- a/funding/aux_funding.go +++ b/funding/aux_funding.go @@ -2,7 +2,7 @@ package funding import ( "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/msgmux" diff --git a/funding/manager.go b/funding/manager.go index c8a54d9588..395cccb2a6 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -23,7 +23,7 @@ import ( "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/discovery" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" diff --git a/funding/manager_test.go b/funding/manager_test.go index 525f69f9a5..b6130176d1 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -27,7 +27,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/discovery" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" diff --git a/go.mod b/go.mod index 1330f9a84a..bbb421de40 100644 --- a/go.mod +++ b/go.mod @@ -36,13 +36,13 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 - github.com/lightningnetwork/lnd/fn v1.2.5 + github.com/lightningnetwork/lnd/fn/v2 v2.0.2 github.com/lightningnetwork/lnd/healthcheck v1.2.6 github.com/lightningnetwork/lnd/kvdb v1.4.11 github.com/lightningnetwork/lnd/queue v1.1.1 github.com/lightningnetwork/lnd/sqldb v1.0.5 github.com/lightningnetwork/lnd/ticker v1.1.1 - github.com/lightningnetwork/lnd/tlv v1.2.6 + github.com/lightningnetwork/lnd/tlv v1.3.0 github.com/lightningnetwork/lnd/tor v1.1.4 github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 github.com/miekg/dns v1.1.43 diff --git a/go.sum b/go.sum index aa04dc5fce..4c452df735 100644 --- a/go.sum +++ b/go.sum @@ -456,8 +456,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ= -github.com/lightningnetwork/lnd/fn v1.2.5 h1:pGMz0BDUxrhvOtShD4FIysdVy+ulfFAnFvTKjZO5Pp8= -github.com/lightningnetwork/lnd/fn v1.2.5/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0= +github.com/lightningnetwork/lnd/fn/v2 v2.0.2 h1:M7o2lYrh/zCp+lntPB3WP/rWTu5U+4ssyHW+kqNJ0fs= +github.com/lightningnetwork/lnd/fn/v2 v2.0.2/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s= github.com/lightningnetwork/lnd/healthcheck v1.2.6 h1:1sWhqr93GdkWy4+6U7JxBfcyZIE78MhIHTJZfPx7qqI= github.com/lightningnetwork/lnd/healthcheck v1.2.6/go.mod h1:Mu02um4CWY/zdTOvFje7WJgJcHyX2zq/FG3MhOAiGaQ= github.com/lightningnetwork/lnd/kvdb v1.4.11 h1:fk1HMVFrsVK3xqU7q+JWHRgBltw/a2qIg1E3zazMb/8= @@ -468,8 +468,8 @@ github.com/lightningnetwork/lnd/sqldb v1.0.5 h1:ax5vBPf44tN/uD6C5+hBPBjOJ7cRMrUL github.com/lightningnetwork/lnd/sqldb v1.0.5/go.mod h1:OG09zL/PHPaBJefp4HsPz2YLUJ+zIQHbpgCtLnOx8I4= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= -github.com/lightningnetwork/lnd/tlv v1.2.6 h1:icvQG2yDr6k3ZuZzfRdG3EJp6pHurcuh3R6dg0gv/Mw= -github.com/lightningnetwork/lnd/tlv v1.2.6/go.mod h1:/CmY4VbItpOldksocmGT4lxiJqRP9oLxwSZOda2kzNQ= +github.com/lightningnetwork/lnd/tlv v1.3.0 h1:exS/KCPEgpOgviIttfiXAPaUqw2rHQrnUOpP7HPBPiY= +github.com/lightningnetwork/lnd/tlv v1.3.0/go.mod h1:pJuiBj1ecr1WWLOtcZ+2+hu9Ey25aJWFIsjmAoPPnmc= github.com/lightningnetwork/lnd/tor v1.1.4 h1:TUW27EXqoZCcCAQPlD4aaDfh8jMbBS9CghNz50qqwtA= github.com/lightningnetwork/lnd/tor v1.1.4/go.mod h1:qSRB8llhAK+a6kaTPWOLLXSZc6Hg8ZC0mq1sUQ/8JfI= github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 h1:sjOGyegMIhvgfq5oaue6Td+hxZuf3tDC8lAPrFldqFw= diff --git a/graph/builder.go b/graph/builder.go index c0133e02ec..d6984af709 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -16,7 +16,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" diff --git a/graph/db/models/channel_edge_info.go b/graph/db/models/channel_edge_info.go index 0f91e2bbec..6aa67acc6a 100644 --- a/graph/db/models/channel_edge_info.go +++ b/graph/db/models/channel_edge_info.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) // ChannelEdgeInfo represents a fully authenticated channel along with all its diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index c48436173f..6414c9f802 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -8,7 +8,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index d8f55afc69..78cb019892 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" @@ -206,6 +206,11 @@ const ( Outgoing LinkDirection = true ) +// OptionalBandwidth is a type alias for the result of a bandwidth query that +// may return a bandwidth value or fn.None if the bandwidth is not available or +// not applicable. +type OptionalBandwidth = fn.Option[lnwire.MilliSatoshi] + // ChannelLink is an interface which represents the subsystem for managing the // incoming htlc requests, applying the changes to the channel, and also // propagating/forwarding it to htlc switch. @@ -267,10 +272,10 @@ type ChannelLink interface { // in order to signal to the source of the HTLC, the policy consistency // issue. CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi, - amtToForward lnwire.MilliSatoshi, - incomingTimeout, outgoingTimeout uint32, - inboundFee models.InboundFee, - heightNow uint32, scid lnwire.ShortChannelID) *LinkError + amtToForward lnwire.MilliSatoshi, incomingTimeout, + outgoingTimeout uint32, inboundFee models.InboundFee, + heightNow uint32, scid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError // CheckHtlcTransit should return a nil error if the passed HTLC details // satisfy the current channel policy. Otherwise, a LinkError with a @@ -278,14 +283,15 @@ type ChannelLink interface { // the violation. This call is intended to be used for locally initiated // payments for which there is no corresponding incoming htlc. CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, - timeout uint32, heightNow uint32) *LinkError + timeout uint32, heightNow uint32, + customRecords lnwire.CustomRecords) *LinkError // Stats return the statistics of channel link. Number of updates, // total sent/received milli-satoshis. Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) - // Peer returns the serialized public key of remote peer with which we - // have the channel link opened. + // PeerPubKey returns the serialized public key of remote peer with + // which we have the channel link opened. PeerPubKey() [33]byte // AttachMailBox delivers an active MailBox to the link. The MailBox may @@ -302,9 +308,18 @@ type ChannelLink interface { // commitment of the channel that this link is associated with. CommitmentCustomBlob() fn.Option[tlv.Blob] - // Start/Stop are used to initiate the start/stop of the channel link - // functioning. + // AuxBandwidth returns the bandwidth that can be used for a channel, + // expressed in milli-satoshi. This might be different from the regular + // BTC bandwidth for custom channels. This will always return fn.None() + // for a regular (non-custom) channel. + AuxBandwidth(amount lnwire.MilliSatoshi, cid lnwire.ShortChannelID, + htlcBlob fn.Option[tlv.Blob], + ts AuxTrafficShaper) fn.Result[OptionalBandwidth] + + // Start starts the channel link. Start() error + + // Stop requests the channel link to be shut down. Stop() } @@ -440,7 +455,7 @@ type htlcNotifier interface { NotifyForwardingEvent(key HtlcKey, info HtlcInfo, eventType HtlcEventType) - // NotifyIncomingLinkFailEvent notifies that a htlc has failed on our + // NotifyLinkFailEvent notifies that a htlc has failed on our // incoming link. It takes an isReceive bool to differentiate between // our node's receives and forwards. NotifyLinkFailEvent(key HtlcKey, info HtlcInfo, @@ -461,3 +476,36 @@ type htlcNotifier interface { NotifyFinalHtlcEvent(key models.CircuitKey, info channeldb.FinalHtlcInfo) } + +// AuxHtlcModifier is an interface that allows the sender to modify the outgoing +// HTLC of a payment by changing the amount or the wire message tlv records. +type AuxHtlcModifier interface { + // ProduceHtlcExtraData is a function that, based on the previous extra + // data blob of an HTLC, may produce a different blob or modify the + // amount of bitcoin this htlc should carry. + ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, + htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, + lnwire.CustomRecords, error) +} + +// AuxTrafficShaper is an interface that allows the sender to determine if a +// payment should be carried by a channel based on the TLV records that may be +// present in the `update_add_htlc` message or the channel commitment itself. +type AuxTrafficShaper interface { + AuxHtlcModifier + + // ShouldHandleTraffic is called in order to check if the channel + // identified by the provided channel ID may have external mechanisms + // that would allow it to carry out the payment. + ShouldHandleTraffic(cid lnwire.ShortChannelID, + fundingBlob fn.Option[tlv.Blob]) (bool, error) + + // PaymentBandwidth returns the available bandwidth for a custom channel + // decided by the given channel aux blob and HTLC blob. A return value + // of 0 means there is no bandwidth available. To find out if a channel + // is a custom channel that should be handled by the traffic shaper, the + // ShouldHandleTraffic method should be called first. + PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob], + linkBandwidth, + htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) +} diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 60062862ef..6e67e87def 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -16,7 +16,7 @@ import ( "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -293,6 +293,10 @@ type ChannelLinkConfig struct { // ShouldFwdExpEndorsement is a closure that indicates whether the link // should forward experimental endorsement signals. ShouldFwdExpEndorsement func() bool + + // AuxTrafficShaper is an optional auxiliary traffic shaper that can be + // used to manage the bandwidth of the link. + AuxTrafficShaper fn.Option[AuxTrafficShaper] } // channelLink is the service which drives a channel's commitment update @@ -3233,11 +3237,11 @@ func (l *channelLink) UpdateForwardingPolicy( // issue. // // NOTE: Part of the ChannelLink interface. -func (l *channelLink) CheckHtlcForward(payHash [32]byte, - incomingHtlcAmt, amtToForward lnwire.MilliSatoshi, - incomingTimeout, outgoingTimeout uint32, - inboundFee models.InboundFee, - heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { +func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt, + amtToForward lnwire.MilliSatoshi, incomingTimeout, + outgoingTimeout uint32, inboundFee models.InboundFee, + heightNow uint32, originalScid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError { l.RLock() policy := l.cfg.FwrdingPolicy @@ -3286,7 +3290,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // Check whether the outgoing htlc satisfies the channel policy. err := l.canSendHtlc( policy, payHash, amtToForward, outgoingTimeout, heightNow, - originalScid, + originalScid, customRecords, ) if err != nil { return err @@ -3322,8 +3326,8 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // the violation. This call is intended to be used for locally initiated // payments for which there is no corresponding incoming htlc. func (l *channelLink) CheckHtlcTransit(payHash [32]byte, - amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32) *LinkError { + amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32, + customRecords lnwire.CustomRecords) *LinkError { l.RLock() policy := l.cfg.FwrdingPolicy @@ -3334,6 +3338,7 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, // to occur. return l.canSendHtlc( policy, payHash, amt, timeout, heightNow, hop.Source, + customRecords, ) } @@ -3341,7 +3346,8 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, // the channel's amount and time lock constraints. func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { + heightNow uint32, originalScid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError { // As our first sanity check, we'll ensure that the passed HTLC isn't // too small for the next hop. If so, then we'll cancel the HTLC @@ -3399,8 +3405,38 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, return NewLinkError(&lnwire.FailExpiryTooFar{}) } + // We now check the available bandwidth to see if this HTLC can be + // forwarded. + availableBandwidth := l.Bandwidth() + auxBandwidth, err := fn.MapOptionZ( + l.cfg.AuxTrafficShaper, + func(ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { + var htlcBlob fn.Option[tlv.Blob] + blob, err := customRecords.Serialize() + if err != nil { + return fn.Err[OptionalBandwidth]( + fmt.Errorf("unable to serialize "+ + "custom records: %w", err)) + } + + if len(blob) > 0 { + htlcBlob = fn.Some(blob) + } + + return l.AuxBandwidth(amt, originalScid, htlcBlob, ts) + }, + ).Unpack() + if err != nil { + l.log.Errorf("Unable to determine aux bandwidth: %v", err) + return NewLinkError(&lnwire.FailTemporaryNodeFailure{}) + } + + auxBandwidth.WhenSome(func(bandwidth lnwire.MilliSatoshi) { + availableBandwidth = bandwidth + }) + // Check to see if there is enough balance in this channel. - if amt > l.Bandwidth() { + if amt > availableBandwidth { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, l.Bandwidth()) cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { @@ -3415,6 +3451,48 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, return nil } +// AuxBandwidth returns the bandwidth that can be used for a channel, expressed +// in milli-satoshi. This might be different from the regular BTC bandwidth for +// custom channels. This will always return fn.None() for a regular (non-custom) +// channel. +func (l *channelLink) AuxBandwidth(amount lnwire.MilliSatoshi, + cid lnwire.ShortChannelID, htlcBlob fn.Option[tlv.Blob], + ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { + + unknownBandwidth := fn.None[lnwire.MilliSatoshi]() + + fundingBlob := l.FundingCustomBlob() + shouldHandle, err := ts.ShouldHandleTraffic(cid, fundingBlob) + if err != nil { + return fn.Err[OptionalBandwidth](fmt.Errorf("traffic shaper "+ + "failed to decide whether to handle traffic: %w", err)) + } + + log.Debugf("ShortChannelID=%v: aux traffic shaper is handling "+ + "traffic: %v", cid, shouldHandle) + + // If this channel isn't handled by the aux traffic shaper, we'll return + // early. + if !shouldHandle { + return fn.Ok(unknownBandwidth) + } + + // Ask for a specific bandwidth to be used for the channel. + commitmentBlob := l.CommitmentCustomBlob() + auxBandwidth, err := ts.PaymentBandwidth( + htlcBlob, commitmentBlob, l.Bandwidth(), amount, + ) + if err != nil { + return fn.Err[OptionalBandwidth](fmt.Errorf("failed to get "+ + "bandwidth from external traffic shaper: %w", err)) + } + + log.Debugf("ShortChannelID=%v: aux traffic shaper reported available "+ + "bandwidth: %v", cid, auxBandwidth) + + return fn.Ok(fn.Some(auxBandwidth)) +} + // Stats returns the statistics of channel link. // // NOTE: Part of the ChannelLink interface. diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 80632b07e9..1747105597 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -26,7 +26,7 @@ import ( "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -6243,9 +6243,9 @@ func TestCheckHtlcForward(t *testing.T) { var hash [32]byte t.Run("satisfied", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if result != nil { t.Fatalf("expected policy to be satisfied") @@ -6253,9 +6253,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("below minhtlc", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 100, 50, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 100, 50, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok { t.Fatalf("expected FailAmountBelowMinimum failure code") @@ -6263,9 +6263,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("above maxhtlc", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1200, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1200, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok { t.Fatalf("expected FailTemporaryChannelFailure failure code") @@ -6273,9 +6273,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("insufficient fee", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1005, 1000, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1005, 1000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok { t.Fatalf("expected FailFeeInsufficient failure code") @@ -6288,17 +6288,17 @@ func TestCheckHtlcForward(t *testing.T) { t.Parallel() result := link.CheckHtlcForward( - hash, 100005, 100000, 200, - 150, models.InboundFee{}, 0, lnwire.ShortChannelID{}, + hash, 100005, 100000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient) require.True(t, ok, "expected FailFeeInsufficient failure code") }) t.Run("expiry too soon", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, models.InboundFee{}, 190, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 150, models.InboundFee{}, 190, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok { t.Fatalf("expected FailExpiryTooSoon failure code") @@ -6306,9 +6306,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("incorrect cltv expiry", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 190, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 190, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok { t.Fatalf("expected FailIncorrectCltvExpiry failure code") @@ -6318,9 +6318,9 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("cltv expiry too far in the future", func(t *testing.T) { // Check that expiry isn't too far in the future. - result := link.CheckHtlcForward(hash, 1500, 1000, - 10200, 10100, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 10200, 10100, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok { t.Fatalf("expected FailExpiryTooFar failure code") @@ -6330,9 +6330,11 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("inbound fee satisfied", func(t *testing.T) { t.Parallel() - result := link.CheckHtlcForward(hash, 1000+10-2-1, 1000, - 200, 150, models.InboundFee{Base: -2, Rate: -1_000}, - 0, lnwire.ShortChannelID{}) + result := link.CheckHtlcForward( + hash, 1000+10-2-1, 1000, 200, 150, + models.InboundFee{Base: -2, Rate: -1_000}, + 0, lnwire.ShortChannelID{}, nil, + ) if result != nil { t.Fatalf("expected policy to be satisfied") } @@ -6341,9 +6343,11 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("inbound fee insufficient", func(t *testing.T) { t.Parallel() - result := link.CheckHtlcForward(hash, 1000+10-10-101-1, 1000, + result := link.CheckHtlcForward( + hash, 1000+10-10-101-1, 1000, 200, 150, models.InboundFee{Base: -10, Rate: -100_000}, - 0, lnwire.ShortChannelID{}) + 0, lnwire.ShortChannelID{}, nil, + ) msg := result.WireMessage() if _, ok := msg.(*lnwire.FailFeeInsufficient); !ok { diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index ce791bef32..5cf7966573 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -24,7 +24,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" @@ -846,14 +846,14 @@ func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) { } func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi, lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32, - lnwire.ShortChannelID) *LinkError { + lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError { return f.checkHtlcForwardResult } func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32) *LinkError { + heightNow uint32, _ lnwire.CustomRecords) *LinkError { return f.checkHtlcTransitResult } @@ -968,6 +968,17 @@ func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { return fn.None[tlv.Blob]() } +// AuxBandwidth returns the bandwidth that can be used for a channel, +// expressed in milli-satoshi. This might be different from the regular +// BTC bandwidth for custom channels. This will always return fn.None() +// for a regular (non-custom) channel. +func (f *mockChannelLink) AuxBandwidth(lnwire.MilliSatoshi, + lnwire.ShortChannelID, + fn.Option[tlv.Blob], AuxTrafficShaper) fn.Result[OptionalBandwidth] { + + return fn.Ok(fn.None[lnwire.MilliSatoshi]()) +} + var _ ChannelLink = (*mockChannelLink)(nil) const testInvoiceCltvExpiry = 6 diff --git a/htlcswitch/quiescer.go b/htlcswitch/quiescer.go index 27d0deb8c6..468ad5e708 100644 --- a/htlcswitch/quiescer.go +++ b/htlcswitch/quiescer.go @@ -6,7 +6,7 @@ import ( "time" "github.com/btcsuite/btclog/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/htlcswitch/quiescer_test.go b/htlcswitch/quiescer_test.go index da08909d57..6ce9563e45 100644 --- a/htlcswitch/quiescer_test.go +++ b/htlcswitch/quiescer_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/require" diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 1a08275ec9..c94c677966 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -17,7 +17,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/kvdb" @@ -917,6 +917,7 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( currentHeight := atomic.LoadUint32(&s.bestHeight) htlcErr := link.CheckHtlcTransit( htlc.PaymentHash, htlc.Amount, htlc.Expiry, currentHeight, + htlc.CustomRecords, ) if htlcErr != nil { log.Errorf("Link %v policy for local forward not "+ @@ -1605,7 +1606,7 @@ out: } } - log.Infof("Received outside contract resolution, "+ + log.Debugf("Received outside contract resolution, "+ "mapping to: %v", spew.Sdump(pkt)) // We don't check the error, as the only failure we can @@ -2887,10 +2888,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket, failure = link.CheckHtlcForward( htlc.PaymentHash, packet.incomingAmount, packet.amount, packet.incomingTimeout, - packet.outgoingTimeout, - packet.inboundFee, - currentHeight, - packet.originalOutgoingChanID, + packet.outgoingTimeout, packet.inboundFee, + currentHeight, packet.originalOutgoingChanID, + htlc.CustomRecords, ) } diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index abfb8e4d5b..8809321460 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -17,7 +17,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" diff --git a/input/input.go b/input/input.go index 088b20401f..4a9a4b55c0 100644 --- a/input/input.go +++ b/input/input.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/input/mocks.go b/input/mocks.go index bbd4550c5f..6d90bc28df 100644 --- a/input/mocks.go +++ b/input/mocks.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/tlv" diff --git a/input/script_utils.go b/input/script_utils.go index 91ca55292f..000efe9585 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "golang.org/x/crypto/ripemd160" diff --git a/input/taproot.go b/input/taproot.go index 2ca6e97236..5ca4dd0c66 100644 --- a/input/taproot.go +++ b/input/taproot.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/waddrmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) const ( diff --git a/input/taproot_test.go b/input/taproot_test.go index a1259be196..3a1e000374 100644 --- a/input/taproot_test.go +++ b/input/taproot_test.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" diff --git a/intercepted_forward.go b/intercepted_forward.go index 791d4bd583..5cb1ca192b 100644 --- a/intercepted_forward.go +++ b/intercepted_forward.go @@ -3,7 +3,7 @@ package lnd import ( "errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index f5a6c6a95f..cc76d5aefa 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -1275,7 +1275,11 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked( invoiceToExpire = makeInvoiceExpiry(ctx.hash, invoice) } - i.hodlSubscribe(hodlChan, ctx.circuitKey) + // Subscribe to the resolution if the caller specified a + // notification channel. + if hodlChan != nil { + i.hodlSubscribe(hodlChan, ctx.circuitKey) + } default: panic("unknown action") diff --git a/invoices/modification_interceptor.go b/invoices/modification_interceptor.go index 97e75e8cc5..58f5b63d07 100644 --- a/invoices/modification_interceptor.go +++ b/invoices/modification_interceptor.go @@ -5,7 +5,7 @@ import ( "fmt" "sync/atomic" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) var ( diff --git a/itest/lnd_funding_test.go b/itest/lnd_funding_test.go index 54180abf57..0b08da32b3 100644 --- a/itest/lnd_funding_test.go +++ b/itest/lnd_funding_test.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainreg" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" diff --git a/itest/lnd_sweep_test.go b/itest/lnd_sweep_test.go index 099014aff0..158e8768f9 100644 --- a/itest/lnd_sweep_test.go +++ b/itest/lnd_sweep_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" @@ -1119,9 +1119,9 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // The sweeping tx has two inputs, one from wallet, the other // from the force close tx. We now check whether the first tx // spends from the force close tx of Alice->Bob. - found := fn.Any(func(inp *wire.TxIn) bool { + found := fn.Any(txns[0].TxIn, func(inp *wire.TxIn) bool { return inp.PreviousOutPoint.Hash == abCloseTxid - }, txns[0].TxIn) + }) // If the first tx spends an outpoint from the force close tx // of Alice->Bob, then it must be the incoming HTLC sweeping diff --git a/lnrpc/devrpc/dev_server.go b/lnrpc/devrpc/dev_server.go index 60f30dd7ed..b26b144c81 100644 --- a/lnrpc/devrpc/dev_server.go +++ b/lnrpc/devrpc/dev_server.go @@ -16,7 +16,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" diff --git a/lnrpc/invoicesrpc/utils.go b/lnrpc/invoicesrpc/utils.go index 955ba6acf2..19ade28fd8 100644 --- a/lnrpc/invoicesrpc/utils.go +++ b/lnrpc/invoicesrpc/utils.go @@ -1,8 +1,10 @@ package invoicesrpc import ( + "cmp" "encoding/hex" "fmt" + "slices" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/chaincfg" @@ -160,6 +162,11 @@ func CreateRPCInvoice(invoice *invoices.Invoice, rpcHtlcs = append(rpcHtlcs, &rpcHtlc) } + // Perform an inplace sort of the HTLCs to ensure they are ordered. + slices.SortFunc(rpcHtlcs, func(i, j *lnrpc.InvoiceHTLC) int { + return cmp.Compare(i.HtlcIndex, j.HtlcIndex) + }) + rpcInvoice := &lnrpc.Invoice{ Memo: string(invoice.Memo), RHash: rHash, diff --git a/lnrpc/marshall_utils.go b/lnrpc/marshall_utils.go index 230fea35b6..96d3342d83 100644 --- a/lnrpc/marshall_utils.go +++ b/lnrpc/marshall_utils.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/wallet" "github.com/lightningnetwork/lnd/aliasmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "golang.org/x/exp/maps" @@ -221,12 +221,18 @@ func UnmarshallCoinSelectionStrategy(strategy CoinSelectionStrategy, // MarshalAliasMap converts a ScidAliasMap to its proto counterpart. This is // used in various RPCs that handle scid alias mappings. func MarshalAliasMap(scidMap aliasmgr.ScidAliasMap) []*AliasMap { - return fn.Map(func(base lnwire.ShortChannelID) *AliasMap { - return &AliasMap{ - BaseScid: base.ToUint64(), - Aliases: fn.Map(func(a lnwire.ShortChannelID) uint64 { - return a.ToUint64() - }, scidMap[base]), - } - }, maps.Keys(scidMap)) + return fn.Map( + maps.Keys(scidMap), + func(base lnwire.ShortChannelID) *AliasMap { + return &AliasMap{ + BaseScid: base.ToUint64(), + Aliases: fn.Map( + scidMap[base], + func(a lnwire.ShortChannelID) uint64 { + return a.ToUint64() + }, + ), + } + }, + ) } diff --git a/lnrpc/routerrpc/forward_interceptor.go b/lnrpc/routerrpc/forward_interceptor.go index 9da831ac04..72df3d0199 100644 --- a/lnrpc/routerrpc/forward_interceptor.go +++ b/lnrpc/routerrpc/forward_interceptor.go @@ -3,7 +3,7 @@ package routerrpc import ( "errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnrpc" diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 9421e991b6..7d73681094 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -16,7 +16,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 7f1a7edf07..9499fa25a3 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -16,7 +16,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnrpc/walletrpc/walletkit_server.go b/lnrpc/walletrpc/walletkit_server.go index c6dec6fbd5..4f477cdbd4 100644 --- a/lnrpc/walletrpc/walletkit_server.go +++ b/lnrpc/walletrpc/walletkit_server.go @@ -31,7 +31,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/labels" @@ -1145,9 +1145,9 @@ func (w *WalletKit) getWaitingCloseChannel( return nil, err } - channel := fn.Find(func(c *channeldb.OpenChannel) bool { + channel := fn.Find(chans, func(c *channeldb.OpenChannel) bool { return c.FundingOutpoint == chanPoint - }, chans) + }) return channel.UnwrapOrErr(errors.New("channel not found")) } @@ -1231,18 +1231,23 @@ func (w *WalletKit) BumpForceCloseFee(_ context.Context, pendingSweeps := maps.Values(inputsMap) // Discard everything except for the anchor sweeps. - anchors := fn.Filter(func(sweep *sweep.PendingInputResponse) bool { - // Only filter for anchor inputs because these are the only - // inputs which can be used to bump a closed unconfirmed - // commitment transaction. - if sweep.WitnessType != input.CommitmentAnchor && - sweep.WitnessType != input.TaprootAnchorSweepSpend { - - return false - } + anchors := fn.Filter( + pendingSweeps, + func(sweep *sweep.PendingInputResponse) bool { + // Only filter for anchor inputs because these are the + // only inputs which can be used to bump a closed + // unconfirmed commitment transaction. + isCommitAnchor := sweep.WitnessType == + input.CommitmentAnchor + isTaprootSweepSpend := sweep.WitnessType == + input.TaprootAnchorSweepSpend + if !isCommitAnchor && !isTaprootSweepSpend { + return false + } - return commitSet.Contains(sweep.OutPoint.Hash) - }, pendingSweeps) + return commitSet.Contains(sweep.OutPoint.Hash) + }, + ) if len(anchors) == 0 { return nil, fmt.Errorf("unable to find pending anchor outputs") @@ -1754,7 +1759,7 @@ func (w *WalletKit) fundPsbtInternalWallet(account string, return true } - eligibleUtxos := fn.Filter(filterFn, utxos) + eligibleUtxos := fn.Filter(utxos, filterFn) // Validate all inputs against our known list of UTXOs // now. diff --git a/lntest/harness.go b/lntest/harness.go index f96a3aadd7..8e8fcd3936 100644 --- a/lntest/harness.go +++ b/lntest/harness.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb/etcd" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" diff --git a/lntest/harness_assertion.go b/lntest/harness_assertion.go index 1b079fea16..11cbefdd5c 100644 --- a/lntest/harness_assertion.go +++ b/lntest/harness_assertion.go @@ -19,7 +19,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" @@ -270,7 +270,7 @@ func (h *HarnessTest) AssertNumActiveEdges(hn *node.HarnessNode, IncludeUnannounced: includeUnannounced, } resp := hn.RPC.DescribeGraph(req) - activeEdges := fn.Filter(filterDisabled, resp.Edges) + activeEdges := fn.Filter(resp.Edges, filterDisabled) total := len(activeEdges) if total-old == expected { diff --git a/lntest/miner/miner.go b/lntest/miner/miner.go index e9e380bbb3..0229d6a47f 100644 --- a/lntest/miner/miner.go +++ b/lntest/miner/miner.go @@ -17,7 +17,7 @@ import ( "github.com/btcsuite/btcd/integration/rpctest" "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntest/node" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/stretchr/testify/require" @@ -296,10 +296,7 @@ func (h *HarnessMiner) AssertTxInMempool(txid chainhash.Hash) *wire.MsgTx { return fmt.Errorf("empty mempool") } - isEqual := func(memTx chainhash.Hash) bool { - return memTx == txid - } - result := fn.Find(isEqual, mempool) + result := fn.Find(mempool, fn.Eq(txid)) if result.IsNone() { return fmt.Errorf("txid %v not found in "+ diff --git a/lntest/mock/walletcontroller.go b/lntest/mock/walletcontroller.go index 8b7ef55380..fa623bf84d 100644 --- a/lntest/mock/walletcontroller.go +++ b/lntest/mock/walletcontroller.go @@ -16,7 +16,7 @@ import ( base "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) diff --git a/lntest/node/state.go b/lntest/node/state.go index a89ab7d2cc..38f02f3a4c 100644 --- a/lntest/node/state.go +++ b/lntest/node/state.go @@ -7,7 +7,7 @@ import ( "time" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/walletrpc" "github.com/lightningnetwork/lnd/lntest/rpc" @@ -324,11 +324,11 @@ func (s *State) updateEdgeStats() { req := &lnrpc.ChannelGraphRequest{IncludeUnannounced: true} resp := s.rpc.DescribeGraph(req) - s.Edge.Total = len(fn.Filter(filterDisabled, resp.Edges)) + s.Edge.Total = len(fn.Filter(resp.Edges, filterDisabled)) req = &lnrpc.ChannelGraphRequest{IncludeUnannounced: false} resp = s.rpc.DescribeGraph(req) - s.Edge.Public = len(fn.Filter(filterDisabled, resp.Edges)) + s.Edge.Public = len(fn.Filter(resp.Edges, filterDisabled)) } // updateWalletBalance creates stats for the node's wallet balance. diff --git a/lnwallet/aux_leaf_store.go b/lnwallet/aux_leaf_store.go index c457a92509..28a78e09db 100644 --- a/lnwallet/aux_leaf_store.go +++ b/lnwallet/aux_leaf_store.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/lnwallet/aux_resolutions.go b/lnwallet/aux_resolutions.go index 382232640d..b36e2d6368 100644 --- a/lnwallet/aux_resolutions.go +++ b/lnwallet/aux_resolutions.go @@ -4,7 +4,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" diff --git a/lnwallet/aux_signer.go b/lnwallet/aux_signer.go index 01abe1aae3..510b64b5d1 100644 --- a/lnwallet/aux_signer.go +++ b/lnwallet/aux_signer.go @@ -2,7 +2,7 @@ package lnwallet import ( "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/lnwallet/btcwallet/btcwallet.go b/lnwallet/btcwallet/btcwallet.go index 5d28574cbe..b9a909fbd3 100644 --- a/lnwallet/btcwallet/btcwallet.go +++ b/lnwallet/btcwallet/btcwallet.go @@ -27,7 +27,7 @@ import ( "github.com/btcsuite/btcwallet/wtxmgr" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/blockcache" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" diff --git a/lnwallet/chainfee/filtermanager.go b/lnwallet/chainfee/filtermanager.go index 26fa56aef1..2d6fd0a2e1 100644 --- a/lnwallet/chainfee/filtermanager.go +++ b/lnwallet/chainfee/filtermanager.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/rpcclient" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) const ( diff --git a/lnwallet/chancloser/aux_closer.go b/lnwallet/chancloser/aux_closer.go index 8b1c445ca3..62f475dd43 100644 --- a/lnwallet/chancloser/aux_closer.go +++ b/lnwallet/chancloser/aux_closer.go @@ -4,7 +4,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go index 17112b29e0..398a8a9f3e 100644 --- a/lnwallet/chancloser/chancloser.go +++ b/lnwallet/chancloser/chancloser.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 28709fd5f8..fe71fe5e3b 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/chancloser/interface.go b/lnwallet/chancloser/interface.go index 729cdc545b..f774c81039 100644 --- a/lnwallet/chancloser/interface.go +++ b/lnwallet/chancloser/interface.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/lnwallet/chanfunding/canned_assembler.go b/lnwallet/chanfunding/canned_assembler.go index b3457f21bf..e28cbb96d1 100644 --- a/lnwallet/chanfunding/canned_assembler.go +++ b/lnwallet/chanfunding/canned_assembler.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" ) diff --git a/lnwallet/chanfunding/interface.go b/lnwallet/chanfunding/interface.go index 3512b32ff9..e40c4a1157 100644 --- a/lnwallet/chanfunding/interface.go +++ b/lnwallet/chanfunding/interface.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) diff --git a/lnwallet/chanfunding/psbt_assembler.go b/lnwallet/chanfunding/psbt_assembler.go index f678f520fc..dd1bedd05a 100644 --- a/lnwallet/chanfunding/psbt_assembler.go +++ b/lnwallet/chanfunding/psbt_assembler.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" ) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 7f70e600c6..d190acdf5e 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -25,7 +25,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -600,7 +600,7 @@ func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight, htlc := htlc - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves if htlc.Incoming { @@ -1106,7 +1106,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, feeRate, wireMsg.Amount.ToSatoshis(), remoteDustLimit, ) if !isDustRemote { - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves return leaves[pd.HtlcIndex].AuxTapLeaf @@ -2088,7 +2088,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, // Since it is the remote breach we are reconstructing, the output // going to us will be a to-remote script with our local params. - remoteAuxLeaf := fn.ChainOption( + remoteAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.RemoteAuxLeaf }, @@ -2102,7 +2102,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, return nil, err } - localAuxLeaf := fn.ChainOption( + localAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.LocalAuxLeaf }, @@ -2229,7 +2229,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - br.LocalResolutionBlob = resolveBlob.Option() + br.LocalResolutionBlob = resolveBlob.OkToSome() } // Similarly, if their balance exceeds the remote party's dust limit, @@ -2308,7 +2308,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - br.RemoteResolutionBlob = resolveBlob.Option() + br.RemoteResolutionBlob = resolveBlob.OkToSome() } // Finally, with all the necessary data constructed, we can pad the @@ -2338,7 +2338,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // We'll generate the original second level witness script now, as // we'll need it if we're revoking an HTLC output on the remote // commitment transaction, and *they* go to the second level. - secondLevelAuxLeaf := fn.ChainOption( + secondLevelAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] { return fn.MapOption(func(val uint16) input.AuxTapLeaf { idx := input.HtlcIndex(val) @@ -2366,7 +2366,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // HTLC script. Otherwise, is this was an outgoing HTLC that we sent, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. - htlcLeaf := fn.ChainOption( + htlcLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] { return fn.MapOption(func(val uint16) input.AuxTapLeaf { idx := input.HtlcIndex(val) @@ -2693,13 +2693,13 @@ type HtlcView struct { // AuxOurUpdates returns the outgoing HTLCs as a read-only copy of // AuxHtlcDescriptors. func (v *HtlcView) AuxOurUpdates() []AuxHtlcDescriptor { - return fn.Map(newAuxHtlcDescriptor, v.Updates.Local) + return fn.Map(v.Updates.Local, newAuxHtlcDescriptor) } // AuxTheirUpdates returns the incoming HTLCs as a read-only copy of // AuxHtlcDescriptors. func (v *HtlcView) AuxTheirUpdates() []AuxHtlcDescriptor { - return fn.Map(newAuxHtlcDescriptor, v.Updates.Remote) + return fn.Map(v.Updates.Remote, newAuxHtlcDescriptor) } // fetchHTLCView returns all the candidate HTLC updates which should be @@ -2917,9 +2917,9 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, // The fee rate of our view is always the last UpdateFee message from // the channel's OpeningParty. openerUpdates := view.Updates.GetForParty(lc.channelState.Initiator()) - feeUpdates := fn.Filter(func(u *paymentDescriptor) bool { + feeUpdates := fn.Filter(openerUpdates, func(u *paymentDescriptor) bool { return u.EntryType == FeeUpdate - }, openerUpdates) + }) lastFeeUpdate := fn.Last(feeUpdates) lastFeeUpdate.WhenSome(func(pd *paymentDescriptor) { newView.FeePerKw = chainfee.SatPerKWeight( @@ -2942,14 +2942,17 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, for _, party := range parties { // First we run through non-add entries in both logs, // populating the skip sets. - resolutions := fn.Filter(func(pd *paymentDescriptor) bool { - switch pd.EntryType { - case Settle, Fail, MalformedFail: - return true - default: - return false - } - }, view.Updates.GetForParty(party)) + resolutions := fn.Filter( + view.Updates.GetForParty(party), + func(pd *paymentDescriptor) bool { + switch pd.EntryType { + case Settle, Fail, MalformedFail: + return true + default: + return false + } + }, + ) for _, entry := range resolutions { addEntry, err := lc.fetchParent( @@ -3002,10 +3005,16 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. for _, party := range parties { - liveAdds := fn.Filter(func(pd *paymentDescriptor) bool { - return pd.EntryType == Add && - !skip.GetForParty(party).Contains(pd.HtlcIndex) - }, view.Updates.GetForParty(party)) + liveAdds := fn.Filter( + view.Updates.GetForParty(party), + func(pd *paymentDescriptor) bool { + isAdd := pd.EntryType == Add + shouldSkip := skip.GetForParty(party). + Contains(pd.HtlcIndex) + + return isAdd && !shouldSkip + }, + ) for _, entry := range liveAdds { // Skip the entries that have already had their add @@ -3063,7 +3072,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, uncommittedUpdates := lntypes.MapDual( view.Updates, func(us []*paymentDescriptor) []*paymentDescriptor { - return fn.Filter(isUncommitted, us) + return fn.Filter(us, isUncommitted) }, ) @@ -3189,7 +3198,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.IncomingHtlcLeaves return leaves[htlc.HtlcIndex].SecondLevelLeaf @@ -3270,7 +3279,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves return leaves[htlc.HtlcIndex].SecondLevelLeaf @@ -4802,7 +4811,7 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee - auxLeaf := fn.ChainOption(func( + auxLeaf := fn.FlatMapOption(func( l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.IncomingHtlcLeaves @@ -4895,7 +4904,7 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee - auxLeaf := fn.ChainOption(func( + auxLeaf := fn.FlatMapOption(func( l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves @@ -6766,7 +6775,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, //nolint:funlen // Before we can generate the proper sign descriptor, we'll need to // locate the output index of our non-delayed output on the commitment // transaction. - remoteAuxLeaf := fn.ChainOption( + remoteAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.RemoteAuxLeaf }, @@ -6870,7 +6879,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, //nolint:funlen return nil, fmt.Errorf("unable to aux resolve: %w", err) } - commitResolution.ResolutionBlob = resolveBlob.Option() + commitResolution.ResolutionBlob = resolveBlob.OkToSome() } closeSummary := channeldb.ChannelCloseSummary{ @@ -7059,7 +7068,7 @@ func newOutgoingHtlcResolution(signer input.Signer, // First, we'll re-generate the script used to send the HTLC to the // remote party within their commitment transaction. - auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + auxLeaf := fn.FlatMapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { return l.OutgoingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf })(auxLeaves) htlcScriptInfo, err := genHtlcScript( @@ -7149,7 +7158,7 @@ func newOutgoingHtlcResolution(signer input.Signer, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - resolutionBlob := resolveRes.Option() + resolutionBlob := resolveRes.OkToSome() return &OutgoingHtlcResolution{ Expiry: htlc.RefundTimeout, @@ -7171,7 +7180,7 @@ func newOutgoingHtlcResolution(signer input.Signer, // With the fee calculated, re-construct the second level timeout // transaction. - secondLevelAuxLeaf := fn.ChainOption( + secondLevelAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.OutgoingHtlcLeaves return leaves[htlc.HtlcIndex].SecondLevelLeaf @@ -7366,7 +7375,7 @@ func newOutgoingHtlcResolution(signer input.Signer, if err := resolveRes.Err(); err != nil { return nil, fmt.Errorf("unable to aux resolve: %w", err) } - resolutionBlob := resolveRes.Option() + resolutionBlob := resolveRes.OkToSome() return &OutgoingHtlcResolution{ Expiry: htlc.RefundTimeout, @@ -7406,7 +7415,7 @@ func newIncomingHtlcResolution(signer input.Signer, // First, we'll re-generate the script the remote party used to // send the HTLC to us in their commitment transaction. - auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + auxLeaf := fn.FlatMapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { return l.IncomingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf })(auxLeaves) scriptInfo, err := genHtlcScript( @@ -7497,7 +7506,7 @@ func newIncomingHtlcResolution(signer input.Signer, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - resolutionBlob := resolveRes.Option() + resolutionBlob := resolveRes.OkToSome() return &IncomingHtlcResolution{ ClaimOutpoint: op, @@ -7507,7 +7516,7 @@ func newIncomingHtlcResolution(signer input.Signer, }, nil } - secondLevelAuxLeaf := fn.ChainOption( + secondLevelAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { leaves := l.IncomingHtlcLeaves return leaves[htlc.HtlcIndex].SecondLevelLeaf @@ -7707,7 +7716,7 @@ func newIncomingHtlcResolution(signer input.Signer, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - resolutionBlob := resolveRes.Option() + resolutionBlob := resolveRes.OkToSome() return &IncomingHtlcResolution{ SignedSuccessTx: successTx, @@ -8011,7 +8020,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, leaseExpiry = chanState.ThawHeight } - localAuxLeaf := fn.ChainOption( + localAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.LocalAuxLeaf }, @@ -8126,7 +8135,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, return nil, fmt.Errorf("unable to aux resolve: %w", err) } - commitResolution.ResolutionBlob = resolveBlob.Option() + commitResolution.ResolutionBlob = resolveBlob.OkToSome() } // Once the delay output has been found (if it exists), then we'll also diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index f7ecd32277..d0caa97812 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -25,7 +25,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" @@ -730,9 +730,12 @@ func TestCommitHTLCSigCustomRecordSize(t *testing.T) { // Replace the default PackSigs implementation to return a // large custom records blob. - mockSigner.ExpectedCalls = fn.Filter(func(c *mock.Call) bool { - return c.Method != "PackSigs" - }, mockSigner.ExpectedCalls) + mockSigner.ExpectedCalls = fn.Filter( + mockSigner.ExpectedCalls, + func(c *mock.Call) bool { + return c.Method != "PackSigs" + }, + ) mockSigner.On("PackSigs", mock.Anything). Return(fn.Ok(fn.Some(largeBlob))) }) diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 8b364a01df..787e8a71e1 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -836,7 +836,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf { return leaves[htlc.HtlcIndex].AuxTapLeaf }, @@ -864,7 +864,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } - auxLeaf := fn.ChainOption( + auxLeaf := fn.FlatMapOption( func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf { return leaves[htlc.HtlcIndex].AuxTapLeaf }, @@ -1323,7 +1323,7 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // Compute the to_local script. From our PoV, when facing a remote // commitment, the to_local output belongs to them. - localAuxLeaf := fn.ChainOption( + localAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.LocalAuxLeaf }, @@ -1338,7 +1338,7 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // Compute the to_remote script. From our PoV, when facing a remote // commitment, the to_remote output belongs to us. - remoteAuxLeaf := fn.ChainOption( + remoteAuxLeaf := fn.FlatMapOption( func(l CommitAuxLeaves) input.AuxTapLeaf { return l.RemoteAuxLeaf }, diff --git a/lnwallet/commitment_chain.go b/lnwallet/commitment_chain.go index fa2abe0aa2..871a139c5c 100644 --- a/lnwallet/commitment_chain.go +++ b/lnwallet/commitment_chain.go @@ -1,7 +1,7 @@ package lnwallet import ( - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) // commitmentChain represents a chain of unrevoked commitments. The tail of the diff --git a/lnwallet/config.go b/lnwallet/config.go index 425fe15dad..c60974be6d 100644 --- a/lnwallet/config.go +++ b/lnwallet/config.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/lnwallet/interface.go b/lnwallet/interface.go index c9dee9202a..64f8546310 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -19,7 +19,7 @@ import ( base "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chanvalidate" diff --git a/lnwallet/mock.go b/lnwallet/mock.go index a8610dc779..39e520d276 100644 --- a/lnwallet/mock.go +++ b/lnwallet/mock.go @@ -18,7 +18,7 @@ import ( "github.com/btcsuite/btcwallet/wtxmgr" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/tlv" diff --git a/lnwallet/musig_session.go b/lnwallet/musig_session.go index 822aa48a14..748e5fa958 100644 --- a/lnwallet/musig_session.go +++ b/lnwallet/musig_session.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" diff --git a/lnwallet/reservation.go b/lnwallet/reservation.go index fd35d95076..a8a0cacd4b 100644 --- a/lnwallet/reservation.go +++ b/lnwallet/reservation.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/rpcwallet/rpcwallet.go b/lnwallet/rpcwallet/rpcwallet.go index bf6aa61df3..426712b597 100644 --- a/lnwallet/rpcwallet/rpcwallet.go +++ b/lnwallet/rpcwallet/rpcwallet.go @@ -22,7 +22,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/waddrmgr" basewallet "github.com/btcsuite/btcwallet/wallet" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lncfg" diff --git a/lnwallet/test/test_interface.go b/lnwallet/test/test_interface.go index c006aa2e50..27de51708c 100644 --- a/lnwallet/test/test_interface.go +++ b/lnwallet/test/test_interface.go @@ -34,7 +34,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs/btcdnotify" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index ff9adfbd79..738558e224 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -16,7 +16,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index 135d1866bc..38131eaa72 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -21,7 +21,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/update_log.go b/lnwallet/update_log.go index 2d1f65c9fa..b2b8af58d1 100644 --- a/lnwallet/update_log.go +++ b/lnwallet/update_log.go @@ -1,7 +1,7 @@ package lnwallet import ( - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" ) // updateLog is an append-only log that stores updates to a node's commitment diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index ad6354e2e8..2646d7c8f0 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -23,7 +23,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -733,7 +733,7 @@ func (l *LightningWallet) RegisterFundingIntent(expectedID [32]byte, } if _, ok := l.fundingIntents[expectedID]; ok { - return fmt.Errorf("%w: already has intent registered: %v", + return fmt.Errorf("%w: already has intent registered: %x", ErrDuplicatePendingChanID, expectedID[:]) } diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index e523279498..577379623f 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -5,7 +5,7 @@ import ( "io" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/custom_records.go b/lnwire/custom_records.go index 8177cbe821..a63aa5dfb0 100644 --- a/lnwire/custom_records.go +++ b/lnwire/custom_records.go @@ -6,7 +6,7 @@ import ( "io" "sort" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) @@ -179,9 +179,12 @@ func (c CustomRecords) SerializeTo(w io.Writer) error { // ProduceRecordsSorted converts a slice of record producers into a slice of // records and then sorts it by type. func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record { - records := fn.Map(func(producer tlv.RecordProducer) tlv.Record { - return producer.Record() - }, recordProducers) + records := fn.Map( + recordProducers, + func(producer tlv.RecordProducer) tlv.Record { + return producer.Record() + }, + ) // Ensure that the set of records are sorted before we attempt to // decode from the stream, to ensure they're canonical. @@ -212,9 +215,9 @@ func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record { // RecordsAsProducers converts a slice of records into a slice of record // producers. func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer { - return fn.Map(func(record tlv.Record) tlv.RecordProducer { + return fn.Map(records, func(record tlv.Record) tlv.RecordProducer { return &record - }, records) + }) } // EncodeRecords encodes the given records into a byte slice. diff --git a/lnwire/custom_records_test.go b/lnwire/custom_records_test.go index 8ff6af10ba..d4aad2e546 100644 --- a/lnwire/custom_records_test.go +++ b/lnwire/custom_records_test.go @@ -4,7 +4,7 @@ import ( "bytes" "testing" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -182,9 +182,12 @@ func TestCustomRecordsExtendRecordProducers(t *testing.T) { func serializeRecordProducers(t *testing.T, producers []tlv.RecordProducer) []byte { - tlvRecords := fn.Map(func(p tlv.RecordProducer) tlv.Record { - return p.Record() - }, producers) + tlvRecords := fn.Map( + producers, + func(p tlv.RecordProducer) tlv.Record { + return p.Record() + }, + ) stream, err := tlv.NewStream(tlvRecords...) require.NoError(t, err) diff --git a/lnwire/dyn_ack.go b/lnwire/dyn_ack.go index 24f23a228d..d477461e7b 100644 --- a/lnwire/dyn_ack.go +++ b/lnwire/dyn_ack.go @@ -5,7 +5,7 @@ import ( "io" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/dyn_propose.go b/lnwire/dyn_propose.go index b0cc1198e9..394fff6f37 100644 --- a/lnwire/dyn_propose.go +++ b/lnwire/dyn_propose.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index c4ca260e1e..4681426cbb 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -5,7 +5,7 @@ import ( "fmt" "io" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 952e90a7e6..6bfbb465ec 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -22,7 +22,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tor" diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 5f05e1ef9f..7b65a85f4e 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -10,7 +10,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 9c39be6d5c..5c3d0291a5 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/stretchr/testify/require" ) diff --git a/log.go b/log.go index a3efd03335..46047fb56c 100644 --- a/log.go +++ b/log.go @@ -9,6 +9,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" @@ -196,6 +197,7 @@ func SetupLoggers(root *build.SubLoggerManager, interceptor signal.Interceptor) root, blindedpath.Subsystem, interceptor, blindedpath.UseLogger, ) AddV1SubLogger(root, graphdb.Subsystem, interceptor, graphdb.UseLogger) + AddSubLogger(root, chainio.Subsystem, interceptor, chainio.UseLogger) } // AddSubLogger is a helper method to conveniently create and register the diff --git a/msgmux/msg_router.go b/msgmux/msg_router.go index db9e783990..736c085a95 100644 --- a/msgmux/msg_router.go +++ b/msgmux/msg_router.go @@ -6,7 +6,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) @@ -91,8 +91,8 @@ func sendQueryErr[Q any](sendChan chan fn.Req[Q, error], queryArg Q, quitChan chan struct{}) error { return fn.ElimEither( - fn.Iden, fn.Iden, sendQuery(sendChan, queryArg, quitChan).Either, + fn.Iden, fn.Iden, ) } diff --git a/peer/brontide.go b/peer/brontide.go index 6bc49445ee..a41b5080cb 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -26,7 +26,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" @@ -400,6 +400,10 @@ type Config struct { // way contracts are resolved. AuxResolver fn.Option[lnwallet.AuxContractResolver] + // AuxTrafficShaper is an optional auxiliary traffic shaper that can be + // used to manage the bandwidth of peer links. + AuxTrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] + // PongBuf is a slice we'll reuse instead of allocating memory on the // heap. Since only reads will occur and no writes, there is no need // for any synchronization primitives. As a result, it's safe to share @@ -1330,6 +1334,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, ShouldFwdExpEndorsement: p.cfg.ShouldFwdExpEndorsement, DisallowQuiescence: p.cfg.DisallowQuiescence || !p.remoteFeatures.HasFeature(lnwire.QuiescenceOptional), + AuxTrafficShaper: p.cfg.AuxTrafficShaper, } // Before adding our new link, purge the switch of any pending or live diff --git a/peer/brontide_test.go b/peer/brontide_test.go index c3d1bee48b..eded658887 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -13,7 +13,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/peer/musig_chan_closer.go b/peer/musig_chan_closer.go index 6f69a8c5b8..149ebcfa0c 100644 --- a/peer/musig_chan_closer.go +++ b/peer/musig_chan_closer.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chancloser" diff --git a/peer/test_utils.go b/peer/test_utils.go index eb510a53b1..34c42e2f7c 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -18,7 +18,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" diff --git a/protofsm/daemon_events.go b/protofsm/daemon_events.go index e5de0b6951..bca7283d39 100644 --- a/protofsm/daemon_events.go +++ b/protofsm/daemon_events.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/protofsm/msg_mapper.go b/protofsm/msg_mapper.go index b96d677e6b..5e24255fa3 100644 --- a/protofsm/msg_mapper.go +++ b/protofsm/msg_mapper.go @@ -1,7 +1,7 @@ package protofsm import ( - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go index b71d5efe42..a81f5746b2 100644 --- a/protofsm/state_machine.go +++ b/protofsm/state_machine.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" ) @@ -21,6 +21,12 @@ const ( pollInterval = time.Millisecond * 100 ) +var ( + // ErrStateMachineShutdown occurs when trying to feed an event to a + // StateMachine that has been asked to Stop. + ErrStateMachineShutdown = fmt.Errorf("StateMachine is shutting down") +) + // EmittedEvent is a special type that can be emitted by a state transition. // This can container internal events which are to be routed back to the state, // or external events which are to be sent to the daemon. @@ -287,7 +293,7 @@ func (s *StateMachine[Event, Env]) CurrentState() (State[Event, Env], error) { } if !fn.SendOrQuit(s.stateQuery, query, s.quit) { - return nil, fmt.Errorf("state machine is shutting down") + return nil, ErrStateMachineShutdown } return fn.RecvOrTimeout(query.CurrentState, time.Second) @@ -322,6 +328,8 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[ // executeDaemonEvent executes a daemon event, which is a special type of event // that can be emitted as part of the state transition function of the state // machine. An error is returned if the type of event is unknown. +// +//nolint:funlen func (s *StateMachine[Event, Env]) executeDaemonEvent( event DaemonEvent) error { @@ -347,7 +355,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // If a post-send event was specified, then we'll funnel // that back into the main state machine now as well. return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:ll - return s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(func(ctx context.Context) { log.Debugf("FSM(%v): sending "+ "post-send event: %v", s.cfg.Env.Name(), @@ -356,6 +364,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( s.SendEvent(event) }) + + if !launched { + return ErrStateMachineShutdown + } + + return nil }) } @@ -368,7 +382,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // Otherwise, this has a SendWhen predicate, so we'll need // launch a goroutine to poll the SendWhen, then send only once // the predicate is true. - return s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(func(ctx context.Context) { predicateTicker := time.NewTicker( s.cfg.CustomPollInterval.UnwrapOr(pollInterval), ) @@ -407,6 +421,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( } }) + if !launched { + return ErrStateMachineShutdown + } + + return nil + // If this is a broadcast transaction event, then we'll broadcast with // the label attached. case *BroadcastTxn: @@ -436,7 +456,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( return fmt.Errorf("unable to register spend: %w", err) } - return s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(func(ctx context.Context) { for { select { case spend, ok := <-spendEvent.Spend: @@ -461,6 +481,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( } }) + if !launched { + return ErrStateMachineShutdown + } + + return nil + // The state machine has requested a new event to be sent once a // specified txid+pkScript pair has confirmed. case *RegisterConf[Event]: @@ -476,7 +502,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( return fmt.Errorf("unable to register conf: %w", err) } - return s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(func(ctx context.Context) { for { select { case <-confEvent.Confirmed: @@ -498,6 +524,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( } } }) + + if !launched { + return ErrStateMachineShutdown + } + + return nil } return fmt.Errorf("unknown daemon event: %T", event) diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go index fc30fcefc3..fc7a4ccfdc 100644 --- a/protofsm/state_machine_test.go +++ b/protofsm/state_machine_test.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 12e82131dc..a552628c79 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -3,7 +3,7 @@ package routing import ( "fmt" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" @@ -29,39 +29,6 @@ type bandwidthHints interface { firstHopCustomBlob() fn.Option[tlv.Blob] } -// TlvTrafficShaper is an interface that allows the sender to determine if a -// payment should be carried by a channel based on the TLV records that may be -// present in the `update_add_htlc` message or the channel commitment itself. -type TlvTrafficShaper interface { - AuxHtlcModifier - - // ShouldHandleTraffic is called in order to check if the channel - // identified by the provided channel ID may have external mechanisms - // that would allow it to carry out the payment. - ShouldHandleTraffic(cid lnwire.ShortChannelID, - fundingBlob fn.Option[tlv.Blob]) (bool, error) - - // PaymentBandwidth returns the available bandwidth for a custom channel - // decided by the given channel aux blob and HTLC blob. A return value - // of 0 means there is no bandwidth available. To find out if a channel - // is a custom channel that should be handled by the traffic shaper, the - // HandleTraffic method should be called first. - PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob], - linkBandwidth, - htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) -} - -// AuxHtlcModifier is an interface that allows the sender to modify the outgoing -// HTLC of a payment by changing the amount or the wire message tlv records. -type AuxHtlcModifier interface { - // ProduceHtlcExtraData is a function that, based on the previous extra - // data blob of an HTLC, may produce a different blob or modify the - // amount of bitcoin this htlc should carry. - ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, - htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, - lnwire.CustomRecords, error) -} - // getLinkQuery is the function signature used to lookup a link. type getLinkQuery func(lnwire.ShortChannelID) ( htlcswitch.ChannelLink, error) @@ -73,7 +40,7 @@ type bandwidthManager struct { getLink getLinkQuery localChans map[lnwire.ShortChannelID]struct{} firstHopBlob fn.Option[tlv.Blob] - trafficShaper fn.Option[TlvTrafficShaper] + trafficShaper fn.Option[htlcswitch.AuxTrafficShaper] } // newBandwidthManager creates a bandwidth manager for the source node provided @@ -84,13 +51,14 @@ type bandwidthManager struct { // that are inactive, or just don't have enough bandwidth to carry the payment. func newBandwidthManager(graph Graph, sourceNode route.Vertex, linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) { + ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager, + error) { manager := &bandwidthManager{ getLink: linkQuery, localChans: make(map[lnwire.ShortChannelID]struct{}), firstHopBlob: firstHopBlob, - trafficShaper: trafficShaper, + trafficShaper: ts, } // First, we'll collect the set of outbound edges from the target @@ -166,44 +134,15 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, result, err := fn.MapOptionZ( b.trafficShaper, - func(ts TlvTrafficShaper) fn.Result[bandwidthResult] { - fundingBlob := link.FundingCustomBlob() - shouldHandle, err := ts.ShouldHandleTraffic( - cid, fundingBlob, - ) - if err != nil { - return bandwidthErr(fmt.Errorf("traffic "+ - "shaper failed to decide whether to "+ - "handle traffic: %w", err)) - } - - log.Debugf("ShortChannelID=%v: external traffic "+ - "shaper is handling traffic: %v", cid, - shouldHandle) - - // If this channel isn't handled by the external traffic - // shaper, we'll return early. - if !shouldHandle { - return fn.Ok(bandwidthResult{}) - } - - // Ask for a specific bandwidth to be used for the - // channel. - commitmentBlob := link.CommitmentCustomBlob() - auxBandwidth, err := ts.PaymentBandwidth( - b.firstHopBlob, commitmentBlob, linkBandwidth, - amount, - ) + func(s htlcswitch.AuxTrafficShaper) fn.Result[bandwidthResult] { + auxBandwidth, err := link.AuxBandwidth( + amount, cid, b.firstHopBlob, s, + ).Unpack() if err != nil { return bandwidthErr(fmt.Errorf("failed to get "+ - "bandwidth from external traffic "+ - "shaper: %w", err)) + "auxiliary bandwidth: %w", err)) } - log.Debugf("ShortChannelID=%v: external traffic "+ - "shaper reported available bandwidth: %v", cid, - auxBandwidth) - // We don't know the actual HTLC amount that will be // sent using the custom channel. But we'll still want // to make sure we can add another HTLC, using the @@ -213,7 +152,7 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, // the max number of HTLCs on the channel. A proper // balance check is done elsewhere. return fn.Ok(bandwidthResult{ - bandwidth: fn.Some(auxBandwidth), + bandwidth: auxBandwidth, htlcAmount: fn.Some[lnwire.MilliSatoshi](0), }) }, diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index 4872b5a7ec..b31d0095ac 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" @@ -118,7 +118,9 @@ func TestBandwidthManager(t *testing.T) { m, err := newBandwidthManager( g, sourceNode.pubkey, testCase.linkQuery, fn.None[[]byte](), - fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), ) require.NoError(t, err) diff --git a/routing/blinding.go b/routing/blinding.go index 7c84063469..0c27e87439 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" diff --git a/routing/blinding_test.go b/routing/blinding_test.go index 410dfaf643..8f83f7fd82 100644 --- a/routing/blinding_test.go +++ b/routing/blinding_test.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 315b0dff22..e4241dac53 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index d7380439ac..cd9e58fcaa 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/discovery" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" diff --git a/routing/missioncontrol.go b/routing/missioncontrol.go index de892392e7..3bc9be7aba 100644 --- a/routing/missioncontrol.go +++ b/routing/missioncontrol.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" diff --git a/routing/mock_test.go b/routing/mock_test.go index 3cdb5ebaf2..86fd765499 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" @@ -107,7 +107,7 @@ var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil) func (m *mockPaymentSessionSourceOld) NewPaymentSession( _ *LightningPayment, _ fn.Option[tlv.Blob], - _ fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + _ fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) { return &mockPaymentSessionOld{ routes: m.routes, @@ -635,7 +635,8 @@ var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) func (m *mockPaymentSessionSource) NewPaymentSession( payment *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - tlvShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + tlvShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, + error) { args := m.Called(payment, firstHopBlob, tlvShaper) return args.Get(0).(PaymentSession), args.Error(1) @@ -895,6 +896,19 @@ func (m *mockLink) Bandwidth() lnwire.MilliSatoshi { return m.bandwidth } +// AuxBandwidth returns the bandwidth that can be used for a channel, +// expressed in milli-satoshi. This might be different from the regular +// BTC bandwidth for custom channels. This will always return fn.None() +// for a regular (non-custom) channel. +func (m *mockLink) AuxBandwidth(lnwire.MilliSatoshi, lnwire.ShortChannelID, + fn.Option[tlv.Blob], + htlcswitch.AuxTrafficShaper) fn.Result[htlcswitch.OptionalBandwidth] { + + return fn.Ok[htlcswitch.OptionalBandwidth]( + fn.None[lnwire.MilliSatoshi](), + ) +} + // EligibleToForward returns the mock's configured eligibility. func (m *mockLink) EligibleToForward() bool { return !m.ineligible diff --git a/routing/pathfind.go b/routing/pathfind.go index 8e40c5bc4b..80f4b1e68f 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/btcutil" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnutils" diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index da29c79a25..c463b8135b 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -21,7 +21,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 267ce3965d..5c4bae5585 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -10,7 +10,7 @@ import ( "github.com/davecgh/go-spew/spew" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" @@ -761,7 +761,8 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { // and apply its side effects to the UpdateAddHTLC message. result, err := fn.MapOptionZ( p.router.cfg.TrafficShaper, - func(ts TlvTrafficShaper) fn.Result[extraDataRequest] { + //nolint:ll + func(ts htlcswitch.AuxTrafficShaper) fn.Result[extraDataRequest] { newAmt, newRecords, err := ts.ProduceHtlcExtraData( rt.TotalAmount, p.firstHopCustomRecords, ) @@ -774,7 +775,7 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { return fn.Err[extraDataRequest](err) } - log.Debugf("TLV traffic shaper returned custom "+ + log.Debugf("Aux traffic shaper returned custom "+ "records %v and amount %d msat for HTLC", spew.Sdump(newRecords), newAmt) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 315c1bad58..72aa631419 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lntest/wait" @@ -30,7 +30,7 @@ func createTestPaymentLifecycle() *paymentLifecycle { quitChan := make(chan struct{}) rt := &ChannelRouter{ cfg: &Config{ - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }, @@ -83,7 +83,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { Payer: mockPayer, Clock: mockClock, MissionControl: mockMissionControl, - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index d5f1a6af41..5e4eb23d7f 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -2,8 +2,9 @@ package routing import ( "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -52,7 +53,8 @@ type SessionSource struct { // payment's destination. func (m *SessionSource) NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + trafficShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, + error) { getBandwidthHints := func(graph Graph) (bandwidthHints, error) { return newBandwidthManager( diff --git a/routing/result_interpretation.go b/routing/result_interpretation.go index 089213d65e..bc1749dfb1 100644 --- a/routing/result_interpretation.go +++ b/routing/result_interpretation.go @@ -6,7 +6,7 @@ import ( "io" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -578,7 +578,7 @@ func extractMCRoute(r *route.Route) *mcRoute { // extractMCHops extracts the Hop fields that MC actually uses from a slice of // Hops. func extractMCHops(hops []*route.Hop) mcHops { - return fn.Map(extractMCHop, hops) + return fn.Map(hops, extractMCHop) } // extractMCHop extracts the Hop fields that MC actually uses from a Hop. diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index b213eb1835..8c67bdeea9 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) diff --git a/routing/router.go b/routing/router.go index 9eabe0b2ae..468510a6c7 100644 --- a/routing/router.go +++ b/routing/router.go @@ -19,7 +19,7 @@ import ( "github.com/lightningnetwork/lnd/amp" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" @@ -157,7 +157,7 @@ type PaymentSessionSource interface { // finding a path to the payment's destination. NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, + ts fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) // NewPaymentSessionEmpty creates a new paymentSession instance that is @@ -297,7 +297,7 @@ type Config struct { // TrafficShaper is an optional traffic shaper that can be used to // control the outgoing channel of a payment. - TrafficShaper fn.Option[TlvTrafficShaper] + TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] } // EdgeLocator is a struct used to identify a specific edge. diff --git a/routing/router_test.go b/routing/router_test.go index 2923f1fb90..22c9d14e50 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -23,7 +23,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" @@ -170,7 +170,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, Clock: clock.NewTestClock(time.Unix(1, 0)), ApplyChannelUpdate: graphBuilder.ApplyChannelUpdate, ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }) @@ -2206,8 +2206,10 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Register mockers with the expected method calls. @@ -2291,8 +2293,10 @@ func TestSendToRouteSkipTempErrNonMPP(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Expect an error to be returned. @@ -2347,8 +2351,10 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. @@ -2431,8 +2437,10 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. @@ -2519,8 +2527,10 @@ func TestSendToRouteTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. diff --git a/rpcserver.go b/rpcserver.go index d7d2e0186c..72e2fa4afd 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -46,7 +46,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" @@ -8068,9 +8068,9 @@ func (r *rpcServer) VerifyChanBackup(ctx context.Context, } return &lnrpc.VerifyChanBackupResponse{ - ChanPoints: fn.Map(func(c chanbackup.Single) string { + ChanPoints: fn.Map(channels, func(c chanbackup.Single) string { return c.FundingOutpoint.String() - }, channels), + }), }, nil } diff --git a/rpcserver_test.go b/rpcserver_test.go index b4b66e719c..b686c9020a 100644 --- a/rpcserver_test.go +++ b/rpcserver_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnrpc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/server.go b/server.go index f8f8239ed6..9fd0b7a006 100644 --- a/server.go +++ b/server.go @@ -28,6 +28,7 @@ import ( "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/brontide" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/chanbackup" @@ -39,7 +40,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/feature" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" @@ -356,6 +357,10 @@ type server struct { // txPublisher is a publisher with fee-bumping capability. txPublisher *sweep.TxPublisher + // blockbeatDispatcher is a block dispatcher that notifies subscribers + // of new blocks. + blockbeatDispatcher *chainio.BlockbeatDispatcher + quit chan struct{} wg sync.WaitGroup @@ -623,6 +628,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, readPool: readPool, chansToRestore: chansToRestore, + blockbeatDispatcher: chainio.NewBlockbeatDispatcher( + cc.ChainNotifier, + ), channelNotifier: channelnotifier.New( dbs.ChanStateDB.ChannelStateDB(), ), @@ -665,6 +673,17 @@ func newServer(cfg *Config, listenAddrs []net.Addr, quit: make(chan struct{}), } + // Start the low-level services once they are initialized. + // + // TODO(yy): break the server startup into four steps, + // 1. init the low-level services. + // 2. start the low-level services. + // 3. init the high-level services. + // 4. start the high-level services. + if err := s.startLowLevelServices(); err != nil { + return nil, err + } + currentHash, currentHeight, err := s.cc.ChainIO.GetBestBlock() if err != nil { return nil, err @@ -1813,6 +1832,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } s.connMgr = cmgr + // Finally, register the subsystems in blockbeat. + s.registerBlockConsumers() + return s, nil } @@ -1845,6 +1867,25 @@ func (s *server) UpdateRoutingConfig(cfg *routing.MissionControlConfig) { routerCfg.MaxMcHistory = cfg.MaxMcHistory } +// registerBlockConsumers registers the subsystems that consume block events. +// By calling `RegisterQueue`, a list of subsystems are registered in the +// blockbeat for block notifications. When a new block arrives, the subsystems +// in the same queue are notified sequentially, and different queues are +// notified concurrently. +// +// NOTE: To put a subsystem in a different queue, create a slice and pass it to +// a new `RegisterQueue` call. +func (s *server) registerBlockConsumers() { + // In this queue, when a new block arrives, it will be received and + // processed in this order: chainArb -> sweeper -> txPublisher. + consumers := []chainio.Consumer{ + s.chainArb, + s.sweeper, + s.txPublisher, + } + s.blockbeatDispatcher.RegisterQueue(consumers) +} + // signAliasUpdate takes a ChannelUpdate and returns the signature. This is // used for option_scid_alias channels where the ChannelUpdate to be sent back // may differ from what is on disk. @@ -2067,12 +2108,41 @@ func (c cleaner) run() { } } +// startLowLevelServices starts the low-level services of the server. These +// services must be started successfully before running the main server. The +// services are, +// 1. the chain notifier. +// +// TODO(yy): identify and add more low-level services here. +func (s *server) startLowLevelServices() error { + var startErr error + + cleanup := cleaner{} + + cleanup = cleanup.add(s.cc.ChainNotifier.Stop) + if err := s.cc.ChainNotifier.Start(); err != nil { + startErr = err + } + + if startErr != nil { + cleanup.run() + } + + return startErr +} + // Start starts the main daemon server, all requested listeners, and any helper // goroutines. // NOTE: This function is safe for concurrent access. // //nolint:funlen func (s *server) Start() error { + // Get the current blockbeat. + beat, err := s.getStartingBeat() + if err != nil { + return err + } + var startErr error // If one sub system fails to start, the following code ensures that the @@ -2126,12 +2196,6 @@ func (s *server) Start() error { return } - cleanup = cleanup.add(s.cc.ChainNotifier.Stop) - if err := s.cc.ChainNotifier.Start(); err != nil { - startErr = err - return - } - cleanup = cleanup.add(s.cc.BestBlockTracker.Stop) if err := s.cc.BestBlockTracker.Start(); err != nil { startErr = err @@ -2167,13 +2231,13 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.txPublisher.Stop) - if err := s.txPublisher.Start(); err != nil { + if err := s.txPublisher.Start(beat); err != nil { startErr = err return } cleanup = cleanup.add(s.sweeper.Stop) - if err := s.sweeper.Start(); err != nil { + if err := s.sweeper.Start(beat); err != nil { startErr = err return } @@ -2218,7 +2282,7 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.chainArb.Stop) - if err := s.chainArb.Start(); err != nil { + if err := s.chainArb.Start(beat); err != nil { startErr = err return } @@ -2459,6 +2523,17 @@ func (s *server) Start() error { srvrLog.Infof("Auto peer bootstrapping is disabled") } + // Start the blockbeat after all other subsystems have been + // started so they are ready to receive new blocks. + cleanup = cleanup.add(func() error { + s.blockbeatDispatcher.Stop() + return nil + }) + if err := s.blockbeatDispatcher.Start(); err != nil { + startErr = err + return + } + // Set the active flag now that we've completed the full // startup. atomic.StoreInt32(&s.active, 1) @@ -2483,6 +2558,9 @@ func (s *server) Stop() error { // Shutdown connMgr first to prevent conns during shutdown. s.connMgr.Stop() + // Stop dispatching blocks to other systems immediately. + s.blockbeatDispatcher.Stop() + // Shutdown the wallet, funding manager, and the rpc server. if err := s.chanStatusMgr.Stop(); err != nil { srvrLog.Warnf("failed to stop chanStatusMgr: %v", err) @@ -4222,6 +4300,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, MsgRouter: s.implCfg.MsgRouter, AuxChanCloser: s.implCfg.AuxChanCloser, AuxResolver: s.implCfg.AuxContractResolver, + AuxTrafficShaper: s.implCfg.TrafficShaper, ShouldFwdExpEndorsement: func() bool { if s.cfg.ProtocolOptions.NoExperimentalEndorsement() { return false @@ -5151,3 +5230,35 @@ func (s *server) fetchClosedChannelSCIDs() map[lnwire.ShortChannelID]struct{} { return closedSCIDs } + +// getStartingBeat returns the current beat. This is used during the startup to +// initialize blockbeat consumers. +func (s *server) getStartingBeat() (*chainio.Beat, error) { + // beat is the current blockbeat. + var beat *chainio.Beat + + // We should get a notification with the current best block immediately + // by passing a nil block. + blockEpochs, err := s.cc.ChainNotifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return beat, fmt.Errorf("register block epoch ntfn: %w", err) + } + defer blockEpochs.Cancel() + + // We registered for the block epochs with a nil request. The notifier + // should send us the current best block immediately. So we need to + // wait for it here because we need to know the current best height. + select { + case bestBlock := <-blockEpochs.Epochs: + srvrLog.Infof("Received initial block %v at height %d", + bestBlock.Hash, bestBlock.Height) + + // Update the current blockbeat. + beat = chainio.NewBeat(*bestBlock) + + case <-s.quit: + srvrLog.Debug("LND shutting down") + } + + return beat, nil +} diff --git a/subrpcserver_config.go b/subrpcserver_config.go index 30755c05e4..102e211187 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -11,7 +11,7 @@ import ( "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" diff --git a/sweep/aggregator.go b/sweep/aggregator.go index a0a1b0a540..e97ccb9a21 100644 --- a/sweep/aggregator.go +++ b/sweep/aggregator.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/sweep/aggregator_test.go b/sweep/aggregator_test.go index 6df0d73fa2..2cb89bdc38 100644 --- a/sweep/aggregator_test.go +++ b/sweep/aggregator_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index adb4db65ed..a43f875850 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -12,8 +12,9 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lntypes" @@ -65,7 +66,7 @@ type Bumper interface { // and monitors its confirmation status for potential fee bumping. It // returns a chan that the caller can use to receive updates about the // broadcast result and potential RBF attempts. - Broadcast(req *BumpRequest) (<-chan *BumpResult, error) + Broadcast(req *BumpRequest) <-chan *BumpResult } // BumpEvent represents the event of a fee bumping attempt. @@ -75,7 +76,17 @@ const ( // TxPublished is sent when the broadcast attempt is finished. TxPublished BumpEvent = iota - // TxFailed is sent when the broadcast attempt fails. + // TxFailed is sent when the tx has encountered a fee-related error + // during its creation or broadcast, or an internal error from the fee + // bumper. In either case the inputs in this tx should be retried with + // either a different grouping strategy or an increased budget. + // + // NOTE: We also send this event when there's a third party spend + // event, and the sweeper will handle cleaning this up once it's + // confirmed. + // + // TODO(yy): Remove the above usage once we remove sweeping non-CPFP + // anchors. TxFailed // TxReplaced is sent when the original tx is replaced by a new one. @@ -84,6 +95,11 @@ const ( // TxConfirmed is sent when the tx is confirmed. TxConfirmed + // TxFatal is sent when the inputs in this tx cannot be retried. Txns + // will end up in this state if they have encountered a non-fee related + // error, which means they cannot be retried with increased budget. + TxFatal + // sentinalEvent is used to check if an event is unknown. sentinalEvent ) @@ -99,6 +115,8 @@ func (e BumpEvent) String() string { return "Replaced" case TxConfirmed: return "Confirmed" + case TxFatal: + return "Fatal" default: return "Unknown" } @@ -136,6 +154,10 @@ type BumpRequest struct { // ExtraTxOut tracks if this bump request has an optional set of extra // outputs to add to the transaction. ExtraTxOut fn.Option[SweepOutput] + + // Immediate is used to specify that the tx should be broadcast + // immediately. + Immediate bool } // MaxFeeRateAllowed returns the maximum fee rate allowed for the given @@ -145,13 +167,13 @@ type BumpRequest struct { func (r *BumpRequest) MaxFeeRateAllowed() (chainfee.SatPerKWeight, error) { // We'll want to know if we have any blobs, as we need to factor this // into the max fee rate for this bump request. - hasBlobs := fn.Any(func(i input.Input) bool { + hasBlobs := fn.Any(r.Inputs, func(i input.Input) bool { return fn.MapOptionZ( i.ResolutionBlob(), func(b tlv.Blob) bool { return len(b) > 0 }, ) - }, r.Inputs) + }) sweepAddrs := [][]byte{ r.DeliveryAddress.DeliveryAddress, @@ -246,10 +268,22 @@ type BumpResult struct { requestID uint64 } +// String returns a human-readable string for the result. +func (b *BumpResult) String() string { + desc := fmt.Sprintf("Event=%v", b.Event) + if b.Tx != nil { + desc += fmt.Sprintf(", Tx=%v", b.Tx.TxHash()) + } + + return fmt.Sprintf("[%s]", desc) +} + // Validate validates the BumpResult so it's safe to use. func (b *BumpResult) Validate() error { - // Every result must have a tx. - if b.Tx == nil { + isFailureEvent := b.Event == TxFailed || b.Event == TxFatal + + // Every result must have a tx except the fatal or failed case. + if b.Tx == nil && !isFailureEvent { return fmt.Errorf("%w: nil tx", ErrInvalidBumpResult) } @@ -263,8 +297,8 @@ func (b *BumpResult) Validate() error { return fmt.Errorf("%w: nil replacing tx", ErrInvalidBumpResult) } - // If it's a failed event, it must have an error. - if b.Event == TxFailed && b.Err == nil { + // If it's a failed or fatal event, it must have an error. + if isFailureEvent && b.Err == nil { return fmt.Errorf("%w: nil error", ErrInvalidBumpResult) } @@ -311,6 +345,10 @@ type TxPublisher struct { started atomic.Bool stopped atomic.Bool + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + wg sync.WaitGroup // cfg specifies the configuration of the TxPublisher. @@ -338,14 +376,22 @@ type TxPublisher struct { // Compile-time constraint to ensure TxPublisher implements Bumper. var _ Bumper = (*TxPublisher)(nil) +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*TxPublisher)(nil) + // NewTxPublisher creates a new TxPublisher. func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher { - return &TxPublisher{ + tp := &TxPublisher{ cfg: &cfg, records: lnutils.SyncMap[uint64, *monitorRecord]{}, subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{}, quit: make(chan struct{}), } + + // Mount the block consumer. + tp.BeatConsumer = chainio.NewBeatConsumer(tp.quit, tp.Name()) + + return tp } // isNeutrinoBackend checks if the wallet backend is neutrino. @@ -353,60 +399,69 @@ func (t *TxPublisher) isNeutrinoBackend() bool { return t.cfg.Wallet.BackEnd() == "neutrino" } -// Broadcast is used to publish the tx created from the given inputs. It will, -// 1. init a fee function based on the given strategy. -// 2. create an RBF-compliant tx and monitor it for confirmation. -// 3. notify the initial broadcast result back to the caller. -// The initial broadcast is guaranteed to be RBF-compliant unless the budget -// specified cannot cover the fee. +// Broadcast is used to publish the tx created from the given inputs. It will +// register the broadcast request and return a chan to the caller to subscribe +// the broadcast result. The initial broadcast is guaranteed to be +// RBF-compliant unless the budget specified cannot cover the fee. // // NOTE: part of the Bumper interface. -func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { - log.Tracef("Received broadcast request: %s", lnutils.SpewLogClosure( - req)) +func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult { + log.Tracef("Received broadcast request: %s", + lnutils.SpewLogClosure(req)) - // Attempt an initial broadcast which is guaranteed to comply with the - // RBF rules. - result, err := t.initialBroadcast(req) - if err != nil { - log.Errorf("Initial broadcast failed: %v", err) - - return nil, err - } + // Store the request. + requestID, record := t.storeInitialRecord(req) // Create a chan to send the result to the caller. subscriber := make(chan *BumpResult, 1) - t.subscriberChans.Store(result.requestID, subscriber) + t.subscriberChans.Store(requestID, subscriber) - // Send the initial broadcast result to the caller. - t.handleResult(result) + // Publish the tx immediately if specified. + if req.Immediate { + t.handleInitialBroadcast(record, requestID) + } + + return subscriber +} + +// storeInitialRecord initializes a monitor record and saves it in the map. +func (t *TxPublisher) storeInitialRecord(req *BumpRequest) ( + uint64, *monitorRecord) { - return subscriber, nil + // Increase the request counter. + // + // NOTE: this is the only place where we increase the counter. + requestID := t.requestCounter.Add(1) + + // Register the record. + record := &monitorRecord{req: req} + t.records.Store(requestID, record) + + return requestID, record +} + +// NOTE: part of the `chainio.Consumer` interface. +func (t *TxPublisher) Name() string { + return "TxPublisher" } -// initialBroadcast initializes a fee function, creates an RBF-compliant tx and -// broadcasts it. -func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { +// initializeTx initializes a fee function and creates an RBF-compliant tx. If +// succeeded, the initial tx is stored in the records map. +func (t *TxPublisher) initializeTx(requestID uint64, req *BumpRequest) error { // Create a fee bumping algorithm to be used for future RBF. feeAlgo, err := t.initializeFeeFunction(req) if err != nil { - return nil, fmt.Errorf("init fee function: %w", err) + return fmt.Errorf("init fee function: %w", err) } // Create the initial tx to be broadcasted. This tx is guaranteed to // comply with the RBF restrictions. - requestID, err := t.createRBFCompliantTx(req, feeAlgo) + err = t.createRBFCompliantTx(requestID, req, feeAlgo) if err != nil { - return nil, fmt.Errorf("create RBF-compliant tx: %w", err) + return fmt.Errorf("create RBF-compliant tx: %w", err) } - // Broadcast the tx and return the monitored record. - result, err := t.broadcast(requestID) - if err != nil { - return nil, fmt.Errorf("broadcast sweep tx: %w", err) - } - - return result, nil + return nil } // initializeFeeFunction initializes a fee function to be used for this request @@ -442,8 +497,8 @@ func (t *TxPublisher) initializeFeeFunction( // so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee // and redo the process until the tx is valid, or return an error when non-RBF // related errors occur or the budget has been used up. -func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, - f FeeFunction) (uint64, error) { +func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest, + f FeeFunction) error { for { // Create a new tx with the given fee rate and check its @@ -452,18 +507,19 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, switch { case err == nil: - // The tx is valid, return the request ID. - requestID := t.storeRecord( - sweepCtx.tx, req, f, sweepCtx.fee, + // The tx is valid, store it. + t.storeRecord( + requestID, sweepCtx.tx, req, f, sweepCtx.fee, sweepCtx.outpointToTxIndex, ) - log.Infof("Created tx %v for %v inputs: feerate=%v, "+ - "fee=%v, inputs=%v", sweepCtx.tx.TxHash(), - len(req.Inputs), f.FeeRate(), sweepCtx.fee, + log.Infof("Created initial sweep tx=%v for %v inputs: "+ + "feerate=%v, fee=%v, inputs:\n%v", + sweepCtx.tx.TxHash(), len(req.Inputs), + f.FeeRate(), sweepCtx.fee, inputTypeSummary(req.Inputs)) - return requestID, nil + return nil // If the error indicates the fees paid is not enough, we will // ask the fee function to increase the fee rate and retry. @@ -494,7 +550,7 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // cluster these inputs differetly. increased, err = f.Increment() if err != nil { - return 0, err + return err } } @@ -504,21 +560,15 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // mempool acceptance. default: log.Debugf("Failed to create RBF-compliant tx: %v", err) - return 0, err + return err } } } // storeRecord stores the given record in the records map. -func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, - f FeeFunction, fee btcutil.Amount, - outpointToTxIndex map[wire.OutPoint]int) uint64 { - - // Increase the request counter. - // - // NOTE: this is the only place where we increase the - // counter. - requestID := t.requestCounter.Add(1) +func (t *TxPublisher) storeRecord(requestID uint64, tx *wire.MsgTx, + req *BumpRequest, f FeeFunction, fee btcutil.Amount, + outpointToTxIndex map[wire.OutPoint]int) { // Register the record. t.records.Store(requestID, &monitorRecord{ @@ -528,8 +578,6 @@ func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, fee: fee, outpointToTxIndex: outpointToTxIndex, }) - - return requestID } // createAndCheckTx creates a tx based on the given inputs, change output @@ -659,8 +707,7 @@ func (t *TxPublisher) notifyResult(result *BumpResult) { return } - log.Debugf("Sending result for requestID=%v, tx=%v", id, - result.Tx.TxHash()) + log.Debugf("Sending result %v for requestID=%v", result, id) select { // Send the result to the subscriber. @@ -678,20 +725,31 @@ func (t *TxPublisher) notifyResult(result *BumpResult) { func (t *TxPublisher) removeResult(result *BumpResult) { id := result.requestID - // Remove the record from the maps if there's an error. This means this - // tx has failed its broadcast and cannot be retried. There are two - // cases, - // - when the budget cannot cover the fee. - // - when a non-RBF related error occurs. + var txid chainhash.Hash + if result.Tx != nil { + txid = result.Tx.TxHash() + } + + // Remove the record from the maps if there's an error or the tx is + // confirmed. When there's an error, it means this tx has failed its + // broadcast and cannot be retried. There are two cases it may fail, + // - when the budget cannot cover the increased fee calculated by the + // fee function, hence the budget is used up. + // - when a non-fee related error returned from PublishTransaction. switch result.Event { case TxFailed: log.Errorf("Removing monitor record=%v, tx=%v, due to err: %v", - id, result.Tx.TxHash(), result.Err) + id, txid, result.Err) case TxConfirmed: - // Remove the record is the tx is confirmed. + // Remove the record if the tx is confirmed. log.Debugf("Removing confirmed monitor record=%v, tx=%v", id, - result.Tx.TxHash()) + txid) + + case TxFatal: + // Remove the record if there's an error. + log.Debugf("Removing monitor record=%v due to fatal err: %v", + id, result.Err) // Do nothing if it's neither failed or confirmed. default: @@ -737,20 +795,18 @@ type monitorRecord struct { // Start starts the publisher by subscribing to block epoch updates and kicking // off the monitor loop. -func (t *TxPublisher) Start() error { +func (t *TxPublisher) Start(beat chainio.Blockbeat) error { log.Info("TxPublisher starting...") if t.started.Swap(true) { return fmt.Errorf("TxPublisher started more than once") } - blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } + // Set the current height. + t.currentHeight.Store(beat.Height()) t.wg.Add(1) - go t.monitor(blockEvent) + go t.monitor() log.Debugf("TxPublisher started") @@ -778,33 +834,25 @@ func (t *TxPublisher) Stop() error { // to be bumped. If so, it will attempt to bump the fee of the tx. // // NOTE: Must be run as a goroutine. -func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { - defer blockEvent.Cancel() +func (t *TxPublisher) monitor() { defer t.wg.Done() for { select { - case epoch, ok := <-blockEvent.Epochs: - if !ok { - // We should stop the publisher before stopping - // the chain service. Otherwise it indicates an - // error. - log.Error("Block epoch channel closed, exit " + - "monitor") - - return - } - - log.Debugf("TxPublisher received new block: %v", - epoch.Height) + case beat := <-t.BlockbeatChan: + height := beat.Height() + log.Debugf("TxPublisher received new block: %v", height) // Update the best known height for the publisher. - t.currentHeight.Store(epoch.Height) + t.currentHeight.Store(height) // Check all monitored txns to see if any of them needs // to be bumped. t.processRecords() + // Notify we've processed the block. + t.NotifyBlockProcessed(beat, nil) + case <-t.quit: log.Debug("Fee bumper stopped, exit monitor") return @@ -819,18 +867,27 @@ func (t *TxPublisher) processRecords() { // confirmed. confirmedRecords := make(map[uint64]*monitorRecord) - // feeBumpRecords stores a map of the records which need to be bumped. + // feeBumpRecords stores a map of records which need to be bumped. feeBumpRecords := make(map[uint64]*monitorRecord) - // failedRecords stores a map of the records which has inputs being - // spent by a third party. + // failedRecords stores a map of records which has inputs being spent + // by a third party. // // NOTE: this is only used for neutrino backend. failedRecords := make(map[uint64]*monitorRecord) + // initialRecords stores a map of records which are being created and + // published for the first time. + initialRecords := make(map[uint64]*monitorRecord) + // visitor is a helper closure that visits each record and divides them // into two groups. visitor := func(requestID uint64, r *monitorRecord) error { + if r.tx == nil { + initialRecords[requestID] = r + return nil + } + log.Tracef("Checking monitor recordID=%v for tx=%v", requestID, r.tx.TxHash()) @@ -858,17 +915,20 @@ func (t *TxPublisher) processRecords() { return nil } - // Iterate through all the records and divide them into two groups. + // Iterate through all the records and divide them into four groups. t.records.ForEach(visitor) + // Handle the initial broadcast. + for requestID, r := range initialRecords { + t.handleInitialBroadcast(r, requestID) + } + // For records that are confirmed, we'll notify the caller about this // result. for requestID, r := range confirmedRecords { - rec := r - log.Debugf("Tx=%v is confirmed", r.tx.TxHash()) t.wg.Add(1) - go t.handleTxConfirmed(rec, requestID) + go t.handleTxConfirmed(r, requestID) } // Get the current height to be used in the following goroutines. @@ -876,22 +936,18 @@ func (t *TxPublisher) processRecords() { // For records that are not confirmed, we perform a fee bump if needed. for requestID, r := range feeBumpRecords { - rec := r - log.Debugf("Attempting to fee bump Tx=%v", r.tx.TxHash()) t.wg.Add(1) - go t.handleFeeBumpTx(requestID, rec, currentHeight) + go t.handleFeeBumpTx(requestID, r, currentHeight) } // For records that are failed, we'll notify the caller about this // result. for requestID, r := range failedRecords { - rec := r - log.Debugf("Tx=%v has inputs been spent by a third party, "+ "failing it now", r.tx.TxHash()) t.wg.Add(1) - go t.handleThirdPartySpent(rec, requestID) + go t.handleThirdPartySpent(r, requestID) } } @@ -916,6 +972,96 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { t.handleResult(result) } +// handleInitialTxError takes the error from `initializeTx` and decides the +// bump event. It will construct a BumpResult and handles it. +func (t *TxPublisher) handleInitialTxError(requestID uint64, err error) { + // We now decide what type of event to send. + var event BumpEvent + + switch { + // When the error is due to a dust output, we'll send a TxFailed so + // these inputs can be retried with a different group in the next + // block. + case errors.Is(err, ErrTxNoOutput): + event = TxFailed + + // When the error is due to budget being used up, we'll send a TxFailed + // so these inputs can be retried with a different group in the next + // block. + case errors.Is(err, ErrMaxPosition): + event = TxFailed + + // When the error is due to zero fee rate delta, we'll send a TxFailed + // so these inputs can be retried in the next block. + case errors.Is(err, ErrZeroFeeRateDelta): + event = TxFailed + + // Otherwise this is not a fee-related error and the tx cannot be + // retried. In that case we will fail ALL the inputs in this tx, which + // means they will be removed from the sweeper and never be tried + // again. + // + // TODO(yy): Find out which input is causing the failure and fail that + // one only. + default: + event = TxFatal + } + + result := &BumpResult{ + Event: event, + Err: err, + requestID: requestID, + } + + t.handleResult(result) +} + +// handleInitialBroadcast is called when a new request is received. It will +// handle the initial tx creation and broadcast. In details, +// 1. init a fee function based on the given strategy. +// 2. create an RBF-compliant tx and monitor it for confirmation. +// 3. notify the initial broadcast result back to the caller. +func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord, + requestID uint64) { + + log.Debugf("Initial broadcast for requestID=%v", requestID) + + var ( + result *BumpResult + err error + ) + + // Attempt an initial broadcast which is guaranteed to comply with the + // RBF rules. + // + // Create the initial tx to be broadcasted. + err = t.initializeTx(requestID, r.req) + if err != nil { + log.Errorf("Initial broadcast failed: %v", err) + + // We now handle the initialization error and exit. + t.handleInitialTxError(requestID, err) + + return + } + + // Successfully created the first tx, now broadcast it. + result, err = t.broadcast(requestID) + if err != nil { + // The broadcast failed, which can only happen if the tx record + // cannot be found or the aux sweeper returns an error. In + // either case, we will send back a TxFail event so these + // inputs can be retried. + result = &BumpResult{ + Event: TxFailed, + Err: err, + requestID: requestID, + } + } + + t.handleResult(result) +} + // handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will // attempt to bump the fee of the tx. // @@ -1382,7 +1528,7 @@ func prepareSweepTx(inputs []input.Input, changePkScript lnwallet.AddrWithKey, return err } - extraChangeOut = extraOut.LeftToOption() + extraChangeOut = extraOut.LeftToSome() return nil }, diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 5030dee227..54c67dbe28 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -91,6 +91,12 @@ func TestBumpResultValidate(t *testing.T) { } require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + // A fatal event without a failure reason will give an error. + b = BumpResult{ + Event: TxFailed, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + // A confirmed event without fee info will give an error. b = BumpResult{ Tx: &wire.MsgTx{}, @@ -104,6 +110,20 @@ func TestBumpResultValidate(t *testing.T) { Event: TxPublished, } require.NoError(t, b.Validate()) + + // Tx is allowed to be nil in a TxFailed event. + b = BumpResult{ + Event: TxFailed, + Err: errDummy, + } + require.NoError(t, b.Validate()) + + // Tx is allowed to be nil in a TxFatal event. + b = BumpResult{ + Event: TxFatal, + Err: errDummy, + } + require.NoError(t, b.Validate()) } // TestCalcSweepTxWeight checks that the weight of the sweep tx is calculated @@ -332,13 +352,10 @@ func TestStoreRecord(t *testing.T) { } // Call the method under test. - requestID := tp.storeRecord(tx, req, feeFunc, fee, utxoIndex) - - // Check the request ID is as expected. - require.Equal(t, initialCounter+1, requestID) + tp.storeRecord(initialCounter, tx, req, feeFunc, fee, utxoIndex) // Read the saved record and compare. - record, ok := tp.records.Load(requestID) + record, ok := tp.records.Load(initialCounter) require.True(t, ok) require.Equal(t, tx, record.tx) require.Equal(t, feeFunc, record.feeFunction) @@ -635,23 +652,19 @@ func TestCreateRBFCompliantTx(t *testing.T) { }, } + var requestCounter atomic.Uint64 for _, tc := range testCases { tc := tc + rid := requestCounter.Add(1) t.Run(tc.name, func(t *testing.T) { tc.setupMock() // Call the method under test. - id, err := tp.createRBFCompliantTx(req, m.feeFunc) + err := tp.createRBFCompliantTx(rid, req, m.feeFunc) // Check the result is as expected. require.ErrorIs(t, err, tc.expectedErr) - - // If there's an error, expect the requestID to be - // empty. - if tc.expectedErr != nil { - require.Zero(t, id) - } }) } } @@ -684,7 +697,8 @@ func TestTxPublisherBroadcast(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex) // Quickly check when the requestID cannot be found, an error is // returned. @@ -779,6 +793,9 @@ func TestRemoveResult(t *testing.T) { op: 0, } + // Create a test request ID counter. + requestCounter := atomic.Uint64{} + testCases := []struct { name string setupRecord func() uint64 @@ -790,12 +807,13 @@ func TestRemoveResult(t *testing.T) { // removed. name: "remove on TxConfirmed", setupRecord: func() uint64 { - id := tp.storeRecord( - tx, req, m.feeFunc, fee, utxoIndex, + rid := requestCounter.Add(1) + tp.storeRecord( + rid, tx, req, m.feeFunc, fee, utxoIndex, ) - tp.subscriberChans.Store(id, nil) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxConfirmed, @@ -807,12 +825,13 @@ func TestRemoveResult(t *testing.T) { // When the tx is failed, the records will be removed. name: "remove on TxFailed", setupRecord: func() uint64 { - id := tp.storeRecord( - tx, req, m.feeFunc, fee, utxoIndex, + rid := requestCounter.Add(1) + tp.storeRecord( + rid, tx, req, m.feeFunc, fee, utxoIndex, ) - tp.subscriberChans.Store(id, nil) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxFailed, @@ -825,12 +844,13 @@ func TestRemoveResult(t *testing.T) { // Noop when the tx is neither confirmed or failed. name: "noop when tx is not confirmed or failed", setupRecord: func() uint64 { - id := tp.storeRecord( - tx, req, m.feeFunc, fee, utxoIndex, + rid := requestCounter.Add(1) + tp.storeRecord( + rid, tx, req, m.feeFunc, fee, utxoIndex, ) - tp.subscriberChans.Store(id, nil) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxPublished, @@ -885,7 +905,8 @@ func TestNotifyResult(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -933,41 +954,17 @@ func TestNotifyResult(t *testing.T) { } } -// TestBroadcastSuccess checks the public `Broadcast` method can successfully -// broadcast a tx based on the request. -func TestBroadcastSuccess(t *testing.T) { +// TestBroadcast checks the public `Broadcast` method can successfully register +// a broadcast request. +func TestBroadcast(t *testing.T) { t.Parallel() // Create a publisher using the mocks. - tp, m := createTestPublisher(t) + tp, _ := createTestPublisher(t) // Create a test feerate. feerate := chainfee.SatPerKWeight(1000) - // Mock the fee estimator to return the testing fee rate. - // - // We are not testing `NewLinearFeeFunction` here, so the actual params - // used are irrelevant. - m.estimator.On("EstimateFeePerKW", mock.Anything).Return( - feerate, nil).Once() - m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() - - // Mock the signer to always return a valid script. - // - // NOTE: we are not testing the utility of creating valid txes here, so - // this is fine to be mocked. This behaves essentially as skipping the - // Signer check and alaways assume the tx has a valid sig. - script := &input.Script{} - m.signer.On("ComputeInputScript", mock.Anything, - mock.Anything).Return(script, nil) - - // Mock the testmempoolaccept to pass. - m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - - // Mock the wallet to publish successfully. - m.wallet.On("PublishTransaction", - mock.Anything, mock.Anything).Return(nil).Once() - // Create a test request. inp := createTestInput(1000, input.WitnessKeyHash) @@ -981,27 +978,24 @@ func TestBroadcastSuccess(t *testing.T) { } // Send the req and expect no error. - resultChan, err := tp.Broadcast(req) - require.NoError(t, err) - - // Check the result is sent back. - select { - case <-time.After(time.Second): - t.Fatal("timeout waiting for subscriber to receive result") - - case result := <-resultChan: - // We expect the first result to be TxPublished. - require.Equal(t, TxPublished, result.Event) - } + resultChan := tp.Broadcast(req) + require.NotNil(t, resultChan) // Validate the record was stored. require.Equal(t, 1, tp.records.Len()) require.Equal(t, 1, tp.subscriberChans.Len()) + + // Validate the record. + rid := tp.requestCounter.Load() + record, found := tp.records.Load(rid) + require.True(t, found) + require.Equal(t, req, record.req) } -// TestBroadcastFail checks the public `Broadcast` returns the error or a -// failed result when the broadcast fails. -func TestBroadcastFail(t *testing.T) { +// TestBroadcastImmediate checks the public `Broadcast` method can successfully +// register a broadcast request and publish the tx when `Immediate` flag is +// set. +func TestBroadcastImmediate(t *testing.T) { t.Parallel() // Create a publisher using the mocks. @@ -1020,64 +1014,27 @@ func TestBroadcastFail(t *testing.T) { Budget: btcutil.Amount(1000), MaxFeeRate: feerate * 10, DeadlineHeight: 10, + Immediate: true, } - // Mock the fee estimator to return the testing fee rate. + // Mock the fee estimator to return an error. // - // We are not testing `NewLinearFeeFunction` here, so the actual params - // used are irrelevant. + // NOTE: We are not testing `handleInitialBroadcast` here, but only + // interested in checking that this method is indeed called when + // `Immediate` is true. Thus we mock the method to return an error to + // quickly abort. As long as this mocked method is called, we know the + // `Immediate` flag works. m.estimator.On("EstimateFeePerKW", mock.Anything).Return( - feerate, nil).Twice() - m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() + chainfee.SatPerKWeight(0), errDummy).Once() - // Mock the signer to always return a valid script. - // - // NOTE: we are not testing the utility of creating valid txes here, so - // this is fine to be mocked. This behaves essentially as skipping the - // Signer check and alaways assume the tx has a valid sig. - script := &input.Script{} - m.signer.On("ComputeInputScript", mock.Anything, - mock.Anything).Return(script, nil) - - // Mock the testmempoolaccept to return an error. - m.wallet.On("CheckMempoolAcceptance", - mock.Anything).Return(errDummy).Once() - - // Send the req and expect an error returned. - resultChan, err := tp.Broadcast(req) - require.ErrorIs(t, err, errDummy) - require.Nil(t, resultChan) - - // Validate the record was NOT stored. - require.Equal(t, 0, tp.records.Len()) - require.Equal(t, 0, tp.subscriberChans.Len()) - - // Mock the testmempoolaccept again, this time it passes. - m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - - // Mock the wallet to fail on publish. - m.wallet.On("PublishTransaction", - mock.Anything, mock.Anything).Return(errDummy).Once() - - // Send the req and expect no error returned. - resultChan, err = tp.Broadcast(req) - require.NoError(t, err) - - // Check the result is sent back. - select { - case <-time.After(time.Second): - t.Fatal("timeout waiting for subscriber to receive result") - - case result := <-resultChan: - // We expect the result to be TxFailed and the error is set in - // the result. - require.Equal(t, TxFailed, result.Event) - require.ErrorIs(t, result.Err, errDummy) - } + // Send the req and expect no error. + resultChan := tp.Broadcast(req) + require.NotNil(t, resultChan) - // Validate the record was removed. - require.Equal(t, 0, tp.records.Len()) - require.Equal(t, 0, tp.subscriberChans.Len()) + // Validate the record was removed due to an error returned in initial + // broadcast. + require.Empty(t, tp.records.Len()) + require.Empty(t, tp.subscriberChans.Len()) } // TestCreateAnPublishFail checks all the error cases are handled properly in @@ -1250,7 +1207,8 @@ func TestHandleTxConfirmed(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex) record, ok := tp.records.Load(requestID) require.True(t, ok) @@ -1330,7 +1288,8 @@ func TestHandleFeeBumpTx(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -1531,3 +1490,183 @@ func TestProcessRecords(t *testing.T) { require.Equal(t, requestID2, result.requestID) } } + +// TestHandleInitialBroadcastSuccess checks `handleInitialBroadcast` method can +// successfully broadcast a tx based on the request. +func TestHandleInitialBroadcastSuccess(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Once() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to pass. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate * 10, + DeadlineHeight: 10, + } + + // Register the testing record use `Broadcast`. + resultChan := tp.Broadcast(req) + + // Grab the monitor record from the map. + rid := tp.requestCounter.Load() + rec, ok := tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxPublished. + require.Equal(t, TxPublished, result.Event) + } + + // Validate the record was stored. + require.Equal(t, 1, tp.records.Len()) + require.Equal(t, 1, tp.subscriberChans.Len()) +} + +// TestHandleInitialBroadcastFail checks `handleInitialBroadcast` returns the +// error or a failed result when the broadcast fails. +func TestHandleInitialBroadcastFail(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate * 10, + DeadlineHeight: 10, + } + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Twice() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return an error. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + + // Register the testing record use `Broadcast`. + resultChan := tp.Broadcast(req) + + // Grab the monitor record from the map. + rid := tp.requestCounter.Load() + rec, ok := tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test and expect an error returned. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxFatal. + require.Equal(t, TxFatal, result.Event) + } + + // Validate the record was NOT stored. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) + + // Mock the testmempoolaccept again, this time it passes. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to fail on publish. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(errDummy).Once() + + // Register the testing record use `Broadcast`. + resultChan = tp.Broadcast(req) + + // Grab the monitor record from the map. + rid = tp.requestCounter.Load() + rec, ok = tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the result to be TxFailed and the error is set in + // the result. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, errDummy) + } + + // Validate the record was removed. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) +} diff --git a/sweep/fee_function.go b/sweep/fee_function.go index cbf283e37d..eb2ed4d6b1 100644 --- a/sweep/fee_function.go +++ b/sweep/fee_function.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" ) @@ -14,6 +14,9 @@ var ( // ErrMaxPosition is returned when trying to increase the position of // the fee function while it's already at its max. ErrMaxPosition = errors.New("position already at max") + + // ErrZeroFeeRateDelta is returned when the fee rate delta is zero. + ErrZeroFeeRateDelta = errors.New("fee rate delta is zero") ) // mSatPerKWeight represents a fee rate in msat/kw. @@ -169,7 +172,7 @@ func NewLinearFeeFunction(maxFeeRate chainfee.SatPerKWeight, "endingFeeRate=%v, width=%v, delta=%v", start, end, l.width, l.deltaFeeRate) - return nil, fmt.Errorf("fee rate delta is zero") + return nil, ErrZeroFeeRateDelta } // Attach the calculated values to the fee function. diff --git a/sweep/fee_function_test.go b/sweep/fee_function_test.go index c278bb7f06..a55ce79a78 100644 --- a/sweep/fee_function_test.go +++ b/sweep/fee_function_test.go @@ -3,7 +3,7 @@ package sweep import ( "testing" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/require" ) diff --git a/sweep/interface.go b/sweep/interface.go index f2fff84b08..6c8c2cfad2 100644 --- a/sweep/interface.go +++ b/sweep/interface.go @@ -4,7 +4,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 34202b1453..f9471f22a0 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -4,7 +4,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -268,6 +268,13 @@ func (m *MockInputSet) StartingFeeRate() fn.Option[chainfee.SatPerKWeight] { return args.Get(0).(fn.Option[chainfee.SatPerKWeight]) } +// Immediate returns whether the inputs should be swept immediately. +func (m *MockInputSet) Immediate() bool { + args := m.Called() + + return args.Bool(0) +} + // MockBumper is a mock implementation of the interface Bumper. type MockBumper struct { mock.Mock @@ -277,14 +284,14 @@ type MockBumper struct { var _ Bumper = (*MockBumper)(nil) // Broadcast broadcasts the transaction to the network. -func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { +func (m *MockBumper) Broadcast(req *BumpRequest) <-chan *BumpResult { args := m.Called(req) if args.Get(0) == nil { - return nil, args.Error(1) + return nil } - return args.Get(0).(chan *BumpResult), args.Error(1) + return args.Get(0).(chan *BumpResult) } // MockFeeFunction is a mock implementation of the FeeFunction interface. diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 6257faac1f..d49f104af0 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -10,8 +10,9 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" @@ -222,6 +223,35 @@ func (p *SweeperInput) terminated() bool { } } +// isMature returns a boolean indicating whether the input has a timelock that +// has been reached or not. The locktime found is also returned. +func (p *SweeperInput) isMature(currentHeight uint32) (bool, uint32) { + locktime, _ := p.RequiredLockTime() + if currentHeight < locktime { + log.Debugf("Input %v has locktime=%v, current height is %v", + p.OutPoint(), locktime, currentHeight) + + return false, locktime + } + + // If the input has a CSV that's not yet reached, we will skip + // this input and wait for the expiry. + // + // NOTE: We need to consider whether this input can be included in the + // next block or not, which means the CSV will be checked against the + // currentHeight plus one. + locktime = p.BlocksToMaturity() + p.HeightHint() + if currentHeight+1 < locktime { + log.Debugf("Input %v has CSV expiry=%v, current height is %v, "+ + "skipped sweeping", p.OutPoint(), locktime, + currentHeight) + + return false, locktime + } + + return true, locktime +} + // InputsMap is a type alias for a set of pending inputs. type InputsMap = map[wire.OutPoint]*SweeperInput @@ -280,6 +310,10 @@ type UtxoSweeper struct { started uint32 // To be used atomically. stopped uint32 // To be used atomically. + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + cfg *UtxoSweeperConfig newInputs chan *sweepInputMessage @@ -309,11 +343,14 @@ type UtxoSweeper struct { // updated whenever a new block epoch is received. currentHeight int32 - // bumpResultChan is a channel that receives broadcast results from the + // bumpRespChan is a channel that receives broadcast results from the // TxPublisher. - bumpResultChan chan *BumpResult + bumpRespChan chan *bumpResp } +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*UtxoSweeper)(nil) + // UtxoSweeperConfig contains dependencies of UtxoSweeper. type UtxoSweeperConfig struct { // GenSweepScript generates a P2WKH script belonging to the wallet where @@ -387,7 +424,7 @@ type sweepInputMessage struct { // New returns a new Sweeper instance. func New(cfg *UtxoSweeperConfig) *UtxoSweeper { - return &UtxoSweeper{ + s := &UtxoSweeper{ cfg: cfg, newInputs: make(chan *sweepInputMessage), spendChan: make(chan *chainntnfs.SpendDetail), @@ -395,12 +432,17 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { pendingSweepsReqs: make(chan *pendingSweepsReq), quit: make(chan struct{}), inputs: make(InputsMap), - bumpResultChan: make(chan *BumpResult, 100), + bumpRespChan: make(chan *bumpResp, 100), } + + // Mount the block consumer. + s.BeatConsumer = chainio.NewBeatConsumer(s.quit, s.Name()) + + return s } // Start starts the process of constructing and publish sweep txes. -func (s *UtxoSweeper) Start() error { +func (s *UtxoSweeper) Start(beat chainio.Blockbeat) error { if !atomic.CompareAndSwapUint32(&s.started, 0, 1) { return nil } @@ -411,49 +453,12 @@ func (s *UtxoSweeper) Start() error { // not change from here on. s.relayFeeRate = s.cfg.FeeEstimator.RelayFeePerKW() - // We need to register for block epochs and retry sweeping every block. - // We should get a notification with the current best block immediately - // if we don't provide any epoch. We'll wait for that in the collector. - blockEpochs, err := s.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } + // Set the current height. + s.currentHeight = beat.Height() // Start sweeper main loop. s.wg.Add(1) - go func() { - defer blockEpochs.Cancel() - defer s.wg.Done() - - s.collector(blockEpochs.Epochs) - - // The collector exited and won't longer handle incoming - // requests. This can happen on shutdown, when the block - // notifier shuts down before the sweeper and its clients. In - // order to not deadlock the clients waiting for their requests - // being handled, we handle them here and immediately return an - // error. When the sweeper finally is shut down we can exit as - // the clients will be notified. - for { - select { - case inp := <-s.newInputs: - inp.resultChan <- Result{ - Err: ErrSweeperShuttingDown, - } - - case req := <-s.pendingSweepsReqs: - req.errChan <- ErrSweeperShuttingDown - - case req := <-s.updateReqs: - req.responseChan <- &updateResp{ - err: ErrSweeperShuttingDown, - } - - case <-s.quit: - return - } - } - }() + go s.collector() return nil } @@ -480,6 +485,11 @@ func (s *UtxoSweeper) Stop() error { return nil } +// NOTE: part of the `chainio.Consumer` interface. +func (s *UtxoSweeper) Name() string { + return "UtxoSweeper" +} + // SweepInput sweeps inputs back into the wallet. The inputs will be batched and // swept after the batch time window ends. A custom fee preference can be // provided to determine what fee rate should be used for the input. Note that @@ -502,7 +512,7 @@ func (s *UtxoSweeper) SweepInput(inp input.Input, } absoluteTimeLock, _ := inp.RequiredLockTime() - log.Infof("Sweep request received: out_point=%v, witness_type=%v, "+ + log.Debugf("Sweep request received: out_point=%v, witness_type=%v, "+ "relative_time_lock=%v, absolute_time_lock=%v, amount=%v, "+ "parent=(%v), params=(%v)", inp.OutPoint(), inp.WitnessType(), inp.BlocksToMaturity(), absoluteTimeLock, @@ -611,17 +621,8 @@ func (s *UtxoSweeper) removeConflictSweepDescendants( // collector is the sweeper main loop. It processes new inputs, spend // notifications and counts down to publication of the sweep tx. -func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { - // We registered for the block epochs with a nil request. The notifier - // should send us the current best block immediately. So we need to wait - // for it here because we need to know the current best height. - select { - case bestBlock := <-blockEpochs: - s.currentHeight = bestBlock.Height - - case <-s.quit: - return - } +func (s *UtxoSweeper) collector() { + defer s.wg.Done() for { // Clean inputs, which will remove inputs that are swept, @@ -681,9 +682,9 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { s.sweepPendingInputs(inputs) } - case result := <-s.bumpResultChan: + case resp := <-s.bumpRespChan: // Handle the bump event. - err := s.handleBumpEvent(result) + err := s.handleBumpEvent(resp) if err != nil { log.Errorf("Failed to handle bump event: %v", err) @@ -691,28 +692,33 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // A new block comes in, update the bestHeight, perform a check // over all pending inputs and publish sweeping txns if needed. - case epoch, ok := <-blockEpochs: - if !ok { - // We should stop the sweeper before stopping - // the chain service. Otherwise it indicates an - // error. - log.Error("Block epoch channel closed") - - return - } - + case beat := <-s.BlockbeatChan: // Update the sweeper to the best height. - s.currentHeight = epoch.Height + s.currentHeight = beat.Height() // Update the inputs with the latest height. inputs := s.updateSweeperInputs() log.Debugf("Received new block: height=%v, attempt "+ - "sweeping %d inputs", epoch.Height, len(inputs)) + "sweeping %d inputs:\n%s", + s.currentHeight, len(inputs), + lnutils.NewLogClosure(func() string { + inps := make( + []input.Input, 0, len(inputs), + ) + for _, in := range inputs { + inps = append(inps, in) + } + + return inputTypeSummary(inps) + })) // Attempt to sweep any pending inputs. s.sweepPendingInputs(inputs) + // Notify we've processed the block. + s.NotifyBlockProcessed(beat, nil) + case <-s.quit: return } @@ -827,6 +833,7 @@ func (s *UtxoSweeper) sweep(set InputSet) error { DeliveryAddress: sweepAddr, MaxFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(), StartingFeeRate: set.StartingFeeRate(), + Immediate: set.Immediate(), // TODO(yy): pass the strategy here. } @@ -837,27 +844,13 @@ func (s *UtxoSweeper) sweep(set InputSet) error { // Broadcast will return a read-only chan that we will listen to for // this publish result and future RBF attempt. - resp, err := s.cfg.Publisher.Broadcast(req) - if err != nil { - outpoints := make([]wire.OutPoint, len(set.Inputs())) - for i, inp := range set.Inputs() { - outpoints[i] = inp.OutPoint() - } - - log.Errorf("Initial broadcast failed: %v, inputs=\n%v", err, - inputTypeSummary(set.Inputs())) - - // TODO(yy): find out which input is causing the failure. - s.markInputsPublishFailed(outpoints) - - return err - } + resp := s.cfg.Publisher.Broadcast(req) // Successfully sent the broadcast attempt, we now handle the result by // subscribing to the result chan and listen for future updates about // this tx. s.wg.Add(1) - go s.monitorFeeBumpResult(resp) + go s.monitorFeeBumpResult(set, resp) return nil } @@ -867,14 +860,14 @@ func (s *UtxoSweeper) sweep(set InputSet) error { func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) { // Reschedule sweep. for _, input := range set.Inputs() { - pi, ok := s.inputs[input.OutPoint()] + op := input.OutPoint() + pi, ok := s.inputs[op] if !ok { // It could be that this input is an additional wallet // input that was attached. In that case there also // isn't a pending input to update. log.Tracef("Skipped marking input as pending "+ - "published: %v not found in pending inputs", - input.OutPoint()) + "published: %v not found in pending inputs", op) continue } @@ -885,8 +878,7 @@ func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) { // publish. if pi.terminated() { log.Errorf("Expect input %v to not have terminated "+ - "state, instead it has %v", - input.OutPoint, pi.state) + "state, instead it has %v", op, pi.state) continue } @@ -901,9 +893,7 @@ func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) { // markInputsPublished updates the sweeping tx in db and marks the list of // inputs as published. -func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, - inputs []*wire.TxIn) error { - +func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, set InputSet) error { // Mark this tx in db once successfully published. // // NOTE: this will behave as an overwrite, which is fine as the record @@ -915,15 +905,15 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, } // Reschedule sweep. - for _, input := range inputs { - pi, ok := s.inputs[input.PreviousOutPoint] + for _, input := range set.Inputs() { + op := input.OutPoint() + pi, ok := s.inputs[op] if !ok { // It could be that this input is an additional wallet // input that was attached. In that case there also // isn't a pending input to update. log.Tracef("Skipped marking input as published: %v "+ - "not found in pending inputs", - input.PreviousOutPoint) + "not found in pending inputs", op) continue } @@ -932,8 +922,7 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, if pi.state != PendingPublish { // We may get a Published if this is a replacement tx. log.Debugf("Expect input %v to have %v, instead it "+ - "has %v", input.PreviousOutPoint, - PendingPublish, pi.state) + "has %v", op, PendingPublish, pi.state) continue } @@ -949,9 +938,10 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, } // markInputsPublishFailed marks the list of inputs as failed to be published. -func (s *UtxoSweeper) markInputsPublishFailed(outpoints []wire.OutPoint) { +func (s *UtxoSweeper) markInputsPublishFailed(set InputSet) { // Reschedule sweep. - for _, op := range outpoints { + for _, inp := range set.Inputs() { + op := inp.OutPoint() pi, ok := s.inputs[op] if !ok { // It could be that this input is an additional wallet @@ -1054,6 +1044,12 @@ func (s *UtxoSweeper) handlePendingSweepsReq( resps := make(map[wire.OutPoint]*PendingInputResponse, len(s.inputs)) for _, inp := range s.inputs { + // Skip immature inputs for compatibility. + mature, _ := inp.isMature(uint32(s.currentHeight)) + if !mature { + continue + } + // Only the exported fields are set, as we expect the response // to only be consumed externally. op := inp.OutPoint() @@ -1189,17 +1185,34 @@ func (s *UtxoSweeper) mempoolLookup(op wire.OutPoint) fn.Option[wire.MsgTx] { return s.cfg.Mempool.LookupInputMempoolSpend(op) } -// handleNewInput processes a new input by registering spend notification and -// scheduling sweeping for it. -func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { +// calculateDefaultDeadline calculates the default deadline height for a sweep +// request that has no deadline height specified. +func (s *UtxoSweeper) calculateDefaultDeadline(pi *SweeperInput) int32 { // Create a default deadline height, which will be used when there's no // DeadlineHeight specified for a given input. defaultDeadline := s.currentHeight + int32(s.cfg.NoDeadlineConfTarget) + // If the input is immature and has a locktime, we'll use the locktime + // height as the starting height. + matured, locktime := pi.isMature(uint32(s.currentHeight)) + if !matured { + defaultDeadline = int32(locktime + s.cfg.NoDeadlineConfTarget) + log.Debugf("Input %v is immature, using locktime=%v instead "+ + "of current height=%d as starting height", + pi.OutPoint(), locktime, s.currentHeight) + } + + return defaultDeadline +} + +// handleNewInput processes a new input by registering spend notification and +// scheduling sweeping for it. +func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { outpoint := input.input.OutPoint() pi, pending := s.inputs[outpoint] if pending { - log.Debugf("Already has pending input %v received", outpoint) + log.Infof("Already has pending input %v received, old params: "+ + "%v, new params %v", outpoint, pi.params, input.params) s.handleExistingInput(input, pi) @@ -1220,15 +1233,22 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { Input: input.input, params: input.params, rbf: rbfInfo, - // Set the acutal deadline height. - DeadlineHeight: input.params.DeadlineHeight.UnwrapOr( - defaultDeadline, - ), } + // Set the acutal deadline height. + pi.DeadlineHeight = input.params.DeadlineHeight.UnwrapOr( + s.calculateDefaultDeadline(pi), + ) + s.inputs[outpoint] = pi log.Tracef("input %v, state=%v, added to inputs", outpoint, pi.state) + log.Infof("Registered sweep request at block %d: out_point=%v, "+ + "witness_type=%v, amount=%v, deadline=%d, params=(%v)", + s.currentHeight, pi.OutPoint(), pi.WitnessType(), + btcutil.Amount(pi.SignDesc().Output.Value), pi.DeadlineHeight, + pi.params) + // Start watching for spend of this input, either by us or the remote // party. cancel, err := s.monitorSpend( @@ -1457,11 +1477,6 @@ func (s *UtxoSweeper) markInputFailed(pi *SweeperInput, err error) { pi.state = Failed - // Remove all other inputs in this exclusive group. - if pi.params.ExclusiveGroup != nil { - s.removeExclusiveGroup(*pi.params.ExclusiveGroup) - } - s.signalResult(pi, Result{Err: err}) } @@ -1479,6 +1494,8 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // turn this inputs map into a SyncMap in case we wanna add concurrent // access to the map in the future. for op, input := range s.inputs { + log.Tracef("Checking input: %s, state=%v", input, input.state) + // If the input has reached a final state, that it's either // been swept, or failed, or excluded, we will remove it from // our sweeper. @@ -1506,20 +1523,9 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // If the input has a locktime that's not yet reached, we will // skip this input and wait for the locktime to be reached. - locktime, _ := input.RequiredLockTime() - if uint32(s.currentHeight) < locktime { - log.Warnf("Skipping input %v due to locktime=%v not "+ - "reached, current height is %v", op, locktime, - s.currentHeight) - - continue - } - - // If the input has a CSV that's not yet reached, we will skip - // this input and wait for the expiry. - locktime = input.BlocksToMaturity() + input.HeightHint() - if s.currentHeight < int32(locktime)-1 { - log.Infof("Skipping input %v due to CSV expiry=%v not "+ + mature, locktime := input.isMature(uint32(s.currentHeight)) + if !mature { + log.Debugf("Skipping input %v due to locktime=%v not "+ "reached, current height is %v", op, locktime, s.currentHeight) @@ -1539,6 +1545,8 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // sweepPendingInputs is called when the ticker fires. It will create clusters // and attempt to create and publish the sweeping transactions. func (s *UtxoSweeper) sweepPendingInputs(inputs InputsMap) { + log.Debugf("Sweeping %v inputs", len(inputs)) + // Cluster all of our inputs based on the specific Aggregator. sets := s.cfg.Aggregator.ClusterInputs(inputs) @@ -1580,11 +1588,24 @@ func (s *UtxoSweeper) sweepPendingInputs(inputs InputsMap) { } } +// bumpResp wraps the result of a bump attempt returned from the fee bumper and +// the inputs being used. +type bumpResp struct { + // result is the result of the bump attempt returned from the fee + // bumper. + result *BumpResult + + // set is the input set that was used in the bump attempt. + set InputSet +} + // monitorFeeBumpResult subscribes to the passed result chan to listen for // future updates about the sweeping tx. // // NOTE: must run as a goroutine. -func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { +func (s *UtxoSweeper) monitorFeeBumpResult(set InputSet, + resultChan <-chan *BumpResult) { + defer s.wg.Done() for { @@ -1596,9 +1617,14 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { continue } + resp := &bumpResp{ + result: r, + set: set, + } + // Send the result back to the main event loop. select { - case s.bumpResultChan <- r: + case s.bumpRespChan <- resp: case <-s.quit: log.Debug("Sweeper shutting down, skip " + "sending bump result") @@ -1613,6 +1639,14 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { // in sweeper and rely solely on this event to mark // inputs as Swept? if r.Event == TxConfirmed || r.Event == TxFailed { + // Exit if the tx is failed to be created. + if r.Tx == nil { + log.Debugf("Received %v for nil tx, "+ + "exit monitor", r.Event) + + return + } + log.Debugf("Received %v for sweep tx %v, exit "+ "fee bump monitor", r.Event, r.Tx.TxHash()) @@ -1634,25 +1668,28 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { // handleBumpEventTxFailed handles the case where the tx has been failed to // publish. -func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error { +func (s *UtxoSweeper) handleBumpEventTxFailed(resp *bumpResp) { + r := resp.result tx, err := r.Tx, r.Err - log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) - - outpoints := make([]wire.OutPoint, 0, len(tx.TxIn)) - for _, inp := range tx.TxIn { - outpoints = append(outpoints, inp.PreviousOutPoint) + if tx != nil { + log.Warnf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), + err) } + // NOTE: When marking the inputs as failed, we are using the input set + // instead of the inputs found in the tx. This is fine for current + // version of the sweeper because we always create a tx using ALL of + // the inputs specified by the set. + // // TODO(yy): should we also remove the failed tx from db? - s.markInputsPublishFailed(outpoints) - - return err + s.markInputsPublishFailed(resp.set) } // handleBumpEventTxReplaced handles the case where the sweeping tx has been // replaced by a new one. -func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error { +func (s *UtxoSweeper) handleBumpEventTxReplaced(resp *bumpResp) error { + r := resp.result oldTx := r.ReplacedTx newTx := r.Tx @@ -1692,12 +1729,13 @@ func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error { } // Mark the inputs as published using the replacing tx. - return s.markInputsPublished(tr, r.Tx.TxIn) + return s.markInputsPublished(tr, resp.set) } // handleBumpEventTxPublished handles the case where the sweeping tx has been // successfully published. -func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { +func (s *UtxoSweeper) handleBumpEventTxPublished(resp *bumpResp) error { + r := resp.result tx := r.Tx tr := &TxRecord{ Txid: tx.TxHash(), @@ -1707,7 +1745,7 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { // Inputs have been successfully published so we update their // states. - err := s.markInputsPublished(tr, tx.TxIn) + err := s.markInputsPublished(tr, resp.set) if err != nil { return err } @@ -1723,15 +1761,71 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { return nil } +// handleBumpEventTxFatal handles the case where there's an unexpected error +// when creating or publishing the sweeping tx. In this case, the tx will be +// removed from the sweeper store and the inputs will be marked as `Failed`, +// which means they will not be retried. +func (s *UtxoSweeper) handleBumpEventTxFatal(resp *bumpResp) error { + r := resp.result + + // Remove the tx from the sweeper store if there is one. Since this is + // a broadcast error, it's likely there isn't a tx here. + if r.Tx != nil { + txid := r.Tx.TxHash() + log.Infof("Tx=%v failed with unexpected error: %v", txid, r.Err) + + // Remove the tx from the sweeper db if it exists. + if err := s.cfg.Store.DeleteTx(txid); err != nil { + return fmt.Errorf("delete tx record for %v: %w", txid, + err) + } + } + + // Mark the inputs as failed. + s.markInputsFailed(resp.set, r.Err) + + return nil +} + +// markInputsFailed marks all inputs found in the tx as failed. It will also +// notify all the subscribers of these inputs. +func (s *UtxoSweeper) markInputsFailed(set InputSet, err error) { + for _, inp := range set.Inputs() { + outpoint := inp.OutPoint() + + input, ok := s.inputs[outpoint] + if !ok { + // It's very likely that a spending tx contains inputs + // that we don't know. + log.Tracef("Skipped marking input as failed: %v not "+ + "found in pending inputs", outpoint) + + continue + } + + // If the input is already in a terminal state, we don't want + // to rewrite it, which also indicates an error as we only get + // an error event during the initial broadcast. + if input.terminated() { + log.Errorf("Skipped marking input=%v as failed due to "+ + "unexpected state=%v", outpoint, input.state) + + continue + } + + s.markInputFailed(input, err) + } +} + // handleBumpEvent handles the result sent from the bumper based on its event // type. // // NOTE: TxConfirmed event is not handled, since we already subscribe to the // input's spending event, we don't need to do anything here. -func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { - log.Debugf("Received bump event [%v] for tx %v", r.Event, r.Tx.TxHash()) +func (s *UtxoSweeper) handleBumpEvent(r *bumpResp) error { + log.Debugf("Received bump result %v", r.result) - switch r.Event { + switch r.result.Event { // The tx has been published, we update the inputs' state and create a // record to be stored in the sweeper db. case TxPublished: @@ -1739,12 +1833,18 @@ func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { // The tx has failed, we update the inputs' state. case TxFailed: - return s.handleBumpEventTxFailed(r) + s.handleBumpEventTxFailed(r) + return nil // The tx has been replaced, we will remove the old tx and replace it // with the new one. case TxReplaced: return s.handleBumpEventTxReplaced(r) + + // There's a fatal error in creating the tx, we will remove the tx from + // the sweeper db and mark the inputs as failed. + case TxFatal: + return s.handleBumpEventTxFatal(r) } return nil diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 2b61f67933..16a4a46fbe 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1,6 +1,7 @@ package sweep import ( + "crypto/rand" "errors" "testing" "time" @@ -10,8 +11,9 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/mock" @@ -33,6 +35,41 @@ var ( }) ) +// createMockInput creates a mock input and saves it to the sweeper's inputs +// map. The created input has the specified state and a random outpoint. It +// will assert the method `OutPoint` is called at least once. +func createMockInput(t *testing.T, s *UtxoSweeper, + state SweepState) *input.MockInput { + + inp := &input.MockInput{} + t.Cleanup(func() { + inp.AssertExpectations(t) + }) + + randBuf := make([]byte, lntypes.HashSize) + _, err := rand.Read(randBuf) + require.NoError(t, err, "internal error, cannot generate random bytes") + + randHash, err := chainhash.NewHash(randBuf) + require.NoError(t, err) + + inp.On("OutPoint").Return(wire.OutPoint{ + Hash: *randHash, + Index: 0, + }) + + // We don't do branch switches based on the witness type here so we + // just mock it. + inp.On("WitnessType").Return(input.CommitmentTimeLock).Maybe() + + s.inputs[inp.OutPoint()] = &SweeperInput{ + Input: inp, + state: state, + } + + return inp +} + // TestMarkInputsPendingPublish checks that given a list of inputs with // different states, only the non-terminal state will be marked as `Published`. func TestMarkInputsPendingPublish(t *testing.T) { @@ -47,50 +84,21 @@ func TestMarkInputsPendingPublish(t *testing.T) { set := &MockInputSet{} defer set.AssertExpectations(t) - // Create three testing inputs. - // - // inputNotExist specifies an input that's not found in the sweeper's - // `pendingInputs` map. - inputNotExist := &input.MockInput{} - defer inputNotExist.AssertExpectations(t) - - inputNotExist.On("OutPoint").Return(wire.OutPoint{Index: 0}) - - // inputInit specifies a newly created input. - inputInit := &input.MockInput{} - defer inputInit.AssertExpectations(t) - - inputInit.On("OutPoint").Return(wire.OutPoint{Index: 1}) - - s.inputs[inputInit.OutPoint()] = &SweeperInput{ - state: Init, - } - - // inputPendingPublish specifies an input that's about to be published. - inputPendingPublish := &input.MockInput{} - defer inputPendingPublish.AssertExpectations(t) - - inputPendingPublish.On("OutPoint").Return(wire.OutPoint{Index: 2}) - - s.inputs[inputPendingPublish.OutPoint()] = &SweeperInput{ - state: PendingPublish, - } - - // inputTerminated specifies an input that's terminated. - inputTerminated := &input.MockInput{} - defer inputTerminated.AssertExpectations(t) - - inputTerminated.On("OutPoint").Return(wire.OutPoint{Index: 3}) - - s.inputs[inputTerminated.OutPoint()] = &SweeperInput{ - state: Excluded, - } + // Create three inputs with different states. + // - inputInit specifies a newly created input. + // - inputPendingPublish specifies an input about to be published. + // - inputTerminated specifies an input that's terminated. + var ( + inputInit = createMockInput(t, s, Init) + inputPendingPublish = createMockInput(t, s, PendingPublish) + inputTerminated = createMockInput(t, s, Excluded) + ) // Mark the test inputs. We expect the non-exist input and the // inputTerminated to be skipped, and the rest to be marked as pending // publish. set.On("Inputs").Return([]input.Input{ - inputNotExist, inputInit, inputPendingPublish, inputTerminated, + inputInit, inputPendingPublish, inputTerminated, }) s.markInputsPendingPublish(set) @@ -122,36 +130,22 @@ func TestMarkInputsPublished(t *testing.T) { dummyTR := &TxRecord{} dummyErr := errors.New("dummy error") + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: mockStore, }) - // Create three testing inputs. - // - // inputNotExist specifies an input that's not found in the sweeper's - // `inputs` map. - inputNotExist := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 1}, - } - - // inputInit specifies a newly created input. When marking this as - // published, we should see an error log as this input hasn't been - // published yet. - inputInit := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 2}, - } - s.inputs[inputInit.PreviousOutPoint] = &SweeperInput{ - state: Init, - } - - // inputPendingPublish specifies an input that's about to be published. - inputPendingPublish := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 3}, - } - s.inputs[inputPendingPublish.PreviousOutPoint] = &SweeperInput{ - state: PendingPublish, - } + // Create two inputs with different states. + // - inputInit specifies a newly created input. + // - inputPendingPublish specifies an input about to be published. + var ( + inputInit = createMockInput(t, s, Init) + inputPendingPublish = createMockInput(t, s, PendingPublish) + ) // First, check that when an error is returned from db, it's properly // returned here. @@ -171,9 +165,9 @@ func TestMarkInputsPublished(t *testing.T) { // Mark the test inputs. We expect the non-exist input and the // inputInit to be skipped, and the final input to be marked as // published. - err = s.markInputsPublished(dummyTR, []*wire.TxIn{ - inputNotExist, inputInit, inputPendingPublish, - }) + set.On("Inputs").Return([]input.Input{inputInit, inputPendingPublish}) + + err = s.markInputsPublished(dummyTR, set) require.NoError(err) // We expect unchanged number of pending inputs. @@ -181,11 +175,11 @@ func TestMarkInputsPublished(t *testing.T) { // We expect the init input's state to stay unchanged. require.Equal(Init, - s.inputs[inputInit.PreviousOutPoint].state) + s.inputs[inputInit.OutPoint()].state) // We expect the pending-publish input's is now marked as published. require.Equal(Published, - s.inputs[inputPendingPublish.PreviousOutPoint].state) + s.inputs[inputPendingPublish.OutPoint()].state) // Assert mocked statements are executed as expected. mockStore.AssertExpectations(t) @@ -202,117 +196,75 @@ func TestMarkInputsPublishFailed(t *testing.T) { // Create a mock sweeper store. mockStore := NewMockSweeperStore() + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: mockStore, }) - // Create testing inputs for each state. - // - // inputNotExist specifies an input that's not found in the sweeper's - // `inputs` map. - inputNotExist := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 1}, - } - - // inputInit specifies a newly created input. When marking this as - // published, we should see an error log as this input hasn't been - // published yet. - inputInit := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 2}, - } - s.inputs[inputInit.PreviousOutPoint] = &SweeperInput{ - state: Init, - } - - // inputPendingPublish specifies an input that's about to be published. - inputPendingPublish := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 3}, - } - s.inputs[inputPendingPublish.PreviousOutPoint] = &SweeperInput{ - state: PendingPublish, - } - - // inputPublished specifies an input that's published. - inputPublished := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 4}, - } - s.inputs[inputPublished.PreviousOutPoint] = &SweeperInput{ - state: Published, - } - - // inputPublishFailed specifies an input that's failed to be published. - inputPublishFailed := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 5}, - } - s.inputs[inputPublishFailed.PreviousOutPoint] = &SweeperInput{ - state: PublishFailed, - } - - // inputSwept specifies an input that's swept. - inputSwept := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 6}, - } - s.inputs[inputSwept.PreviousOutPoint] = &SweeperInput{ - state: Swept, - } - - // inputExcluded specifies an input that's excluded. - inputExcluded := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 7}, - } - s.inputs[inputExcluded.PreviousOutPoint] = &SweeperInput{ - state: Excluded, - } - - // inputFailed specifies an input that's failed. - inputFailed := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 8}, - } - s.inputs[inputFailed.PreviousOutPoint] = &SweeperInput{ - state: Failed, - } + // Create inputs with different states. + // - inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + // - inputPendingPublish specifies an input about to be published. + // - inputPublished specifies an input that's published. + // - inputPublishFailed specifies an input that's failed to be + // published. + // - inputSwept specifies an input that's swept. + // - inputExcluded specifies an input that's excluded. + // - inputFailed specifies an input that's failed. + var ( + inputInit = createMockInput(t, s, Init) + inputPendingPublish = createMockInput(t, s, PendingPublish) + inputPublished = createMockInput(t, s, Published) + inputPublishFailed = createMockInput(t, s, PublishFailed) + inputSwept = createMockInput(t, s, Swept) + inputExcluded = createMockInput(t, s, Excluded) + inputFailed = createMockInput(t, s, Failed) + ) - // Gather all inputs' outpoints. - pendingOps := make([]wire.OutPoint, 0, len(s.inputs)+1) - for op := range s.inputs { - pendingOps = append(pendingOps, op) - } - pendingOps = append(pendingOps, inputNotExist.PreviousOutPoint) + // Gather all inputs. + set.On("Inputs").Return([]input.Input{ + inputInit, inputPendingPublish, inputPublished, + inputPublishFailed, inputSwept, inputExcluded, inputFailed, + }) // Mark the test inputs. We expect the non-exist input and the // inputInit to be skipped, and the final input to be marked as // published. - s.markInputsPublishFailed(pendingOps) + s.markInputsPublishFailed(set) // We expect unchanged number of pending inputs. require.Len(s.inputs, 7) // We expect the init input's state to stay unchanged. require.Equal(Init, - s.inputs[inputInit.PreviousOutPoint].state) + s.inputs[inputInit.OutPoint()].state) // We expect the pending-publish input's is now marked as publish // failed. require.Equal(PublishFailed, - s.inputs[inputPendingPublish.PreviousOutPoint].state) + s.inputs[inputPendingPublish.OutPoint()].state) // We expect the published input's is now marked as publish failed. require.Equal(PublishFailed, - s.inputs[inputPublished.PreviousOutPoint].state) + s.inputs[inputPublished.OutPoint()].state) // We expect the publish failed input to stay unchanged. require.Equal(PublishFailed, - s.inputs[inputPublishFailed.PreviousOutPoint].state) + s.inputs[inputPublishFailed.OutPoint()].state) // We expect the swept input to stay unchanged. - require.Equal(Swept, s.inputs[inputSwept.PreviousOutPoint].state) + require.Equal(Swept, s.inputs[inputSwept.OutPoint()].state) // We expect the excluded input to stay unchanged. - require.Equal(Excluded, s.inputs[inputExcluded.PreviousOutPoint].state) + require.Equal(Excluded, s.inputs[inputExcluded.OutPoint()].state) // We expect the failed input to stay unchanged. - require.Equal(Failed, s.inputs[inputFailed.PreviousOutPoint].state) + require.Equal(Failed, s.inputs[inputFailed.OutPoint()].state) // Assert mocked statements are executed as expected. mockStore.AssertExpectations(t) @@ -491,6 +443,7 @@ func TestUpdateSweeperInputs(t *testing.T) { // returned. inp2.On("RequiredLockTime").Return( uint32(s.currentHeight+1), true).Once() + inp2.On("OutPoint").Return(wire.OutPoint{Index: 2}).Maybe() input7 := &SweeperInput{state: Init, Input: inp2} // Mock the input to have a CSV expiry in the future so it will NOT be @@ -499,6 +452,7 @@ func TestUpdateSweeperInputs(t *testing.T) { uint32(s.currentHeight), false).Once() inp3.On("BlocksToMaturity").Return(uint32(2)).Once() inp3.On("HeightHint").Return(uint32(s.currentHeight)).Once() + inp3.On("OutPoint").Return(wire.OutPoint{Index: 3}).Maybe() input8 := &SweeperInput{state: Init, Input: inp3} // Add the inputs to the sweeper. After the update, we should see the @@ -704,11 +658,13 @@ func TestSweepPendingInputs(t *testing.T) { setNeedWallet.On("Budget").Return(btcutil.Amount(1)).Once() setNeedWallet.On("StartingFeeRate").Return( fn.None[chainfee.SatPerKWeight]()).Once() + setNeedWallet.On("Immediate").Return(false).Once() normalSet.On("Inputs").Return(nil).Maybe() normalSet.On("DeadlineHeight").Return(testHeight).Once() normalSet.On("Budget").Return(btcutil.Amount(1)).Once() normalSet.On("StartingFeeRate").Return( fn.None[chainfee.SatPerKWeight]()).Once() + normalSet.On("Immediate").Return(false).Once() // Make pending inputs for testing. We don't need real values here as // the returned clusters are mocked. @@ -719,13 +675,8 @@ func TestSweepPendingInputs(t *testing.T) { setNeedWallet, normalSet, }) - // Mock `Broadcast` to return an error. This should cause the - // `createSweepTx` inside `sweep` to fail. This is done so we can - // terminate the method early as we are only interested in testing the - // workflow in `sweepPendingInputs`. We don't need to test `sweep` here - // as it should be tested in its own unit test. - dummyErr := errors.New("dummy error") - publisher.On("Broadcast", mock.Anything).Return(nil, dummyErr).Twice() + // Mock `Broadcast` to return a result. + publisher.On("Broadcast", mock.Anything).Return(nil).Twice() // Call the method under test. s.sweepPendingInputs(pis) @@ -736,33 +687,33 @@ func TestSweepPendingInputs(t *testing.T) { func TestHandleBumpEventTxFailed(t *testing.T) { t.Parallel() + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{}) - var ( - // Create four testing outpoints. - op1 = wire.OutPoint{Hash: chainhash.Hash{1}} - op2 = wire.OutPoint{Hash: chainhash.Hash{2}} - op3 = wire.OutPoint{Hash: chainhash.Hash{3}} - opNotExist = wire.OutPoint{Hash: chainhash.Hash{4}} - ) + // inputNotExist specifies an input that's not found in the sweeper's + // `pendingInputs` map. + inputNotExist := &input.MockInput{} + defer inputNotExist.AssertExpectations(t) + inputNotExist.On("OutPoint").Return(wire.OutPoint{Index: 0}) + opNotExist := inputNotExist.OutPoint() // Create three mock inputs. - input1 := &input.MockInput{} - defer input1.AssertExpectations(t) - - input2 := &input.MockInput{} - defer input2.AssertExpectations(t) + var ( + input1 = createMockInput(t, s, PendingPublish) + input2 = createMockInput(t, s, PendingPublish) + input3 = createMockInput(t, s, PendingPublish) + ) - input3 := &input.MockInput{} - defer input3.AssertExpectations(t) + op1 := input1.OutPoint() + op2 := input2.OutPoint() + op3 := input3.OutPoint() // Construct the initial state for the sweeper. - s.inputs = InputsMap{ - op1: &SweeperInput{Input: input1, state: PendingPublish}, - op2: &SweeperInput{Input: input2, state: PendingPublish}, - op3: &SweeperInput{Input: input3, state: PendingPublish}, - } + set.On("Inputs").Return([]input.Input{input1, input2, input3}) // Create a testing tx that spends the first two inputs. tx := &wire.MsgTx{ @@ -780,16 +731,26 @@ func TestHandleBumpEventTxFailed(t *testing.T) { Err: errDummy, } + // Create a testing bump response. + resp := &bumpResp{ + result: br, + set: set, + } + // Call the method under test. - err := s.handleBumpEvent(br) - require.ErrorIs(t, err, errDummy) + err := s.handleBumpEvent(resp) + require.NoError(t, err) // Assert the states of the first two inputs are updated. require.Equal(t, PublishFailed, s.inputs[op1].state) require.Equal(t, PublishFailed, s.inputs[op2].state) - // Assert the state of the third input is not updated. - require.Equal(t, PendingPublish, s.inputs[op3].state) + // Assert the state of the third input. + // + // NOTE: Although the tx doesn't spend it, we still mark this input as + // failed as we are treating the input set as the single source of + // truth. + require.Equal(t, PublishFailed, s.inputs[op3].state) // Assert the non-existing input is not added to the pending inputs. require.NotContains(t, s.inputs, opNotExist) @@ -808,23 +769,21 @@ func TestHandleBumpEventTxReplaced(t *testing.T) { wallet := &MockWallet{} defer wallet.AssertExpectations(t) + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: store, Wallet: wallet, }) - // Create a testing outpoint. - op := wire.OutPoint{Hash: chainhash.Hash{1}} - // Create a mock input. - inp := &input.MockInput{} - defer inp.AssertExpectations(t) + inp := createMockInput(t, s, PendingPublish) + set.On("Inputs").Return([]input.Input{inp}) - // Construct the initial state for the sweeper. - s.inputs = InputsMap{ - op: &SweeperInput{Input: inp, state: PendingPublish}, - } + op := inp.OutPoint() // Create a testing tx that spends the input. tx := &wire.MsgTx{ @@ -849,12 +808,18 @@ func TestHandleBumpEventTxReplaced(t *testing.T) { Event: TxReplaced, } + // Create a testing bump response. + resp := &bumpResp{ + result: br, + set: set, + } + // Mock the store to return an error. dummyErr := errors.New("dummy error") store.On("GetTx", tx.TxHash()).Return(nil, dummyErr).Once() // Call the method under test and assert the error is returned. - err := s.handleBumpEventTxReplaced(br) + err := s.handleBumpEventTxReplaced(resp) require.ErrorIs(t, err, dummyErr) // Mock the store to return the old tx record. @@ -869,7 +834,7 @@ func TestHandleBumpEventTxReplaced(t *testing.T) { store.On("DeleteTx", tx.TxHash()).Return(dummyErr).Once() // Call the method under test and assert the error is returned. - err = s.handleBumpEventTxReplaced(br) + err = s.handleBumpEventTxReplaced(resp) require.ErrorIs(t, err, dummyErr) // Mock the store to return the old tx record and delete it without @@ -889,7 +854,7 @@ func TestHandleBumpEventTxReplaced(t *testing.T) { wallet.On("CancelRebroadcast", tx.TxHash()).Once() // Call the method under test. - err = s.handleBumpEventTxReplaced(br) + err = s.handleBumpEventTxReplaced(resp) require.NoError(t, err) // Assert the state of the input is updated. @@ -905,22 +870,20 @@ func TestHandleBumpEventTxPublished(t *testing.T) { store := &MockSweeperStore{} defer store.AssertExpectations(t) + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: store, }) - // Create a testing outpoint. - op := wire.OutPoint{Hash: chainhash.Hash{1}} - // Create a mock input. - inp := &input.MockInput{} - defer inp.AssertExpectations(t) + inp := createMockInput(t, s, PendingPublish) + set.On("Inputs").Return([]input.Input{inp}) - // Construct the initial state for the sweeper. - s.inputs = InputsMap{ - op: &SweeperInput{Input: inp, state: PendingPublish}, - } + op := inp.OutPoint() // Create a testing tx that spends the input. tx := &wire.MsgTx{ @@ -936,6 +899,12 @@ func TestHandleBumpEventTxPublished(t *testing.T) { Event: TxPublished, } + // Create a testing bump response. + resp := &bumpResp{ + result: br, + set: set, + } + // Mock the store to save the new tx record. store.On("StoreTx", &TxRecord{ Txid: tx.TxHash(), @@ -943,7 +912,7 @@ func TestHandleBumpEventTxPublished(t *testing.T) { }).Return(nil).Once() // Call the method under test. - err := s.handleBumpEventTxPublished(br) + err := s.handleBumpEventTxPublished(resp) require.NoError(t, err) // Assert the state of the input is updated. @@ -961,25 +930,21 @@ func TestMonitorFeeBumpResult(t *testing.T) { wallet := &MockWallet{} defer wallet.AssertExpectations(t) + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + // Create a test sweeper. s := New(&UtxoSweeperConfig{ Store: store, Wallet: wallet, }) - // Create a testing outpoint. - op := wire.OutPoint{Hash: chainhash.Hash{1}} - // Create a mock input. - inp := &input.MockInput{} - defer inp.AssertExpectations(t) - - // Construct the initial state for the sweeper. - s.inputs = InputsMap{ - op: &SweeperInput{Input: inp, state: PendingPublish}, - } + inp := createMockInput(t, s, PendingPublish) // Create a testing tx that spends the input. + op := inp.OutPoint() tx := &wire.MsgTx{ LockTime: 1, TxIn: []*wire.TxIn{ @@ -1058,7 +1023,8 @@ func TestMonitorFeeBumpResult(t *testing.T) { return resultChan }, shouldExit: false, - }, { + }, + { // When the sweeper is shutting down, the monitor loop // should exit. name: "exit on sweeper shutdown", @@ -1085,7 +1051,7 @@ func TestMonitorFeeBumpResult(t *testing.T) { s.wg.Add(1) go func() { - s.monitorFeeBumpResult(resultChan) + s.monitorFeeBumpResult(set, resultChan) close(done) }() @@ -1111,3 +1077,125 @@ func TestMonitorFeeBumpResult(t *testing.T) { }) } } + +// TestMarkInputsFailed checks that given a list of inputs with different +// states, the method `markInputsFailed` correctly marks the inputs as failed. +func TestMarkInputsFailed(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{}) + + // Create testing inputs for each state. + // - inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + // - inputPendingPublish specifies an input about to be published. + // - inputPublished specifies an input that's published. + // - inputPublishFailed specifies an input that's failed to be + // published. + // - inputSwept specifies an input that's swept. + // - inputExcluded specifies an input that's excluded. + // - inputFailed specifies an input that's failed. + var ( + inputInit = createMockInput(t, s, Init) + inputPendingPublish = createMockInput(t, s, PendingPublish) + inputPublished = createMockInput(t, s, Published) + inputPublishFailed = createMockInput(t, s, PublishFailed) + inputSwept = createMockInput(t, s, Swept) + inputExcluded = createMockInput(t, s, Excluded) + inputFailed = createMockInput(t, s, Failed) + ) + + // Gather all inputs. + set.On("Inputs").Return([]input.Input{ + inputInit, inputPendingPublish, inputPublished, + inputPublishFailed, inputSwept, inputExcluded, inputFailed, + }) + + // Mark the test inputs. We expect the non-exist input and + // inputSwept/inputExcluded/inputFailed to be skipped. + s.markInputsFailed(set, errDummy) + + // We expect unchanged number of pending inputs. + require.Len(s.inputs, 7) + + // We expect the init input's to be marked as failed. + require.Equal(Failed, s.inputs[inputInit.OutPoint()].state) + + // We expect the pending-publish input to be marked as failed. + require.Equal(Failed, s.inputs[inputPendingPublish.OutPoint()].state) + + // We expect the published input to be marked as failed. + require.Equal(Failed, s.inputs[inputPublished.OutPoint()].state) + + // We expect the publish failed input to be markd as failed. + require.Equal(Failed, s.inputs[inputPublishFailed.OutPoint()].state) + + // We expect the swept input to stay unchanged. + require.Equal(Swept, s.inputs[inputSwept.OutPoint()].state) + + // We expect the excluded input to stay unchanged. + require.Equal(Excluded, s.inputs[inputExcluded.OutPoint()].state) + + // We expect the failed input to stay unchanged. + require.Equal(Failed, s.inputs[inputFailed.OutPoint()].state) +} + +// TestHandleBumpEventTxFatal checks that `handleBumpEventTxFatal` correctly +// handles a `TxFatal` event. +func TestHandleBumpEventTxFatal(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a mock input set. We are not testing `markInputFailed` here, + // so the actual set doesn't matter. + set := &MockInputSet{} + defer set.AssertExpectations(t) + set.On("Inputs").Return(nil) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a dummy tx. + tx := &wire.MsgTx{ + LockTime: 1, + } + + // Create a testing bump response. + result := &BumpResult{ + Err: errDummy, + Tx: tx, + } + resp := &bumpResp{ + result: result, + set: set, + } + + // Mock the store to return an error. + store.On("DeleteTx", mock.Anything).Return(errDummy).Once() + + // Call the method under test and assert the error is returned. + err := s.handleBumpEventTxFatal(resp) + rt.ErrorIs(err, errDummy) + + // Mock the store to return nil. + store.On("DeleteTx", mock.Anything).Return(nil).Once() + + // Call the method under test and assert no error is returned. + err = s.handleBumpEventTxFatal(resp) + rt.NoError(err) +} diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index ce144a8eb3..b80d52b0ea 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -64,6 +64,13 @@ type InputSet interface { // StartingFeeRate returns the max starting fee rate found in the // inputs. StartingFeeRate() fn.Option[chainfee.SatPerKWeight] + + // Immediate returns a boolean to indicate whether the tx made from + // this input set should be published immediately. + // + // TODO(yy): create a new method `Params` to combine the informational + // methods DeadlineHeight, Budget, StartingFeeRate and Immediate. + Immediate() bool } // createWalletTxInput converts a wallet utxo into an object that can be added @@ -141,7 +148,7 @@ func validateInputs(inputs []SweeperInput, deadlineHeight int32) error { // dedupInputs is a set used to track unique outpoints of the inputs. dedupInputs := fn.NewSet( // Iterate all the inputs and map the function. - fn.Map(func(inp SweeperInput) wire.OutPoint { + fn.Map(inputs, func(inp SweeperInput) wire.OutPoint { // If the input has a deadline height, we'll check if // it's the same as the specified. inp.params.DeadlineHeight.WhenSome(func(h int32) { @@ -156,7 +163,7 @@ func validateInputs(inputs []SweeperInput, deadlineHeight int32) error { }) return inp.OutPoint() - }, inputs)..., + })..., ) // Make sure the inputs share the same deadline height when there is @@ -414,3 +421,18 @@ func (b *BudgetInputSet) StartingFeeRate() fn.Option[chainfee.SatPerKWeight] { return startingFeeRate } + +// Immediate returns whether the inputs should be swept immediately. +// +// NOTE: part of the InputSet interface. +func (b *BudgetInputSet) Immediate() bool { + for _, inp := range b.inputs { + // As long as one of the inputs is immediate, the whole set is + // immediate. + if inp.params.Immediate { + return true + } + } + + return false +} diff --git a/sweep/tx_input_set_test.go b/sweep/tx_input_set_test.go index 8d0850b20d..73f056a964 100644 --- a/sweep/tx_input_set_test.go +++ b/sweep/tx_input_set_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" diff --git a/sweep/walletsweep.go b/sweep/walletsweep.go index 81458fbfb0..3f790dc66f 100644 --- a/sweep/walletsweep.go +++ b/sweep/walletsweep.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/sweep/walletsweep_test.go b/sweep/walletsweep_test.go index 968d9cb4fb..c7a5dfc221 100644 --- a/sweep/walletsweep_test.go +++ b/sweep/walletsweep_test.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/wtxmgr" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/watchtower/blob/justice_kit.go b/watchtower/blob/justice_kit.go index 7780239f07..9dc1af6258 100644 --- a/watchtower/blob/justice_kit.go +++ b/watchtower/blob/justice_kit.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" diff --git a/watchtower/blob/justice_kit_test.go b/watchtower/blob/justice_kit_test.go index a1d6ec9f2c..0d23e2e0fc 100644 --- a/watchtower/blob/justice_kit_test.go +++ b/watchtower/blob/justice_kit_test.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" diff --git a/watchtower/lookout/justice_descriptor_test.go b/watchtower/lookout/justice_descriptor_test.go index 5045b4a0f4..ded2cd6031 100644 --- a/watchtower/lookout/justice_descriptor_test.go +++ b/watchtower/lookout/justice_descriptor_test.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 7eb34f6e37..62d7609469 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index f3a4d5bf4e..e842876b65 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -19,7 +19,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go index 01a9fa01ef..7a39c8ff73 100644 --- a/watchtower/wtclient/manager.go +++ b/watchtower/wtclient/manager.go @@ -12,7 +12,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/subscribe" diff --git a/watchtower/wtdb/client_chan_summary.go b/watchtower/wtdb/client_chan_summary.go index 6fec34c842..9ab77377b4 100644 --- a/watchtower/wtdb/client_chan_summary.go +++ b/watchtower/wtdb/client_chan_summary.go @@ -3,7 +3,7 @@ package wtdb import ( "io" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index b6f6affce6..6e6adacc02 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -9,7 +9,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" diff --git a/zpay32/decode.go b/zpay32/decode.go index 61099cf2f0..76c2c1ecf4 100644 --- a/zpay32/decode.go +++ b/zpay32/decode.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/btcutil/bech32" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/zpay32/encode.go b/zpay32/encode.go index 3e2d799776..43ccd5ecb1 100644 --- a/zpay32/encode.go +++ b/zpay32/encode.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil/bech32" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/zpay32/invoice.go b/zpay32/invoice.go index 7c18253eb0..9c5d86ce2f 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/zpay32/invoice_test.go b/zpay32/invoice_test.go index a4753431e7..55718007db 100644 --- a/zpay32/invoice_test.go +++ b/zpay32/invoice_test.go @@ -17,7 +17,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/require" )