diff --git a/aliasmgr/aliasmgr.go b/aliasmgr/aliasmgr.go
index f06cb53d79..a3227b18b8 100644
--- a/aliasmgr/aliasmgr.go
+++ b/aliasmgr/aliasmgr.go
@@ -5,7 +5,7 @@ import (
"fmt"
"sync"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
@@ -432,9 +432,9 @@ func (m *Manager) DeleteLocalAlias(alias,
}
// We'll filter the alias set and remove the alias from it.
- aliasSet = fn.Filter(func(a lnwire.ShortChannelID) bool {
+ aliasSet = fn.Filter(aliasSet, func(a lnwire.ShortChannelID) bool {
return a.ToUint64() != alias.ToUint64()
- }, aliasSet)
+ })
// If the alias set is empty, we'll delete the base SCID from the
// baseToSet map.
@@ -514,11 +514,17 @@ func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) {
// haveAlias returns true if the passed alias is already assigned to a
// channel in the baseToSet map.
haveAlias := func(maybeNextAlias lnwire.ShortChannelID) bool {
- return fn.Any(func(aliasList []lnwire.ShortChannelID) bool {
- return fn.Any(func(alias lnwire.ShortChannelID) bool {
- return alias == maybeNextAlias
- }, aliasList)
- }, maps.Values(m.baseToSet))
+ return fn.Any(
+ maps.Values(m.baseToSet),
+ func(aliasList []lnwire.ShortChannelID) bool {
+ return fn.Any(
+ aliasList,
+ func(alias lnwire.ShortChannelID) bool {
+ return alias == maybeNextAlias
+ },
+ )
+ },
+ )
}
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
diff --git a/chainio/README.md b/chainio/README.md
new file mode 100644
index 0000000000..b11e38157c
--- /dev/null
+++ b/chainio/README.md
@@ -0,0 +1,152 @@
+# Chainio
+
+`chainio` is a package designed to provide blockchain data access to various
+subsystems within `lnd`. When a new block is received, it is encapsulated in a
+`Blockbeat` object and disseminated to all registered consumers. Consumers may
+receive these updates either concurrently or sequentially, based on their
+registration configuration, ensuring that each subsystem maintains a
+synchronized view of the current block state.
+
+The main components include:
+
+- `Blockbeat`: An interface that provides information about the block.
+
+- `Consumer`: An interface that specifies how subsystems handle the blockbeat.
+
+- `BlockbeatDispatcher`: The core service responsible for receiving each block
+ and distributing it to all consumers.
+
+Additionally, the `BeatConsumer` struct provides a partial implementation of
+the `Consumer` interface. This struct helps reduce code duplication, allowing
+subsystems to avoid re-implementing the `ProcessBlock` method and provides a
+commonly used `NotifyBlockProcessed` method.
+
+
+### Register a Consumer
+
+Consumers within the same queue are notified **sequentially**, while all queues
+are notified **concurrently**. A queue consists of a slice of consumers, which
+are notified in left-to-right order. Developers are responsible for determining
+dependencies in block consumption across subsystems: independent subsystems
+should be notified concurrently, whereas dependent subsystems should be
+notified sequentially.
+
+To notify the consumers concurrently, put them in different queues,
+```go
+// consumer1 and consumer2 will be notified concurrently.
+queue1 := []chainio.Consumer{consumer1}
+blockbeatDispatcher.RegisterQueue(consumer1)
+
+queue2 := []chainio.Consumer{consumer2}
+blockbeatDispatcher.RegisterQueue(consumer2)
+```
+
+To notify the consumers sequentially, put them in the same queue,
+```go
+// consumers will be notified sequentially via,
+// consumer1 -> consumer2 -> consumer3
+queue := []chainio.Consumer{
+ consumer1,
+ consumer2,
+ consumer3,
+}
+blockbeatDispatcher.RegisterQueue(queue)
+```
+
+### Implement the `Consumer` Interface
+
+Implementing the `Consumer` interface is straightforward. Below is an example
+of how
+[`sweep.TxPublisher`](https://github.com/lightningnetwork/lnd/blob/5cec466fad44c582a64cfaeb91f6d5fd302fcf85/sweep/fee_bumper.go#L310)
+implements this interface.
+
+To start, embed the partial implementation `chainio.BeatConsumer`, which
+already provides the `ProcessBlock` implementation and commonly used
+`NotifyBlockProcessed` method, and exposes `BlockbeatChan` for the consumer to
+receive blockbeats.
+
+```go
+type TxPublisher struct {
+ started atomic.Bool
+ stopped atomic.Bool
+
+ chainio.BeatConsumer
+
+ ...
+```
+
+We should also remember to initialize this `BeatConsumer`,
+
+```go
+...
+// Mount the block consumer.
+tp.BeatConsumer = chainio.NewBeatConsumer(tp.quit, tp.Name())
+```
+
+Finally, in the main event loop, read from `BlockbeatChan`, process the
+received blockbeat, and, crucially, call `tp.NotifyBlockProcessed` to inform
+the blockbeat dispatcher that processing is complete.
+
+```go
+for {
+ select {
+ case beat := <-tp.BlockbeatChan:
+ // Consume this blockbeat, usually it means updating the subsystem
+ // using the new block data.
+
+ // Notify we've processed the block.
+ tp.NotifyBlockProcessed(beat, nil)
+
+ ...
+```
+
+### Existing Queues
+
+Currently, we have a single queue of consumers dedicated to handling force
+closures. This queue includes `ChainArbitrator`, `UtxoSweeper`, and
+`TxPublisher`, with `ChainArbitrator` managing two internal consumers:
+`chainWatcher` and `ChannelArbitrator`. The blockbeat flows sequentially
+through the chain as follows: `ChainArbitrator => chainWatcher =>
+ChannelArbitrator => UtxoSweeper => TxPublisher`. The following diagram
+illustrates the flow within the public subsystems.
+
+```mermaid
+sequenceDiagram
+ autonumber
+ participant bb as BlockBeat
+ participant cc as ChainArb
+ participant us as UtxoSweeper
+ participant tp as TxPublisher
+
+ note left of bb: 0. received block x,
dispatching...
+
+ note over bb,cc: 1. send block x to ChainArb,
wait for its done signal
+ bb->>cc: block x
+ rect rgba(165, 0, 85, 0.8)
+ critical signal processed
+ cc->>bb: processed block
+ option Process error or timeout
+ bb->>bb: error and exit
+ end
+ end
+
+ note over bb,us: 2. send block x to UtxoSweeper, wait for its done signal
+ bb->>us: block x
+ rect rgba(165, 0, 85, 0.8)
+ critical signal processed
+ us->>bb: processed block
+ option Process error or timeout
+ bb->>bb: error and exit
+ end
+ end
+
+ note over bb,tp: 3. send block x to TxPublisher, wait for its done signal
+ bb->>tp: block x
+ rect rgba(165, 0, 85, 0.8)
+ critical signal processed
+ tp->>bb: processed block
+ option Process error or timeout
+ bb->>bb: error and exit
+ end
+ end
+```
diff --git a/chainio/blockbeat.go b/chainio/blockbeat.go
new file mode 100644
index 0000000000..79188657fe
--- /dev/null
+++ b/chainio/blockbeat.go
@@ -0,0 +1,54 @@
+package chainio
+
+import (
+ "fmt"
+
+ "github.com/btcsuite/btclog/v2"
+ "github.com/lightningnetwork/lnd/chainntnfs"
+)
+
+// Beat implements the Blockbeat interface. It contains the block epoch and a
+// customized logger.
+//
+// TODO(yy): extend this to check for confirmation status - which serves as the
+// single source of truth, to avoid the potential race between receiving blocks
+// and `GetTransactionDetails/RegisterSpendNtfn/RegisterConfirmationsNtfn`.
+type Beat struct {
+ // epoch is the current block epoch the blockbeat is aware of.
+ epoch chainntnfs.BlockEpoch
+
+ // log is the customized logger for the blockbeat which prints the
+ // block height.
+ log btclog.Logger
+}
+
+// Compile-time check to ensure Beat satisfies the Blockbeat interface.
+var _ Blockbeat = (*Beat)(nil)
+
+// NewBeat creates a new beat with the specified block epoch and a customized
+// logger.
+func NewBeat(epoch chainntnfs.BlockEpoch) *Beat {
+ b := &Beat{
+ epoch: epoch,
+ }
+
+ // Create a customized logger for the blockbeat.
+ logPrefix := fmt.Sprintf("Height[%6d]:", b.Height())
+ b.log = clog.WithPrefix(logPrefix)
+
+ return b
+}
+
+// Height returns the height of the block epoch.
+//
+// NOTE: Part of the Blockbeat interface.
+func (b *Beat) Height() int32 {
+ return b.epoch.Height
+}
+
+// logger returns the logger for the blockbeat.
+//
+// NOTE: Part of the private blockbeat interface.
+func (b *Beat) logger() btclog.Logger {
+ return b.log
+}
diff --git a/chainio/blockbeat_test.go b/chainio/blockbeat_test.go
new file mode 100644
index 0000000000..9326651b38
--- /dev/null
+++ b/chainio/blockbeat_test.go
@@ -0,0 +1,28 @@
+package chainio
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/lightningnetwork/lnd/chainntnfs"
+ "github.com/stretchr/testify/require"
+)
+
+var errDummy = errors.New("dummy error")
+
+// TestNewBeat tests the NewBeat and Height functions.
+func TestNewBeat(t *testing.T) {
+ t.Parallel()
+
+ // Create a testing epoch.
+ epoch := chainntnfs.BlockEpoch{
+ Height: 1,
+ }
+
+ // Create the beat and check the internal state.
+ beat := NewBeat(epoch)
+ require.Equal(t, epoch, beat.epoch)
+
+ // Check the height function.
+ require.Equal(t, epoch.Height, beat.Height())
+}
diff --git a/chainio/consumer.go b/chainio/consumer.go
new file mode 100644
index 0000000000..a9ec25745b
--- /dev/null
+++ b/chainio/consumer.go
@@ -0,0 +1,113 @@
+package chainio
+
+// BeatConsumer defines a supplementary component that should be used by
+// subsystems which implement the `Consumer` interface. It partially implements
+// the `Consumer` interface by providing the method `ProcessBlock` such that
+// subsystems don't need to re-implement it.
+//
+// While inheritance is not commonly used in Go, subsystems embedding this
+// struct cannot pass the interface check for `Consumer` because the `Name`
+// method is not implemented, which gives us a "mortise and tenon" structure.
+// In addition to reducing code duplication, this design allows `ProcessBlock`
+// to work on the concrete type `Beat` to access its internal states.
+type BeatConsumer struct {
+ // BlockbeatChan is a channel to receive blocks from Blockbeat. The
+ // received block contains the best known height and the txns confirmed
+ // in this block.
+ BlockbeatChan chan Blockbeat
+
+ // name is the name of the consumer which embeds the BlockConsumer.
+ name string
+
+ // quit is a channel that closes when the BlockConsumer is shutting
+ // down.
+ //
+ // NOTE: this quit channel should be mounted to the same quit channel
+ // used by the subsystem.
+ quit chan struct{}
+
+ // errChan is a buffered chan that receives an error returned from
+ // processing this block.
+ errChan chan error
+}
+
+// NewBeatConsumer creates a new BlockConsumer.
+func NewBeatConsumer(quit chan struct{}, name string) BeatConsumer {
+ // Refuse to start `lnd` if the quit channel is not initialized. We
+ // treat this case as if we are facing a nil pointer dereference, as
+ // there's no point to return an error here, which will cause the node
+ // to fail to be started anyway.
+ if quit == nil {
+ panic("quit channel is nil")
+ }
+
+ b := BeatConsumer{
+ BlockbeatChan: make(chan Blockbeat),
+ name: name,
+ errChan: make(chan error, 1),
+ quit: quit,
+ }
+
+ return b
+}
+
+// ProcessBlock takes a blockbeat and sends it to the consumer's blockbeat
+// channel. It will send it to the subsystem's BlockbeatChan, and block until
+// the processed result is received from the subsystem. The subsystem must call
+// `NotifyBlockProcessed` after it has finished processing the block.
+//
+// NOTE: part of the `chainio.Consumer` interface.
+func (b *BeatConsumer) ProcessBlock(beat Blockbeat) error {
+ // Update the current height.
+ beat.logger().Tracef("set current height for [%s]", b.name)
+
+ select {
+ // Send the beat to the blockbeat channel. It's expected that the
+ // consumer will read from this channel and process the block. Once
+ // processed, it should return the error or nil to the beat.Err chan.
+ case b.BlockbeatChan <- beat:
+ beat.logger().Tracef("Sent blockbeat to [%s]", b.name)
+
+ case <-b.quit:
+ beat.logger().Debugf("[%s] received shutdown before sending "+
+ "beat", b.name)
+
+ return nil
+ }
+
+ // Check the consumer's err chan. We expect the consumer to call
+ // `beat.NotifyBlockProcessed` to send the error back here.
+ select {
+ case err := <-b.errChan:
+ beat.logger().Debugf("[%s] processed beat: err=%v", b.name, err)
+
+ return err
+
+ case <-b.quit:
+ beat.logger().Debugf("[%s] received shutdown", b.name)
+ }
+
+ return nil
+}
+
+// NotifyBlockProcessed signals that the block has been processed. It takes the
+// blockbeat being processed and an error resulted from processing it. This
+// error is then sent back to the consumer's err chan to unblock
+// `ProcessBlock`.
+//
+// NOTE: This method must be called by the subsystem after it has finished
+// processing the block.
+func (b *BeatConsumer) NotifyBlockProcessed(beat Blockbeat, err error) {
+ // Update the current height.
+ beat.logger().Debugf("[%s]: notifying beat processed", b.name)
+
+ select {
+ case b.errChan <- err:
+ beat.logger().Debugf("[%s]: notified beat processed, err=%v",
+ b.name, err)
+
+ case <-b.quit:
+ beat.logger().Debugf("[%s] received shutdown before notifying "+
+ "beat processed", b.name)
+ }
+}
diff --git a/chainio/consumer_test.go b/chainio/consumer_test.go
new file mode 100644
index 0000000000..d1cabf3168
--- /dev/null
+++ b/chainio/consumer_test.go
@@ -0,0 +1,202 @@
+package chainio
+
+import (
+ "testing"
+ "time"
+
+ "github.com/lightningnetwork/lnd/fn/v2"
+ "github.com/stretchr/testify/require"
+)
+
+// TestNewBeatConsumer tests the NewBeatConsumer function.
+func TestNewBeatConsumer(t *testing.T) {
+ t.Parallel()
+
+ quitChan := make(chan struct{})
+ name := "test"
+
+ // Test the NewBeatConsumer function.
+ b := NewBeatConsumer(quitChan, name)
+
+ // Assert the state.
+ require.Equal(t, quitChan, b.quit)
+ require.Equal(t, name, b.name)
+ require.NotNil(t, b.BlockbeatChan)
+}
+
+// TestProcessBlockSuccess tests when the block is processed successfully, no
+// error is returned.
+func TestProcessBlockSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Create a test consumer.
+ quitChan := make(chan struct{})
+ b := NewBeatConsumer(quitChan, "test")
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Mock the consumer's err chan.
+ consumerErrChan := make(chan error, 1)
+ b.errChan = consumerErrChan
+
+ // Call the method under test.
+ resultChan := make(chan error, 1)
+ go func() {
+ resultChan <- b.ProcessBlock(mockBeat)
+ }()
+
+ // Assert the beat is sent to the blockbeat channel.
+ beat, err := fn.RecvOrTimeout(b.BlockbeatChan, time.Second)
+ require.NoError(t, err)
+ require.Equal(t, mockBeat, beat)
+
+ // Send nil to the consumer's error channel.
+ consumerErrChan <- nil
+
+ // Assert the result of ProcessBlock is nil.
+ result, err := fn.RecvOrTimeout(resultChan, time.Second)
+ require.NoError(t, err)
+ require.Nil(t, result)
+}
+
+// TestProcessBlockConsumerQuitBeforeSend tests when the consumer is quit
+// before sending the beat, the method returns immediately.
+func TestProcessBlockConsumerQuitBeforeSend(t *testing.T) {
+ t.Parallel()
+
+ // Create a test consumer.
+ quitChan := make(chan struct{})
+ b := NewBeatConsumer(quitChan, "test")
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Call the method under test.
+ resultChan := make(chan error, 1)
+ go func() {
+ resultChan <- b.ProcessBlock(mockBeat)
+ }()
+
+ // Instead of reading the BlockbeatChan, close the quit channel.
+ close(quitChan)
+
+ // Assert ProcessBlock returned nil.
+ result, err := fn.RecvOrTimeout(resultChan, time.Second)
+ require.NoError(t, err)
+ require.Nil(t, result)
+}
+
+// TestProcessBlockConsumerQuitAfterSend tests when the consumer is quit after
+// sending the beat, the method returns immediately.
+func TestProcessBlockConsumerQuitAfterSend(t *testing.T) {
+ t.Parallel()
+
+ // Create a test consumer.
+ quitChan := make(chan struct{})
+ b := NewBeatConsumer(quitChan, "test")
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Mock the consumer's err chan.
+ consumerErrChan := make(chan error, 1)
+ b.errChan = consumerErrChan
+
+ // Call the method under test.
+ resultChan := make(chan error, 1)
+ go func() {
+ resultChan <- b.ProcessBlock(mockBeat)
+ }()
+
+ // Assert the beat is sent to the blockbeat channel.
+ beat, err := fn.RecvOrTimeout(b.BlockbeatChan, time.Second)
+ require.NoError(t, err)
+ require.Equal(t, mockBeat, beat)
+
+ // Instead of sending nil to the consumer's error channel, close the
+ // quit chanel.
+ close(quitChan)
+
+ // Assert ProcessBlock returned nil.
+ result, err := fn.RecvOrTimeout(resultChan, time.Second)
+ require.NoError(t, err)
+ require.Nil(t, result)
+}
+
+// TestNotifyBlockProcessedSendErr asserts the error can be sent and read by
+// the beat via NotifyBlockProcessed.
+func TestNotifyBlockProcessedSendErr(t *testing.T) {
+ t.Parallel()
+
+ // Create a test consumer.
+ quitChan := make(chan struct{})
+ b := NewBeatConsumer(quitChan, "test")
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Mock the consumer's err chan.
+ consumerErrChan := make(chan error, 1)
+ b.errChan = consumerErrChan
+
+ // Call the method under test.
+ done := make(chan error)
+ go func() {
+ defer close(done)
+ b.NotifyBlockProcessed(mockBeat, errDummy)
+ }()
+
+ // Assert the error is sent to the beat's err chan.
+ result, err := fn.RecvOrTimeout(consumerErrChan, time.Second)
+ require.NoError(t, err)
+ require.ErrorIs(t, result, errDummy)
+
+ // Assert the done channel is closed.
+ result, err = fn.RecvOrTimeout(done, time.Second)
+ require.NoError(t, err)
+ require.Nil(t, result)
+}
+
+// TestNotifyBlockProcessedOnQuit asserts NotifyBlockProcessed exits
+// immediately when the quit channel is closed.
+func TestNotifyBlockProcessedOnQuit(t *testing.T) {
+ t.Parallel()
+
+ // Create a test consumer.
+ quitChan := make(chan struct{})
+ b := NewBeatConsumer(quitChan, "test")
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Mock the consumer's err chan - we don't buffer it so it will block
+ // on sending the error.
+ consumerErrChan := make(chan error)
+ b.errChan = consumerErrChan
+
+ // Call the method under test.
+ done := make(chan error)
+ go func() {
+ defer close(done)
+ b.NotifyBlockProcessed(mockBeat, errDummy)
+ }()
+
+ // Close the quit channel so the method will return.
+ close(b.quit)
+
+ // Assert the done channel is closed.
+ result, err := fn.RecvOrTimeout(done, time.Second)
+ require.NoError(t, err)
+ require.Nil(t, result)
+}
diff --git a/chainio/dispatcher.go b/chainio/dispatcher.go
new file mode 100644
index 0000000000..87bc21fbaa
--- /dev/null
+++ b/chainio/dispatcher.go
@@ -0,0 +1,296 @@
+package chainio
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/btcsuite/btclog/v2"
+ "github.com/lightningnetwork/lnd/chainntnfs"
+ "golang.org/x/sync/errgroup"
+)
+
+// DefaultProcessBlockTimeout is the timeout value used when waiting for one
+// consumer to finish processing the new block epoch.
+var DefaultProcessBlockTimeout = 60 * time.Second
+
+// ErrProcessBlockTimeout is the error returned when a consumer takes too long
+// to process the block.
+var ErrProcessBlockTimeout = errors.New("process block timeout")
+
+// BlockbeatDispatcher is a service that handles dispatching new blocks to
+// `lnd`'s subsystems. During startup, subsystems that are block-driven should
+// implement the `Consumer` interface and register themselves via
+// `RegisterQueue`. When two subsystems are independent of each other, they
+// should be registered in different queues so blocks are notified concurrently.
+// Otherwise, when living in the same queue, the subsystems are notified of the
+// new blocks sequentially, which means it's critical to understand the
+// relationship of these systems to properly handle the order.
+type BlockbeatDispatcher struct {
+ wg sync.WaitGroup
+
+ // notifier is used to receive new block epochs.
+ notifier chainntnfs.ChainNotifier
+
+ // beat is the latest blockbeat received.
+ beat Blockbeat
+
+ // consumerQueues is a map of consumers that will receive blocks. Its
+ // key is a unique counter and its value is a queue of consumers. Each
+ // queue is notified concurrently, and consumers in the same queue is
+ // notified sequentially.
+ consumerQueues map[uint32][]Consumer
+
+ // counter is used to assign a unique id to each queue.
+ counter atomic.Uint32
+
+ // quit is used to signal the BlockbeatDispatcher to stop.
+ quit chan struct{}
+}
+
+// NewBlockbeatDispatcher returns a new blockbeat dispatcher instance.
+func NewBlockbeatDispatcher(n chainntnfs.ChainNotifier) *BlockbeatDispatcher {
+ return &BlockbeatDispatcher{
+ notifier: n,
+ quit: make(chan struct{}),
+ consumerQueues: make(map[uint32][]Consumer),
+ }
+}
+
+// RegisterQueue takes a list of consumers and registers them in the same
+// queue.
+//
+// NOTE: these consumers are notified sequentially.
+func (b *BlockbeatDispatcher) RegisterQueue(consumers []Consumer) {
+ qid := b.counter.Add(1)
+
+ b.consumerQueues[qid] = append(b.consumerQueues[qid], consumers...)
+ clog.Infof("Registered queue=%d with %d blockbeat consumers", qid,
+ len(consumers))
+
+ for _, c := range consumers {
+ clog.Debugf("Consumer [%s] registered in queue %d", c.Name(),
+ qid)
+ }
+}
+
+// Start starts the blockbeat dispatcher - it registers a block notification
+// and monitors and dispatches new blocks in a goroutine. It will refuse to
+// start if there are no registered consumers.
+func (b *BlockbeatDispatcher) Start() error {
+ // Make sure consumers are registered.
+ if len(b.consumerQueues) == 0 {
+ return fmt.Errorf("no consumers registered")
+ }
+
+ // Start listening to new block epochs. We should get a notification
+ // with the current best block immediately.
+ blockEpochs, err := b.notifier.RegisterBlockEpochNtfn(nil)
+ if err != nil {
+ return fmt.Errorf("register block epoch ntfn: %w", err)
+ }
+
+ clog.Infof("BlockbeatDispatcher is starting with %d consumer queues",
+ len(b.consumerQueues))
+ defer clog.Debug("BlockbeatDispatcher started")
+
+ b.wg.Add(1)
+ go b.dispatchBlocks(blockEpochs)
+
+ return nil
+}
+
+// Stop shuts down the blockbeat dispatcher.
+func (b *BlockbeatDispatcher) Stop() {
+ clog.Info("BlockbeatDispatcher is stopping")
+ defer clog.Debug("BlockbeatDispatcher stopped")
+
+ // Signal the dispatchBlocks goroutine to stop.
+ close(b.quit)
+ b.wg.Wait()
+}
+
+func (b *BlockbeatDispatcher) log() btclog.Logger {
+ return b.beat.logger()
+}
+
+// dispatchBlocks listens to new block epoch and dispatches it to all the
+// consumers. Each queue is notified concurrently, and the consumers in the
+// same queue are notified sequentially.
+//
+// NOTE: Must be run as a goroutine.
+func (b *BlockbeatDispatcher) dispatchBlocks(
+ blockEpochs *chainntnfs.BlockEpochEvent) {
+
+ defer b.wg.Done()
+ defer blockEpochs.Cancel()
+
+ for {
+ select {
+ case blockEpoch, ok := <-blockEpochs.Epochs:
+ if !ok {
+ clog.Debugf("Block epoch channel closed")
+
+ return
+ }
+
+ clog.Infof("Received new block %v at height %d, "+
+ "notifying consumers...", blockEpoch.Hash,
+ blockEpoch.Height)
+
+ // Record the time it takes the consumer to process
+ // this block.
+ start := time.Now()
+
+ // Update the current block epoch.
+ b.beat = NewBeat(*blockEpoch)
+
+ // Notify all consumers.
+ err := b.notifyQueues()
+ if err != nil {
+ b.log().Errorf("Notify block failed: %v", err)
+ }
+
+ b.log().Infof("Notified all consumers on new block "+
+ "in %v", time.Since(start))
+
+ case <-b.quit:
+ b.log().Debugf("BlockbeatDispatcher quit signal " +
+ "received")
+
+ return
+ }
+ }
+}
+
+// notifyQueues notifies each queue concurrently about the latest block epoch.
+func (b *BlockbeatDispatcher) notifyQueues() error {
+ // errChans is a map of channels that will be used to receive errors
+ // returned from notifying the consumers.
+ errChans := make(map[uint32]chan error, len(b.consumerQueues))
+
+ // Notify each queue in goroutines.
+ for qid, consumers := range b.consumerQueues {
+ b.log().Debugf("Notifying queue=%d with %d consumers", qid,
+ len(consumers))
+
+ // Create a signal chan.
+ errChan := make(chan error, 1)
+ errChans[qid] = errChan
+
+ // Notify each queue concurrently.
+ go func(qid uint32, c []Consumer, beat Blockbeat) {
+ // Notify each consumer in this queue sequentially.
+ errChan <- DispatchSequential(beat, c)
+ }(qid, consumers, b.beat)
+ }
+
+ // Wait for all consumers in each queue to finish.
+ for qid, errChan := range errChans {
+ select {
+ case err := <-errChan:
+ if err != nil {
+ return fmt.Errorf("queue=%d got err: %w", qid,
+ err)
+ }
+
+ b.log().Debugf("Notified queue=%d", qid)
+
+ case <-b.quit:
+ b.log().Debugf("BlockbeatDispatcher quit signal " +
+ "received, exit notifyQueues")
+
+ return nil
+ }
+ }
+
+ return nil
+}
+
+// DispatchSequential takes a list of consumers and notify them about the new
+// epoch sequentially. It requires the consumer to finish processing the block
+// within the specified time, otherwise a timeout error is returned.
+func DispatchSequential(b Blockbeat, consumers []Consumer) error {
+ for _, c := range consumers {
+ // Send the beat to the consumer.
+ err := notifyAndWait(b, c, DefaultProcessBlockTimeout)
+ if err != nil {
+ b.logger().Errorf("Failed to process block: %v", err)
+
+ return err
+ }
+ }
+
+ return nil
+}
+
+// DispatchConcurrent notifies each consumer concurrently about the blockbeat.
+// It requires the consumer to finish processing the block within the specified
+// time, otherwise a timeout error is returned.
+func DispatchConcurrent(b Blockbeat, consumers []Consumer) error {
+ eg := &errgroup.Group{}
+
+ // Notify each queue in goroutines.
+ for _, c := range consumers {
+ // Notify each consumer concurrently.
+ eg.Go(func() error {
+ // Send the beat to the consumer.
+ err := notifyAndWait(b, c, DefaultProcessBlockTimeout)
+
+ // Exit early if there's no error.
+ if err == nil {
+ return nil
+ }
+
+ b.logger().Errorf("Consumer=%v failed to process "+
+ "block: %v", c.Name(), err)
+
+ return err
+ })
+ }
+
+ // Wait for all consumers in each queue to finish.
+ if err := eg.Wait(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// notifyAndWait sends the blockbeat to the specified consumer. It requires the
+// consumer to finish processing the block within the specified time, otherwise
+// a timeout error is returned.
+func notifyAndWait(b Blockbeat, c Consumer, timeout time.Duration) error {
+ b.logger().Debugf("Waiting for consumer[%s] to process it", c.Name())
+
+ // Record the time it takes the consumer to process this block.
+ start := time.Now()
+
+ errChan := make(chan error, 1)
+ go func() {
+ errChan <- c.ProcessBlock(b)
+ }()
+
+ // We expect the consumer to finish processing this block under 30s,
+ // otherwise a timeout error is returned.
+ select {
+ case err := <-errChan:
+ if err == nil {
+ break
+ }
+
+ return fmt.Errorf("%s got err in ProcessBlock: %w", c.Name(),
+ err)
+
+ case <-time.After(timeout):
+ return fmt.Errorf("consumer %s: %w", c.Name(),
+ ErrProcessBlockTimeout)
+ }
+
+ b.logger().Debugf("Consumer[%s] processed block in %v", c.Name(),
+ time.Since(start))
+
+ return nil
+}
diff --git a/chainio/dispatcher_test.go b/chainio/dispatcher_test.go
new file mode 100644
index 0000000000..11abbeb65e
--- /dev/null
+++ b/chainio/dispatcher_test.go
@@ -0,0 +1,383 @@
+package chainio
+
+import (
+ "testing"
+ "time"
+
+ "github.com/lightningnetwork/lnd/chainntnfs"
+ "github.com/lightningnetwork/lnd/fn/v2"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+// TestNotifyAndWaitOnConsumerErr asserts when the consumer returns an error,
+// it's returned by notifyAndWait.
+func TestNotifyAndWaitOnConsumerErr(t *testing.T) {
+ t.Parallel()
+
+ // Create a mock consumer.
+ consumer := &MockConsumer{}
+ defer consumer.AssertExpectations(t)
+ consumer.On("Name").Return("mocker")
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Mock ProcessBlock to return an error.
+ consumer.On("ProcessBlock", mockBeat).Return(errDummy).Once()
+
+ // Call the method under test.
+ err := notifyAndWait(mockBeat, consumer, DefaultProcessBlockTimeout)
+
+ // We expect the error to be returned.
+ require.ErrorIs(t, err, errDummy)
+}
+
+// TestNotifyAndWaitOnConsumerErr asserts when the consumer successfully
+// processed the beat, no error is returned.
+func TestNotifyAndWaitOnConsumerSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Create a mock consumer.
+ consumer := &MockConsumer{}
+ defer consumer.AssertExpectations(t)
+ consumer.On("Name").Return("mocker")
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Mock ProcessBlock to return nil.
+ consumer.On("ProcessBlock", mockBeat).Return(nil).Once()
+
+ // Call the method under test.
+ err := notifyAndWait(mockBeat, consumer, DefaultProcessBlockTimeout)
+
+ // We expect a nil error to be returned.
+ require.NoError(t, err)
+}
+
+// TestNotifyAndWaitOnConsumerTimeout asserts when the consumer times out
+// processing the block, the timeout error is returned.
+func TestNotifyAndWaitOnConsumerTimeout(t *testing.T) {
+ t.Parallel()
+
+ // Set timeout to be 10ms.
+ processBlockTimeout := 10 * time.Millisecond
+
+ // Create a mock consumer.
+ consumer := &MockConsumer{}
+ defer consumer.AssertExpectations(t)
+ consumer.On("Name").Return("mocker")
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Mock ProcessBlock to return nil but blocks on returning.
+ consumer.On("ProcessBlock", mockBeat).Return(nil).Run(
+ func(args mock.Arguments) {
+ // Sleep one second to block on the method.
+ time.Sleep(processBlockTimeout * 100)
+ }).Once()
+
+ // Call the method under test.
+ err := notifyAndWait(mockBeat, consumer, processBlockTimeout)
+
+ // We expect a timeout error to be returned.
+ require.ErrorIs(t, err, ErrProcessBlockTimeout)
+}
+
+// TestDispatchSequential checks that the beat is sent to the consumers
+// sequentially.
+func TestDispatchSequential(t *testing.T) {
+ t.Parallel()
+
+ // Create three mock consumers.
+ consumer1 := &MockConsumer{}
+ defer consumer1.AssertExpectations(t)
+ consumer1.On("Name").Return("mocker1")
+
+ consumer2 := &MockConsumer{}
+ defer consumer2.AssertExpectations(t)
+ consumer2.On("Name").Return("mocker2")
+
+ consumer3 := &MockConsumer{}
+ defer consumer3.AssertExpectations(t)
+ consumer3.On("Name").Return("mocker3")
+
+ consumers := []Consumer{consumer1, consumer2, consumer3}
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // prevConsumer specifies the previous consumer that was called.
+ var prevConsumer string
+
+ // Mock the ProcessBlock on consumers to reutrn immediately.
+ consumer1.On("ProcessBlock", mockBeat).Return(nil).Run(
+ func(args mock.Arguments) {
+ // Check the order of the consumers.
+ //
+ // The first consumer should have no previous consumer.
+ require.Empty(t, prevConsumer)
+
+ // Set the consumer as the previous consumer.
+ prevConsumer = consumer1.Name()
+ }).Once()
+
+ consumer2.On("ProcessBlock", mockBeat).Return(nil).Run(
+ func(args mock.Arguments) {
+ // Check the order of the consumers.
+ //
+ // The second consumer should see consumer1.
+ require.Equal(t, consumer1.Name(), prevConsumer)
+
+ // Set the consumer as the previous consumer.
+ prevConsumer = consumer2.Name()
+ }).Once()
+
+ consumer3.On("ProcessBlock", mockBeat).Return(nil).Run(
+ func(args mock.Arguments) {
+ // Check the order of the consumers.
+ //
+ // The third consumer should see consumer2.
+ require.Equal(t, consumer2.Name(), prevConsumer)
+
+ // Set the consumer as the previous consumer.
+ prevConsumer = consumer3.Name()
+ }).Once()
+
+ // Call the method under test.
+ err := DispatchSequential(mockBeat, consumers)
+ require.NoError(t, err)
+
+ // Check the previous consumer is the last consumer.
+ require.Equal(t, consumer3.Name(), prevConsumer)
+}
+
+// TestRegisterQueue tests the RegisterQueue function.
+func TestRegisterQueue(t *testing.T) {
+ t.Parallel()
+
+ // Create two mock consumers.
+ consumer1 := &MockConsumer{}
+ defer consumer1.AssertExpectations(t)
+ consumer1.On("Name").Return("mocker1")
+
+ consumer2 := &MockConsumer{}
+ defer consumer2.AssertExpectations(t)
+ consumer2.On("Name").Return("mocker2")
+
+ consumers := []Consumer{consumer1, consumer2}
+
+ // Create a mock chain notifier.
+ mockNotifier := &chainntnfs.MockChainNotifier{}
+ defer mockNotifier.AssertExpectations(t)
+
+ // Create a new dispatcher.
+ b := NewBlockbeatDispatcher(mockNotifier)
+
+ // Register the consumers.
+ b.RegisterQueue(consumers)
+
+ // Assert that the consumers have been registered.
+ //
+ // We should have one queue.
+ require.Len(t, b.consumerQueues, 1)
+
+ // The queue should have two consumers.
+ queue, ok := b.consumerQueues[1]
+ require.True(t, ok)
+ require.Len(t, queue, 2)
+}
+
+// TestStartDispatcher tests the Start method.
+func TestStartDispatcher(t *testing.T) {
+ t.Parallel()
+
+ // Create a mock chain notifier.
+ mockNotifier := &chainntnfs.MockChainNotifier{}
+ defer mockNotifier.AssertExpectations(t)
+
+ // Create a new dispatcher.
+ b := NewBlockbeatDispatcher(mockNotifier)
+
+ // Start the dispatcher without consumers should return an error.
+ err := b.Start()
+ require.Error(t, err)
+
+ // Create a consumer and register it.
+ consumer := &MockConsumer{}
+ defer consumer.AssertExpectations(t)
+ consumer.On("Name").Return("mocker1")
+ b.RegisterQueue([]Consumer{consumer})
+
+ // Mock the chain notifier to return an error.
+ mockNotifier.On("RegisterBlockEpochNtfn",
+ mock.Anything).Return(nil, errDummy).Once()
+
+ // Start the dispatcher now should return the error.
+ err = b.Start()
+ require.ErrorIs(t, err, errDummy)
+
+ // Mock the chain notifier to return a valid notifier.
+ blockEpochs := &chainntnfs.BlockEpochEvent{}
+ mockNotifier.On("RegisterBlockEpochNtfn",
+ mock.Anything).Return(blockEpochs, nil).Once()
+
+ // Start the dispatcher now should not return an error.
+ err = b.Start()
+ require.NoError(t, err)
+}
+
+// TestDispatchBlocks asserts the blocks are properly dispatched to the queues.
+func TestDispatchBlocks(t *testing.T) {
+ t.Parallel()
+
+ // Create a mock chain notifier.
+ mockNotifier := &chainntnfs.MockChainNotifier{}
+ defer mockNotifier.AssertExpectations(t)
+
+ // Create a new dispatcher.
+ b := NewBlockbeatDispatcher(mockNotifier)
+
+ // Create the beat and attach it to the dispatcher.
+ epoch := chainntnfs.BlockEpoch{Height: 1}
+ beat := NewBeat(epoch)
+ b.beat = beat
+
+ // Create a consumer and register it.
+ consumer := &MockConsumer{}
+ defer consumer.AssertExpectations(t)
+ consumer.On("Name").Return("mocker1")
+ b.RegisterQueue([]Consumer{consumer})
+
+ // Mock the consumer to return nil error on ProcessBlock. This
+ // implictly asserts that the step `notifyQueues` is successfully
+ // reached in the `dispatchBlocks` method.
+ consumer.On("ProcessBlock", mock.Anything).Return(nil).Once()
+
+ // Create a test epoch chan.
+ epochChan := make(chan *chainntnfs.BlockEpoch, 1)
+ blockEpochs := &chainntnfs.BlockEpochEvent{
+ Epochs: epochChan,
+ Cancel: func() {},
+ }
+
+ // Call the method in a goroutine.
+ done := make(chan struct{})
+ b.wg.Add(1)
+ go func() {
+ defer close(done)
+ b.dispatchBlocks(blockEpochs)
+ }()
+
+ // Send an epoch.
+ epoch = chainntnfs.BlockEpoch{Height: 2}
+ epochChan <- &epoch
+
+ // Wait for the dispatcher to process the epoch.
+ time.Sleep(100 * time.Millisecond)
+
+ // Stop the dispatcher.
+ b.Stop()
+
+ // We expect the dispatcher to stop immediately.
+ _, err := fn.RecvOrTimeout(done, time.Second)
+ require.NoError(t, err)
+}
+
+// TestNotifyQueuesSuccess checks when the dispatcher successfully notifies all
+// the queues, no error is returned.
+func TestNotifyQueuesSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Create two mock consumers.
+ consumer1 := &MockConsumer{}
+ defer consumer1.AssertExpectations(t)
+ consumer1.On("Name").Return("mocker1")
+
+ consumer2 := &MockConsumer{}
+ defer consumer2.AssertExpectations(t)
+ consumer2.On("Name").Return("mocker2")
+
+ // Create two queues.
+ queue1 := []Consumer{consumer1}
+ queue2 := []Consumer{consumer2}
+
+ // Create a mock chain notifier.
+ mockNotifier := &chainntnfs.MockChainNotifier{}
+ defer mockNotifier.AssertExpectations(t)
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Create a new dispatcher.
+ b := NewBlockbeatDispatcher(mockNotifier)
+
+ // Register the queues.
+ b.RegisterQueue(queue1)
+ b.RegisterQueue(queue2)
+
+ // Attach the blockbeat.
+ b.beat = mockBeat
+
+ // Mock the consumers to return nil error on ProcessBlock for
+ // both calls.
+ consumer1.On("ProcessBlock", mockBeat).Return(nil).Once()
+ consumer2.On("ProcessBlock", mockBeat).Return(nil).Once()
+
+ // Notify the queues. The mockers will be asserted in the end to
+ // validate the calls.
+ err := b.notifyQueues()
+ require.NoError(t, err)
+}
+
+// TestNotifyQueuesError checks when one of the queue returns an error, this
+// error is returned by the method.
+func TestNotifyQueuesError(t *testing.T) {
+ t.Parallel()
+
+ // Create a mock consumer.
+ consumer := &MockConsumer{}
+ defer consumer.AssertExpectations(t)
+ consumer.On("Name").Return("mocker1")
+
+ // Create one queue.
+ queue := []Consumer{consumer}
+
+ // Create a mock chain notifier.
+ mockNotifier := &chainntnfs.MockChainNotifier{}
+ defer mockNotifier.AssertExpectations(t)
+
+ // Create a mock beat.
+ mockBeat := &MockBlockbeat{}
+ defer mockBeat.AssertExpectations(t)
+ mockBeat.On("logger").Return(clog)
+
+ // Create a new dispatcher.
+ b := NewBlockbeatDispatcher(mockNotifier)
+
+ // Register the queues.
+ b.RegisterQueue(queue)
+
+ // Attach the blockbeat.
+ b.beat = mockBeat
+
+ // Mock the consumer to return an error on ProcessBlock.
+ consumer.On("ProcessBlock", mockBeat).Return(errDummy).Once()
+
+ // Notify the queues. The mockers will be asserted in the end to
+ // validate the calls.
+ err := b.notifyQueues()
+ require.ErrorIs(t, err, errDummy)
+}
diff --git a/chainio/interface.go b/chainio/interface.go
new file mode 100644
index 0000000000..03c09faf7c
--- /dev/null
+++ b/chainio/interface.go
@@ -0,0 +1,53 @@
+package chainio
+
+import "github.com/btcsuite/btclog/v2"
+
+// Blockbeat defines an interface that can be used by subsystems to retrieve
+// block data. It is sent by the BlockbeatDispatcher to all the registered
+// consumers whenever a new block is received. Once the consumer finishes
+// processing the block, it must signal it by calling `NotifyBlockProcessed`.
+//
+// The blockchain is a state machine - whenever there's a state change, it's
+// manifested in a block. The blockbeat is a way to notify subsystems of this
+// state change, and to provide them with the data they need to process it. In
+// other words, subsystems must react to this state change and should consider
+// being driven by the blockbeat in their own state machines.
+type Blockbeat interface {
+ // blockbeat is a private interface that's only used in this package.
+ blockbeat
+
+ // Height returns the current block height.
+ Height() int32
+}
+
+// blockbeat defines a set of private methods used in this package to make
+// interaction with the blockbeat easier.
+type blockbeat interface {
+ // logger returns the internal logger used by the blockbeat which has a
+ // block height prefix.
+ logger() btclog.Logger
+}
+
+// Consumer defines a blockbeat consumer interface. Subsystems that need block
+// info must implement it.
+type Consumer interface {
+ // TODO(yy): We should also define the start methods used by the
+ // consumers such that when implementing the interface, the consumer
+ // will always be started with a blockbeat. This cannot be enforced at
+ // the moment as we need refactor all the start methods to only take a
+ // beat.
+ //
+ // Start(beat Blockbeat) error
+
+ // Name returns a human-readable string for this subsystem.
+ Name() string
+
+ // ProcessBlock takes a blockbeat and processes it. It should not
+ // return until the subsystem has updated its state based on the block
+ // data.
+ //
+ // NOTE: The consumer must try its best to NOT return an error. If an
+ // error is returned from processing the block, it means the subsystem
+ // cannot react to onchain state changes and lnd will shutdown.
+ ProcessBlock(b Blockbeat) error
+}
diff --git a/chainio/log.go b/chainio/log.go
new file mode 100644
index 0000000000..2d8c26f7a5
--- /dev/null
+++ b/chainio/log.go
@@ -0,0 +1,32 @@
+package chainio
+
+import (
+ "github.com/btcsuite/btclog/v2"
+ "github.com/lightningnetwork/lnd/build"
+)
+
+// Subsystem defines the logging code for this subsystem.
+const Subsystem = "CHIO"
+
+// clog is a logger that is initialized with no output filters. This means the
+// package will not perform any logging by default until the caller requests
+// it.
+var clog btclog.Logger
+
+// The default amount of logging is none.
+func init() {
+ UseLogger(build.NewSubLogger(Subsystem, nil))
+}
+
+// DisableLog disables all library log output. Logging output is disabled by
+// default until UseLogger is called.
+func DisableLog() {
+ UseLogger(btclog.Disabled)
+}
+
+// UseLogger uses a specified Logger to output package logging info. This
+// should be used in preference to SetLogWriter if the caller is also using
+// btclog.
+func UseLogger(logger btclog.Logger) {
+ clog = logger
+}
diff --git a/chainio/mocks.go b/chainio/mocks.go
new file mode 100644
index 0000000000..5677734e1d
--- /dev/null
+++ b/chainio/mocks.go
@@ -0,0 +1,50 @@
+package chainio
+
+import (
+ "github.com/btcsuite/btclog/v2"
+ "github.com/stretchr/testify/mock"
+)
+
+// MockConsumer is a mock implementation of the Consumer interface.
+type MockConsumer struct {
+ mock.Mock
+}
+
+// Compile-time constraint to ensure MockConsumer implements Consumer.
+var _ Consumer = (*MockConsumer)(nil)
+
+// Name returns a human-readable string for this subsystem.
+func (m *MockConsumer) Name() string {
+ args := m.Called()
+ return args.String(0)
+}
+
+// ProcessBlock takes a blockbeat and processes it. A receive-only error chan
+// must be returned.
+func (m *MockConsumer) ProcessBlock(b Blockbeat) error {
+ args := m.Called(b)
+
+ return args.Error(0)
+}
+
+// MockBlockbeat is a mock implementation of the Blockbeat interface.
+type MockBlockbeat struct {
+ mock.Mock
+}
+
+// Compile-time constraint to ensure MockBlockbeat implements Blockbeat.
+var _ Blockbeat = (*MockBlockbeat)(nil)
+
+// Height returns the current block height.
+func (m *MockBlockbeat) Height() int32 {
+ args := m.Called()
+
+ return args.Get(0).(int32)
+}
+
+// logger returns the logger for the blockbeat.
+func (m *MockBlockbeat) logger() btclog.Logger {
+ args := m.Called()
+
+ return args.Get(0).(btclog.Logger)
+}
diff --git a/chainntnfs/bitcoindnotify/bitcoind.go b/chainntnfs/bitcoindnotify/bitcoind.go
index fc20fbb857..59c03d5171 100644
--- a/chainntnfs/bitcoindnotify/bitcoind.go
+++ b/chainntnfs/bitcoindnotify/bitcoind.go
@@ -15,7 +15,7 @@ import (
"github.com/btcsuite/btcwallet/chain"
"github.com/lightningnetwork/lnd/blockcache"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/queue"
)
diff --git a/chainntnfs/btcdnotify/btcd.go b/chainntnfs/btcdnotify/btcd.go
index c3a40a00bf..e3bff289cf 100644
--- a/chainntnfs/btcdnotify/btcd.go
+++ b/chainntnfs/btcdnotify/btcd.go
@@ -17,7 +17,7 @@ import (
"github.com/btcsuite/btcwallet/chain"
"github.com/lightningnetwork/lnd/blockcache"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/queue"
)
diff --git a/chainntnfs/interface.go b/chainntnfs/interface.go
index b2383636aa..1b8a5acb50 100644
--- a/chainntnfs/interface.go
+++ b/chainntnfs/interface.go
@@ -13,7 +13,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
)
var (
diff --git a/chainntnfs/mocks.go b/chainntnfs/mocks.go
index d9ab9928d0..4a888b162e 100644
--- a/chainntnfs/mocks.go
+++ b/chainntnfs/mocks.go
@@ -3,7 +3,7 @@ package chainntnfs
import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/stretchr/testify/mock"
)
diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go
index edc422482e..a9a9ede704 100644
--- a/chainreg/chainregistry.go
+++ b/chainreg/chainregistry.go
@@ -23,7 +23,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs/btcdnotify"
"github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
diff --git a/chanbackup/backup.go b/chanbackup/backup.go
index 5853b37e45..afffe5a2e8 100644
--- a/chanbackup/backup.go
+++ b/chanbackup/backup.go
@@ -5,7 +5,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
)
// LiveChannelSource is an interface that allows us to query for the set of
diff --git a/chanbackup/single.go b/chanbackup/single.go
index b741320b07..01d14f6c07 100644
--- a/chanbackup/single.go
+++ b/chanbackup/single.go
@@ -12,7 +12,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/chanbackup/single_test.go b/chanbackup/single_test.go
index d2212bd859..0fe402926d 100644
--- a/chanbackup/single_test.go
+++ b/chanbackup/single_test.go
@@ -13,7 +13,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnencrypt"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/channeldb/channel.go b/channeldb/channel.go
index 9ca57312aa..f4e99a6f8c 100644
--- a/channeldb/channel.go
+++ b/channeldb/channel.go
@@ -19,7 +19,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go
index 2cac0baced..b1ca100eb3 100644
--- a/channeldb/channel_test.go
+++ b/channeldb/channel_test.go
@@ -18,7 +18,7 @@ import (
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/clock"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
diff --git a/channeldb/migration/lnwire21/custom_records.go b/channeldb/migration/lnwire21/custom_records.go
index f0f59185e9..7771c8ec8b 100644
--- a/channeldb/migration/lnwire21/custom_records.go
+++ b/channeldb/migration/lnwire21/custom_records.go
@@ -6,7 +6,7 @@ import (
"io"
"sort"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
@@ -163,9 +163,12 @@ func (c CustomRecords) SerializeTo(w io.Writer) error {
// ProduceRecordsSorted converts a slice of record producers into a slice of
// records and then sorts it by type.
func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record {
- records := fn.Map(func(producer tlv.RecordProducer) tlv.Record {
- return producer.Record()
- }, recordProducers)
+ records := fn.Map(
+ recordProducers,
+ func(producer tlv.RecordProducer) tlv.Record {
+ return producer.Record()
+ },
+ )
// Ensure that the set of records are sorted before we attempt to
// decode from the stream, to ensure they're canonical.
@@ -196,9 +199,9 @@ func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
// RecordsAsProducers converts a slice of records into a slice of record
// producers.
func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer {
- return fn.Map(func(record tlv.Record) tlv.RecordProducer {
+ return fn.Map(records, func(record tlv.Record) tlv.RecordProducer {
return &record
- }, records)
+ })
}
// EncodeRecords encodes the given records into a byte slice.
diff --git a/channeldb/migration32/mission_control_store.go b/channeldb/migration32/mission_control_store.go
index 3ac9d6114c..76463eb6ca 100644
--- a/channeldb/migration32/mission_control_store.go
+++ b/channeldb/migration32/mission_control_store.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
@@ -371,7 +371,7 @@ func extractMCRoute(r *Route) *mcRoute {
// extractMCHops extracts the Hop fields that MC actually uses from a slice of
// Hops.
func extractMCHops(hops []*Hop) mcHops {
- return fn.Map(extractMCHop, hops)
+ return fn.Map(hops, extractMCHop)
}
// extractMCHop extracts the Hop fields that MC actually uses from a Hop.
diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go
index 3abc73f81e..ea6eaf13f2 100644
--- a/channeldb/revocation_log.go
+++ b/channeldb/revocation_log.go
@@ -7,7 +7,7 @@ import (
"math"
"github.com/btcsuite/btcd/btcutil"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go
index 4290552eee..2df6627e2c 100644
--- a/channeldb/revocation_log_test.go
+++ b/channeldb/revocation_log_test.go
@@ -8,7 +8,7 @@ import (
"testing"
"github.com/btcsuite/btcd/btcutil"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntest/channels"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/cmd/commands/cmd_macaroon.go b/cmd/commands/cmd_macaroon.go
index 15c29380a7..d7d6d5f9dc 100644
--- a/cmd/commands/cmd_macaroon.go
+++ b/cmd/commands/cmd_macaroon.go
@@ -10,7 +10,7 @@ import (
"strings"
"unicode"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lncfg"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/macaroons"
@@ -177,12 +177,15 @@ func bakeMacaroon(ctx *cli.Context) error {
"%w", err)
}
- ops := fn.Map(func(p *lnrpc.MacaroonPermission) bakery.Op {
- return bakery.Op{
- Entity: p.Entity,
- Action: p.Action,
- }
- }, parsedPermissions)
+ ops := fn.Map(
+ parsedPermissions,
+ func(p *lnrpc.MacaroonPermission) bakery.Op {
+ return bakery.Op{
+ Entity: p.Entity,
+ Action: p.Action,
+ }
+ },
+ )
rawMacaroon, err = macaroons.BakeFromRootKey(macRootKey, ops)
if err != nil {
diff --git a/config_builder.go b/config_builder.go
index 42650bb68b..afafbed754 100644
--- a/config_builder.go
+++ b/config_builder.go
@@ -33,9 +33,10 @@ import (
"github.com/lightningnetwork/lnd/chainreg"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/funding"
graphdb "github.com/lightningnetwork/lnd/graph/db"
+ "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
@@ -47,7 +48,6 @@ import (
"github.com/lightningnetwork/lnd/lnwallet/rpcwallet"
"github.com/lightningnetwork/lnd/macaroons"
"github.com/lightningnetwork/lnd/msgmux"
- "github.com/lightningnetwork/lnd/routing"
"github.com/lightningnetwork/lnd/rpcperms"
"github.com/lightningnetwork/lnd/signal"
"github.com/lightningnetwork/lnd/sqldb"
@@ -166,7 +166,7 @@ type AuxComponents struct {
// TrafficShaper is an optional traffic shaper that can be used to
// control the outgoing channel of a payment.
- TrafficShaper fn.Option[routing.TlvTrafficShaper]
+ TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper]
// MsgRouter is an optional message router that if set will be used in
// place of a new blank default message router.
diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go
index b4d6877202..eb42ad3cb4 100644
--- a/contractcourt/anchor_resolver.go
+++ b/contractcourt/anchor_resolver.go
@@ -2,6 +2,7 @@ package contractcourt
import (
"errors"
+ "fmt"
"io"
"sync"
@@ -9,7 +10,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/sweep"
)
@@ -23,9 +24,6 @@ type anchorResolver struct {
// anchor is the outpoint on the commitment transaction.
anchor wire.OutPoint
- // resolved reflects if the contract has been fully resolved or not.
- resolved bool
-
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
@@ -71,7 +69,7 @@ func newAnchorResolver(anchorSignDescriptor input.SignDescriptor,
currentReport: report,
}
- r.initLogger(r)
+ r.initLogger(fmt.Sprintf("%T(%v)", r, r.anchor))
return r
}
@@ -83,49 +81,12 @@ func (c *anchorResolver) ResolverKey() []byte {
return nil
}
-// Resolve offers the anchor output to the sweeper and waits for it to be swept.
-func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) {
- // Attempt to update the sweep parameters to the post-confirmation
- // situation. We don't want to force sweep anymore, because the anchor
- // lost its special purpose to get the commitment confirmed. It is just
- // an output that we want to sweep only if it is economical to do so.
- //
- // An exclusive group is not necessary anymore, because we know that
- // this is the only anchor that can be swept.
- //
- // We also clear the parent tx information for cpfp, because the
- // commitment tx is confirmed.
- //
- // After a restart or when the remote force closes, the sweeper is not
- // yet aware of the anchor. In that case, it will be added as new input
- // to the sweeper.
- witnessType := input.CommitmentAnchor
-
- // For taproot channels, we need to use the proper witness type.
- if c.chanType.IsTaproot() {
- witnessType = input.TaprootAnchorSweepSpend
- }
-
- anchorInput := input.MakeBaseInput(
- &c.anchor, witnessType, &c.anchorSignDescriptor,
- c.broadcastHeight, nil,
- )
-
- resultChan, err := c.Sweeper.SweepInput(
- &anchorInput,
- sweep.Params{
- // For normal anchor sweeping, the budget is 330 sats.
- Budget: btcutil.Amount(
- anchorInput.SignDesc().Output.Value,
- ),
-
- // There's no rush to sweep the anchor, so we use a nil
- // deadline here.
- DeadlineHeight: fn.None[int32](),
- },
- )
- if err != nil {
- return nil, err
+// Resolve waits for the output to be swept.
+func (c *anchorResolver) Resolve() (ContractResolver, error) {
+ // If we're already resolved, then we can exit early.
+ if c.IsResolved() {
+ c.log.Errorf("already resolved")
+ return nil, nil
}
var (
@@ -134,7 +95,7 @@ func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) {
)
select {
- case sweepRes := <-resultChan:
+ case sweepRes := <-c.sweepResultChan:
switch sweepRes.Err {
// Anchor was swept successfully.
case nil:
@@ -160,6 +121,8 @@ func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) {
return nil, errResolverShuttingDown
}
+ c.log.Infof("resolved in tx %v", spendTx)
+
// Update report to reflect that funds are no longer in limbo.
c.reportLock.Lock()
if outcome == channeldb.ResolverOutcomeClaimed {
@@ -171,7 +134,7 @@ func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) {
)
c.reportLock.Unlock()
- c.resolved = true
+ c.markResolved()
return nil, c.PutResolverReport(nil, report)
}
@@ -180,15 +143,10 @@ func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) {
//
// NOTE: Part of the ContractResolver interface.
func (c *anchorResolver) Stop() {
- close(c.quit)
-}
+ c.log.Debugf("stopping...")
+ defer c.log.Debugf("stopped")
-// IsResolved returns true if the stored state in the resolve is fully
-// resolved. In this case the target output can be forgotten.
-//
-// NOTE: Part of the ContractResolver interface.
-func (c *anchorResolver) IsResolved() bool {
- return c.resolved
+ close(c.quit)
}
// SupplementState allows the user of a ContractResolver to supplement it with
@@ -215,3 +173,68 @@ func (c *anchorResolver) Encode(w io.Writer) error {
// A compile time assertion to ensure anchorResolver meets the
// ContractResolver interface.
var _ ContractResolver = (*anchorResolver)(nil)
+
+// Launch offers the anchor output to the sweeper.
+func (c *anchorResolver) Launch() error {
+ if c.isLaunched() {
+ c.log.Tracef("already launched")
+ return nil
+ }
+
+ c.log.Debugf("launching resolver...")
+ c.markLaunched()
+
+ // If we're already resolved, then we can exit early.
+ if c.IsResolved() {
+ c.log.Errorf("already resolved")
+ return nil
+ }
+
+ // Attempt to update the sweep parameters to the post-confirmation
+ // situation. We don't want to force sweep anymore, because the anchor
+ // lost its special purpose to get the commitment confirmed. It is just
+ // an output that we want to sweep only if it is economical to do so.
+ //
+ // An exclusive group is not necessary anymore, because we know that
+ // this is the only anchor that can be swept.
+ //
+ // We also clear the parent tx information for cpfp, because the
+ // commitment tx is confirmed.
+ //
+ // After a restart or when the remote force closes, the sweeper is not
+ // yet aware of the anchor. In that case, it will be added as new input
+ // to the sweeper.
+ witnessType := input.CommitmentAnchor
+
+ // For taproot channels, we need to use the proper witness type.
+ if c.chanType.IsTaproot() {
+ witnessType = input.TaprootAnchorSweepSpend
+ }
+
+ anchorInput := input.MakeBaseInput(
+ &c.anchor, witnessType, &c.anchorSignDescriptor,
+ c.broadcastHeight, nil,
+ )
+
+ resultChan, err := c.Sweeper.SweepInput(
+ &anchorInput,
+ sweep.Params{
+ // For normal anchor sweeping, the budget is 330 sats.
+ Budget: btcutil.Amount(
+ anchorInput.SignDesc().Output.Value,
+ ),
+
+ // There's no rush to sweep the anchor, so we use a nil
+ // deadline here.
+ DeadlineHeight: fn.None[int32](),
+ },
+ )
+
+ if err != nil {
+ return err
+ }
+
+ c.sweepResultChan = resultChan
+
+ return nil
+}
diff --git a/contractcourt/breach_arbitrator.go b/contractcourt/breach_arbitrator.go
index d59829b5e5..33bc7f7e33 100644
--- a/contractcourt/breach_arbitrator.go
+++ b/contractcourt/breach_arbitrator.go
@@ -15,7 +15,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
@@ -1537,9 +1537,9 @@ func (b *BreachArbitrator) createSweepTx(
// outputs from the regular, BTC only outputs. So we only need one such
// output, which'll carry the custom channel "valuables" from both the
// breached commitment and HTLC outputs.
- hasBlobs := fn.Any(func(i input.Input) bool {
+ hasBlobs := fn.Any(inputs, func(i input.Input) bool {
return i.ResolutionBlob().IsSome()
- }, inputs)
+ })
if hasBlobs {
weightEstimate.AddP2TROutput()
}
@@ -1624,7 +1624,7 @@ func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit,
// First, we'll add the extra sweep output if it exists, subtracting the
// amount from the sweep amt.
if b.cfg.AuxSweeper.IsSome() {
- extraChangeOut.WhenResult(func(o sweep.SweepOutput) {
+ extraChangeOut.WhenOk(func(o sweep.SweepOutput) {
sweepAmt -= o.Value
txn.AddTxOut(&o.TxOut)
@@ -1697,7 +1697,7 @@ func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit,
return &justiceTxCtx{
justiceTx: txn,
sweepAddr: pkScript,
- extraTxOut: extraChangeOut.Option(),
+ extraTxOut: extraChangeOut.OkToSome(),
fee: txFee,
inputs: inputs,
}, nil
diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go
index 576009eda4..99ed852696 100644
--- a/contractcourt/breach_arbitrator_test.go
+++ b/contractcourt/breach_arbitrator_test.go
@@ -22,7 +22,7 @@ import (
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntest/channels"
@@ -36,7 +36,7 @@ import (
)
var (
- defaultTimeout = 30 * time.Second
+ defaultTimeout = 10 * time.Second
breachOutPoints = []wire.OutPoint{
{
diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go
index 740b4471d5..5644e60fad 100644
--- a/contractcourt/breach_resolver.go
+++ b/contractcourt/breach_resolver.go
@@ -2,6 +2,7 @@ package contractcourt
import (
"encoding/binary"
+ "fmt"
"io"
"github.com/lightningnetwork/lnd/channeldb"
@@ -11,9 +12,6 @@ import (
// future, this will likely take over the duties the current BreachArbitrator
// has.
type breachResolver struct {
- // resolved reflects if the contract has been fully resolved or not.
- resolved bool
-
// subscribed denotes whether or not the breach resolver has subscribed
// to the BreachArbitrator for breach resolution.
subscribed bool
@@ -32,7 +30,7 @@ func newBreachResolver(resCfg ResolverConfig) *breachResolver {
replyChan: make(chan struct{}),
}
- r.initLogger(r)
+ r.initLogger(fmt.Sprintf("%T(%v)", r, r.ChanPoint))
return r
}
@@ -47,7 +45,7 @@ func (b *breachResolver) ResolverKey() []byte {
// been broadcast.
//
// TODO(yy): let sweeper handle the breach inputs.
-func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) {
+func (b *breachResolver) Resolve() (ContractResolver, error) {
if !b.subscribed {
complete, err := b.SubscribeBreachComplete(
&b.ChanPoint, b.replyChan,
@@ -59,7 +57,7 @@ func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) {
// If the breach resolution process is already complete, then
// we can cleanup and checkpoint the resolved state.
if complete {
- b.resolved = true
+ b.markResolved()
return nil, b.Checkpoint(b)
}
@@ -72,8 +70,9 @@ func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) {
// The replyChan has been closed, signalling that the breach
// has been fully resolved. Checkpoint the resolved state and
// exit.
- b.resolved = true
+ b.markResolved()
return nil, b.Checkpoint(b)
+
case <-b.quit:
}
@@ -82,22 +81,17 @@ func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) {
// Stop signals the breachResolver to stop.
func (b *breachResolver) Stop() {
+ b.log.Debugf("stopping...")
close(b.quit)
}
-// IsResolved returns true if the breachResolver is fully resolved and cleanup
-// can occur.
-func (b *breachResolver) IsResolved() bool {
- return b.resolved
-}
-
// SupplementState adds additional state to the breachResolver.
func (b *breachResolver) SupplementState(_ *channeldb.OpenChannel) {
}
// Encode encodes the breachResolver to the passed writer.
func (b *breachResolver) Encode(w io.Writer) error {
- return binary.Write(w, endian, b.resolved)
+ return binary.Write(w, endian, b.IsResolved())
}
// newBreachResolverFromReader attempts to decode an encoded breachResolver
@@ -110,11 +104,15 @@ func newBreachResolverFromReader(r io.Reader, resCfg ResolverConfig) (
replyChan: make(chan struct{}),
}
- if err := binary.Read(r, endian, &b.resolved); err != nil {
+ var resolved bool
+ if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
+ if resolved {
+ b.markResolved()
+ }
- b.initLogger(b)
+ b.initLogger(fmt.Sprintf("%T(%v)", b, b.ChanPoint))
return b, nil
}
@@ -122,3 +120,16 @@ func newBreachResolverFromReader(r io.Reader, resCfg ResolverConfig) (
// A compile time assertion to ensure breachResolver meets the ContractResolver
// interface.
var _ ContractResolver = (*breachResolver)(nil)
+
+// TODO(yy): implement it once the outputs are offered to the sweeper.
+func (b *breachResolver) Launch() error {
+ if b.isLaunched() {
+ b.log.Tracef("already launched")
+ return nil
+ }
+
+ b.log.Debugf("launching resolver...")
+ b.markLaunched()
+
+ return nil
+}
diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go
index a0908ea3fa..7d199c5c28 100644
--- a/contractcourt/briefcase.go
+++ b/contractcourt/briefcase.go
@@ -10,7 +10,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwallet"
diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go
index 0f44db2abb..aa2e711efc 100644
--- a/contractcourt/briefcase_test.go
+++ b/contractcourt/briefcase_test.go
@@ -14,7 +14,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnmock"
@@ -206,8 +206,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver,
ogRes.outputIncubating, diskRes.outputIncubating)
}
if ogRes.resolved != diskRes.resolved {
- t.Fatalf("expected %v, got %v", ogRes.resolved,
- diskRes.resolved)
+ t.Fatalf("expected %v, got %v", ogRes.resolved.Load(),
+ diskRes.resolved.Load())
}
if ogRes.broadcastHeight != diskRes.broadcastHeight {
t.Fatalf("expected %v, got %v",
@@ -229,8 +229,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver,
ogRes.outputIncubating, diskRes.outputIncubating)
}
if ogRes.resolved != diskRes.resolved {
- t.Fatalf("expected %v, got %v", ogRes.resolved,
- diskRes.resolved)
+ t.Fatalf("expected %v, got %v", ogRes.resolved.Load(),
+ diskRes.resolved.Load())
}
if ogRes.broadcastHeight != diskRes.broadcastHeight {
t.Fatalf("expected %v, got %v",
@@ -275,8 +275,8 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver,
ogRes.commitResolution, diskRes.commitResolution)
}
if ogRes.resolved != diskRes.resolved {
- t.Fatalf("expected %v, got %v", ogRes.resolved,
- diskRes.resolved)
+ t.Fatalf("expected %v, got %v", ogRes.resolved.Load(),
+ diskRes.resolved.Load())
}
if ogRes.broadcastHeight != diskRes.broadcastHeight {
t.Fatalf("expected %v, got %v",
@@ -312,13 +312,14 @@ func TestContractInsertionRetrieval(t *testing.T) {
SweepSignDesc: testSignDesc,
},
outputIncubating: true,
- resolved: true,
broadcastHeight: 102,
htlc: channeldb.HTLC{
HtlcIndex: 12,
},
}
- successResolver := htlcSuccessResolver{
+ timeoutResolver.resolved.Store(true)
+
+ successResolver := &htlcSuccessResolver{
htlcResolution: lnwallet.IncomingHtlcResolution{
Preimage: testPreimage,
SignedSuccessTx: nil,
@@ -327,40 +328,49 @@ func TestContractInsertionRetrieval(t *testing.T) {
SweepSignDesc: testSignDesc,
},
outputIncubating: true,
- resolved: true,
broadcastHeight: 109,
htlc: channeldb.HTLC{
RHash: testPreimage,
},
}
- resolvers := []ContractResolver{
- &timeoutResolver,
- &successResolver,
- &commitSweepResolver{
- commitResolution: lnwallet.CommitOutputResolution{
- SelfOutPoint: testChanPoint2,
- SelfOutputSignDesc: testSignDesc,
- MaturityDelay: 99,
- },
- resolved: false,
- broadcastHeight: 109,
- chanPoint: testChanPoint1,
+ successResolver.resolved.Store(true)
+
+ commitResolver := &commitSweepResolver{
+ commitResolution: lnwallet.CommitOutputResolution{
+ SelfOutPoint: testChanPoint2,
+ SelfOutputSignDesc: testSignDesc,
+ MaturityDelay: 99,
},
+ broadcastHeight: 109,
+ chanPoint: testChanPoint1,
+ }
+ commitResolver.resolved.Store(false)
+
+ resolvers := []ContractResolver{
+ &timeoutResolver, successResolver, commitResolver,
}
// All resolvers require a unique ResolverKey() output. To achieve this
// for the composite resolvers, we'll mutate the underlying resolver
// with a new outpoint.
- contestTimeout := timeoutResolver
- contestTimeout.htlcResolution.ClaimOutpoint = randOutPoint()
+ contestTimeout := htlcTimeoutResolver{
+ htlcResolution: lnwallet.OutgoingHtlcResolution{
+ ClaimOutpoint: randOutPoint(),
+ SweepSignDesc: testSignDesc,
+ },
+ }
resolvers = append(resolvers, &htlcOutgoingContestResolver{
htlcTimeoutResolver: &contestTimeout,
})
- contestSuccess := successResolver
- contestSuccess.htlcResolution.ClaimOutpoint = randOutPoint()
+ contestSuccess := &htlcSuccessResolver{
+ htlcResolution: lnwallet.IncomingHtlcResolution{
+ ClaimOutpoint: randOutPoint(),
+ SweepSignDesc: testSignDesc,
+ },
+ }
resolvers = append(resolvers, &htlcIncomingContestResolver{
htlcExpiry: 100,
- htlcSuccessResolver: &contestSuccess,
+ htlcSuccessResolver: contestSuccess,
})
// For quick lookup during the test, we'll create this map which allow
@@ -438,12 +448,12 @@ func TestContractResolution(t *testing.T) {
SweepSignDesc: testSignDesc,
},
outputIncubating: true,
- resolved: true,
broadcastHeight: 192,
htlc: channeldb.HTLC{
HtlcIndex: 9912,
},
}
+ timeoutResolver.resolved.Store(true)
// First, we'll insert the resolver into the database and ensure that
// we get the same resolver out the other side. We do not need to apply
@@ -491,12 +501,13 @@ func TestContractSwapping(t *testing.T) {
SweepSignDesc: testSignDesc,
},
outputIncubating: true,
- resolved: true,
broadcastHeight: 102,
htlc: channeldb.HTLC{
HtlcIndex: 12,
},
}
+ timeoutResolver.resolved.Store(true)
+
contestResolver := &htlcOutgoingContestResolver{
htlcTimeoutResolver: timeoutResolver,
}
diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go
index 6d9b30d208..011b5225cd 100644
--- a/contractcourt/chain_arbitrator.go
+++ b/contractcourt/chain_arbitrator.go
@@ -11,10 +11,11 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
+ "github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
@@ -244,6 +245,10 @@ type ChainArbitrator struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.
+ // Embed the blockbeat consumer struct to get access to the method
+ // `NotifyBlockProcessed` and the `BlockbeatChan`.
+ chainio.BeatConsumer
+
sync.Mutex
// activeChannels is a map of all the active contracts that are still
@@ -262,6 +267,9 @@ type ChainArbitrator struct {
// active channels that it must still watch over.
chanSource *channeldb.DB
+ // beat is the current best known blockbeat.
+ beat chainio.Blockbeat
+
quit chan struct{}
wg sync.WaitGroup
@@ -272,15 +280,23 @@ type ChainArbitrator struct {
func NewChainArbitrator(cfg ChainArbitratorConfig,
db *channeldb.DB) *ChainArbitrator {
- return &ChainArbitrator{
+ c := &ChainArbitrator{
cfg: cfg,
activeChannels: make(map[wire.OutPoint]*ChannelArbitrator),
activeWatchers: make(map[wire.OutPoint]*chainWatcher),
chanSource: db,
quit: make(chan struct{}),
}
+
+ // Mount the block consumer.
+ c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name())
+
+ return c
}
+// Compile-time check for the chainio.Consumer interface.
+var _ chainio.Consumer = (*ChainArbitrator)(nil)
+
// arbChannel is a wrapper around an open channel that channel arbitrators
// interact with.
type arbChannel struct {
@@ -554,147 +570,30 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error {
}
// Start launches all goroutines that the ChainArbitrator needs to operate.
-func (c *ChainArbitrator) Start() error {
+func (c *ChainArbitrator) Start(beat chainio.Blockbeat) error {
if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
return nil
}
- log.Infof("ChainArbitrator starting with config: budget=[%v]",
- &c.cfg.Budget)
+ // Set the current beat.
+ c.beat = beat
+
+ log.Infof("ChainArbitrator starting at height %d with budget=[%v]",
+ &c.cfg.Budget, c.beat.Height())
// First, we'll fetch all the channels that are still open, in order to
// collect them within our set of active contracts.
- openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels()
- if err != nil {
+ if err := c.loadOpenChannels(); err != nil {
return err
}
- if len(openChannels) > 0 {
- log.Infof("Creating ChannelArbitrators for %v active channels",
- len(openChannels))
- }
-
- // For each open channel, we'll configure then launch a corresponding
- // ChannelArbitrator.
- for _, channel := range openChannels {
- chanPoint := channel.FundingOutpoint
- channel := channel
-
- // First, we'll create an active chainWatcher for this channel
- // to ensure that we detect any relevant on chain events.
- breachClosure := func(ret *lnwallet.BreachRetribution) error {
- return c.cfg.ContractBreach(chanPoint, ret)
- }
-
- chainWatcher, err := newChainWatcher(
- chainWatcherConfig{
- chanState: channel,
- notifier: c.cfg.Notifier,
- signer: c.cfg.Signer,
- isOurAddr: c.cfg.IsOurAddress,
- contractBreach: breachClosure,
- extractStateNumHint: lnwallet.GetStateNumHint,
- auxLeafStore: c.cfg.AuxLeafStore,
- auxResolver: c.cfg.AuxResolver,
- },
- )
- if err != nil {
- return err
- }
-
- c.activeWatchers[chanPoint] = chainWatcher
- channelArb, err := newActiveChannelArbitrator(
- channel, c, chainWatcher.SubscribeChannelEvents(),
- )
- if err != nil {
- return err
- }
-
- c.activeChannels[chanPoint] = channelArb
-
- // Republish any closing transactions for this channel.
- err = c.republishClosingTxs(channel)
- if err != nil {
- log.Errorf("Failed to republish closing txs for "+
- "channel %v", chanPoint)
- }
- }
-
// In addition to the channels that we know to be open, we'll also
// launch arbitrators to finishing resolving any channels that are in
// the pending close state.
- closingChannels, err := c.chanSource.ChannelStateDB().FetchClosedChannels(
- true,
- )
- if err != nil {
+ if err := c.loadPendingCloseChannels(); err != nil {
return err
}
- if len(closingChannels) > 0 {
- log.Infof("Creating ChannelArbitrators for %v closing channels",
- len(closingChannels))
- }
-
- // Next, for each channel is the closing state, we'll launch a
- // corresponding more restricted resolver, as we don't have to watch
- // the chain any longer, only resolve the contracts on the confirmed
- // commitment.
- //nolint:ll
- for _, closeChanInfo := range closingChannels {
- // We can leave off the CloseContract and ForceCloseChan
- // methods as the channel is already closed at this point.
- chanPoint := closeChanInfo.ChanPoint
- arbCfg := ChannelArbitratorConfig{
- ChanPoint: chanPoint,
- ShortChanID: closeChanInfo.ShortChanID,
- ChainArbitratorConfig: c.cfg,
- ChainEvents: &ChainEventSubscription{},
- IsPendingClose: true,
- ClosingHeight: closeChanInfo.CloseHeight,
- CloseType: closeChanInfo.CloseType,
- PutResolverReport: func(tx kvdb.RwTx,
- report *channeldb.ResolverReport) error {
-
- return c.chanSource.PutResolverReport(
- tx, c.cfg.ChainHash, &chanPoint, report,
- )
- },
- FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) {
- chanStateDB := c.chanSource.ChannelStateDB()
- return chanStateDB.FetchHistoricalChannel(&chanPoint)
- },
- FindOutgoingHTLCDeadline: func(
- htlc channeldb.HTLC) fn.Option[int32] {
-
- return c.FindOutgoingHTLCDeadline(
- closeChanInfo.ShortChanID, htlc,
- )
- },
- }
- chanLog, err := newBoltArbitratorLog(
- c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint,
- )
- if err != nil {
- return err
- }
- arbCfg.MarkChannelResolved = func() error {
- if c.cfg.NotifyFullyResolvedChannel != nil {
- c.cfg.NotifyFullyResolvedChannel(chanPoint)
- }
-
- return c.ResolveContract(chanPoint)
- }
-
- // We create an empty map of HTLC's here since it's possible
- // that the channel is in StateDefault and updateActiveHTLCs is
- // called. We want to avoid writing to an empty map. Since the
- // channel is already in the process of being resolved, no new
- // HTLCs will be added.
- c.activeChannels[chanPoint] = NewChannelArbitrator(
- arbCfg, make(map[HtlcSetKey]htlcSet), chanLog,
- )
- }
-
// Now, we'll start all chain watchers in parallel to shorten start up
// duration. In neutrino mode, this allows spend registrations to take
// advantage of batch spend reporting, instead of doing a single rescan
@@ -746,7 +645,7 @@ func (c *ChainArbitrator) Start() error {
// transaction.
var startStates map[wire.OutPoint]*chanArbStartState
- err = kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error {
+ err := kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error {
for _, arbitrator := range c.activeChannels {
startState, err := arbitrator.getStartState(tx)
if err != nil {
@@ -778,24 +677,17 @@ func (c *ChainArbitrator) Start() error {
arbitrator.cfg.ChanPoint)
}
- if err := arbitrator.Start(startState); err != nil {
+ if err := arbitrator.Start(startState, c.beat); err != nil {
stopAndLog()
return err
}
}
- // Subscribe to a single stream of block epoch notifications that we
- // will dispatch to all active arbitrators.
- blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil)
- if err != nil {
- return err
- }
-
// Start our goroutine which will dispatch blocks to each arbitrator.
c.wg.Add(1)
go func() {
defer c.wg.Done()
- c.dispatchBlocks(blockEpoch)
+ c.dispatchBlocks()
}()
// TODO(roasbeef): eventually move all breach watching here
@@ -803,94 +695,22 @@ func (c *ChainArbitrator) Start() error {
return nil
}
-// blockRecipient contains the information we need to dispatch a block to a
-// channel arbitrator.
-type blockRecipient struct {
- // chanPoint is the funding outpoint of the channel.
- chanPoint wire.OutPoint
-
- // blocks is the channel that new block heights are sent into. This
- // channel should be sufficiently buffered as to not block the sender.
- blocks chan<- int32
-
- // quit is closed if the receiving entity is shutting down.
- quit chan struct{}
-}
-
// dispatchBlocks consumes a block epoch notification stream and dispatches
// blocks to each of the chain arb's active channel arbitrators. This function
// must be run in a goroutine.
-func (c *ChainArbitrator) dispatchBlocks(
- blockEpoch *chainntnfs.BlockEpochEvent) {
-
- // getRecipients is a helper function which acquires the chain arb
- // lock and returns a set of block recipients which can be used to
- // dispatch blocks.
- getRecipients := func() []blockRecipient {
- c.Lock()
- blocks := make([]blockRecipient, 0, len(c.activeChannels))
- for _, channel := range c.activeChannels {
- blocks = append(blocks, blockRecipient{
- chanPoint: channel.cfg.ChanPoint,
- blocks: channel.blocks,
- quit: channel.quit,
- })
- }
- c.Unlock()
-
- return blocks
- }
-
- // On exit, cancel our blocks subscription and close each block channel
- // so that the arbitrators know they will no longer be receiving blocks.
- defer func() {
- blockEpoch.Cancel()
-
- recipients := getRecipients()
- for _, recipient := range recipients {
- close(recipient.blocks)
- }
- }()
-
+func (c *ChainArbitrator) dispatchBlocks() {
// Consume block epochs until we receive the instruction to shutdown.
for {
select {
// Consume block epochs, exiting if our subscription is
// terminated.
- case block, ok := <-blockEpoch.Epochs:
- if !ok {
- log.Trace("dispatchBlocks block epoch " +
- "cancelled")
- return
- }
+ case beat := <-c.BlockbeatChan:
+ // Set the current blockbeat.
+ c.beat = beat
- // Get the set of currently active channels block
- // subscription channels and dispatch the block to
- // each.
- for _, recipient := range getRecipients() {
- select {
- // Deliver the block to the arbitrator.
- case recipient.blocks <- block.Height:
-
- // If the recipient is shutting down, exit
- // without delivering the block. This may be
- // the case when two blocks are mined in quick
- // succession, and the arbitrator resolves
- // after the first block, and does not need to
- // consume the second block.
- case <-recipient.quit:
- log.Debugf("channel: %v exit without "+
- "receiving block: %v",
- recipient.chanPoint,
- block.Height)
-
- // If the chain arb is shutting down, we don't
- // need to deliver any more blocks (everything
- // will be shutting down).
- case <-c.quit:
- return
- }
- }
+ // Send this blockbeat to all the active channels and
+ // wait for them to finish processing it.
+ c.handleBlockbeat(beat)
// Exit if the chain arbitrator is shutting down.
case <-c.quit:
@@ -899,6 +719,32 @@ func (c *ChainArbitrator) dispatchBlocks(
}
}
+// handleBlockbeat sends the blockbeat to all active channel arbitrator in
+// parallel and wait for them to finish processing it.
+func (c *ChainArbitrator) handleBlockbeat(beat chainio.Blockbeat) {
+ // Read the active channels in a lock.
+ c.Lock()
+
+ // Create a slice to record active channel arbitrator.
+ channels := make([]chainio.Consumer, 0, len(c.activeChannels))
+
+ // Copy the active channels to the slice.
+ for _, channel := range c.activeChannels {
+ channels = append(channels, channel)
+ }
+
+ c.Unlock()
+
+ // Iterate all the copied channels and send the blockbeat to them.
+ //
+ // NOTE: This method will timeout if the processing of blocks of the
+ // subsystems is too long (60s).
+ err := chainio.DispatchConcurrent(beat, channels)
+
+ // Notify the chain arbitrator has processed the block.
+ c.NotifyBlockProcessed(beat, err)
+}
+
// republishClosingTxs will load any stored cooperative or unilateral closing
// transactions and republish them. This helps ensure propagation of the
// transactions in the event that prior publications failed.
@@ -1248,7 +1094,7 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error
// arbitrators, then launch it.
c.activeChannels[chanPoint] = channelArb
- if err := channelArb.Start(nil); err != nil {
+ if err := channelArb.Start(nil, c.beat); err != nil {
return err
}
@@ -1361,3 +1207,152 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID,
// TODO(roasbeef): arbitration reports
// * types: contested, waiting for success conf, etc
+
+// NOTE: part of the `chainio.Consumer` interface.
+func (c *ChainArbitrator) Name() string {
+ return "ChainArbitrator"
+}
+
+// loadOpenChannels loads all channels that are currently open in the database
+// and registers them with the chainWatcher for future notification.
+func (c *ChainArbitrator) loadOpenChannels() error {
+ openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels()
+ if err != nil {
+ return err
+ }
+
+ if len(openChannels) == 0 {
+ return nil
+ }
+
+ log.Infof("Creating ChannelArbitrators for %v active channels",
+ len(openChannels))
+
+ // For each open channel, we'll configure then launch a corresponding
+ // ChannelArbitrator.
+ for _, channel := range openChannels {
+ chanPoint := channel.FundingOutpoint
+ channel := channel
+
+ // First, we'll create an active chainWatcher for this channel
+ // to ensure that we detect any relevant on chain events.
+ breachClosure := func(ret *lnwallet.BreachRetribution) error {
+ return c.cfg.ContractBreach(chanPoint, ret)
+ }
+
+ chainWatcher, err := newChainWatcher(
+ chainWatcherConfig{
+ chanState: channel,
+ notifier: c.cfg.Notifier,
+ signer: c.cfg.Signer,
+ isOurAddr: c.cfg.IsOurAddress,
+ contractBreach: breachClosure,
+ extractStateNumHint: lnwallet.GetStateNumHint,
+ auxLeafStore: c.cfg.AuxLeafStore,
+ auxResolver: c.cfg.AuxResolver,
+ },
+ )
+ if err != nil {
+ return err
+ }
+
+ c.activeWatchers[chanPoint] = chainWatcher
+ channelArb, err := newActiveChannelArbitrator(
+ channel, c, chainWatcher.SubscribeChannelEvents(),
+ )
+ if err != nil {
+ return err
+ }
+
+ c.activeChannels[chanPoint] = channelArb
+
+ // Republish any closing transactions for this channel.
+ err = c.republishClosingTxs(channel)
+ if err != nil {
+ log.Errorf("Failed to republish closing txs for "+
+ "channel %v", chanPoint)
+ }
+ }
+
+ return nil
+}
+
+// loadPendingCloseChannels loads all channels that are currently pending
+// closure in the database and registers them with the ChannelArbitrator to
+// continue the resolution process.
+func (c *ChainArbitrator) loadPendingCloseChannels() error {
+ chanStateDB := c.chanSource.ChannelStateDB()
+
+ closingChannels, err := chanStateDB.FetchClosedChannels(true)
+ if err != nil {
+ return err
+ }
+
+ if len(closingChannels) == 0 {
+ return nil
+ }
+
+ log.Infof("Creating ChannelArbitrators for %v closing channels",
+ len(closingChannels))
+
+ // Next, for each channel is the closing state, we'll launch a
+ // corresponding more restricted resolver, as we don't have to watch
+ // the chain any longer, only resolve the contracts on the confirmed
+ // commitment.
+ //nolint:ll
+ for _, closeChanInfo := range closingChannels {
+ // We can leave off the CloseContract and ForceCloseChan
+ // methods as the channel is already closed at this point.
+ chanPoint := closeChanInfo.ChanPoint
+ arbCfg := ChannelArbitratorConfig{
+ ChanPoint: chanPoint,
+ ShortChanID: closeChanInfo.ShortChanID,
+ ChainArbitratorConfig: c.cfg,
+ ChainEvents: &ChainEventSubscription{},
+ IsPendingClose: true,
+ ClosingHeight: closeChanInfo.CloseHeight,
+ CloseType: closeChanInfo.CloseType,
+ PutResolverReport: func(tx kvdb.RwTx,
+ report *channeldb.ResolverReport) error {
+
+ return c.chanSource.PutResolverReport(
+ tx, c.cfg.ChainHash, &chanPoint, report,
+ )
+ },
+ FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) {
+ return chanStateDB.FetchHistoricalChannel(&chanPoint)
+ },
+ FindOutgoingHTLCDeadline: func(
+ htlc channeldb.HTLC) fn.Option[int32] {
+
+ return c.FindOutgoingHTLCDeadline(
+ closeChanInfo.ShortChanID, htlc,
+ )
+ },
+ }
+ chanLog, err := newBoltArbitratorLog(
+ c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint,
+ )
+ if err != nil {
+ return err
+ }
+ arbCfg.MarkChannelResolved = func() error {
+ if c.cfg.NotifyFullyResolvedChannel != nil {
+ c.cfg.NotifyFullyResolvedChannel(chanPoint)
+ }
+
+ return c.ResolveContract(chanPoint)
+ }
+
+ // We create an empty map of HTLC's here since it's possible
+ // that the channel is in StateDefault and updateActiveHTLCs is
+ // called. We want to avoid writing to an empty map. Since the
+ // channel is already in the process of being resolved, no new
+ // HTLCs will be added.
+ c.activeChannels[chanPoint] = NewChannelArbitrator(
+ arbCfg, make(map[HtlcSetKey]htlcSet), chanLog,
+ )
+ }
+
+ return nil
+}
diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go
index fe2603ca5a..622686f76c 100644
--- a/contractcourt/chain_arbitrator_test.go
+++ b/contractcourt/chain_arbitrator_test.go
@@ -77,7 +77,6 @@ func TestChainArbitratorRepublishCloses(t *testing.T) {
ChainIO: &mock.ChainIO{},
Notifier: &mock.ChainNotifier{
SpendChan: make(chan *chainntnfs.SpendDetail),
- EpochChan: make(chan *chainntnfs.BlockEpoch),
ConfChan: make(chan *chainntnfs.TxConfirmation),
},
PublishTx: func(tx *wire.MsgTx, _ string) error {
@@ -91,7 +90,8 @@ func TestChainArbitratorRepublishCloses(t *testing.T) {
chainArbCfg, db,
)
- if err := chainArb.Start(); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chainArb.Start(beat); err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
@@ -158,7 +158,6 @@ func TestResolveContract(t *testing.T) {
ChainIO: &mock.ChainIO{},
Notifier: &mock.ChainNotifier{
SpendChan: make(chan *chainntnfs.SpendDetail),
- EpochChan: make(chan *chainntnfs.BlockEpoch),
ConfChan: make(chan *chainntnfs.TxConfirmation),
},
PublishTx: func(tx *wire.MsgTx, _ string) error {
@@ -175,7 +174,8 @@ func TestResolveContract(t *testing.T) {
chainArb := NewChainArbitrator(
chainArbCfg, db,
)
- if err := chainArb.Start(); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chainArb.Start(beat); err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go
index e79c8d546b..e29f21e7f4 100644
--- a/contractcourt/chain_watcher.go
+++ b/contractcourt/chain_watcher.go
@@ -18,7 +18,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
@@ -451,7 +451,7 @@ func (c *chainWatcher) handleUnknownLocalState(
leaseExpiry = c.cfg.chanState.ThawHeight
}
- remoteAuxLeaf := fn.ChainOption(
+ remoteAuxLeaf := fn.FlatMapOption(
func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
},
@@ -468,7 +468,7 @@ func (c *chainWatcher) handleUnknownLocalState(
// Next, we'll derive our script that includes the revocation base for
// the remote party allowing them to claim this output before the CSV
// delay if we breach.
- localAuxLeaf := fn.ChainOption(
+ localAuxLeaf := fn.FlatMapOption(
func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
},
@@ -1062,15 +1062,15 @@ func (c *chainWatcher) toSelfAmount(tx *wire.MsgTx) btcutil.Amount {
return false
}
- return fn.Any(c.cfg.isOurAddr, addrs)
+ return fn.Any(addrs, c.cfg.isOurAddr)
}
// Grab all of the outputs that correspond with our delivery address
// or our wallet is aware of.
- outs := fn.Filter(fn.PredOr(isDeliveryOutput, isWalletOutput), tx.TxOut)
+ outs := fn.Filter(tx.TxOut, fn.PredOr(isDeliveryOutput, isWalletOutput))
// Grab the values for those outputs.
- vals := fn.Map(func(o *wire.TxOut) int64 { return o.Value }, outs)
+ vals := fn.Map(outs, func(o *wire.TxOut) int64 { return o.Value })
// Return the sum.
return btcutil.Amount(fn.Sum(vals))
diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go
index 319b437e4e..9b554cb648 100644
--- a/contractcourt/channel_arbitrator.go
+++ b/contractcourt/channel_arbitrator.go
@@ -14,8 +14,9 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
+ "github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
@@ -330,6 +331,10 @@ type ChannelArbitrator struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.
+ // Embed the blockbeat consumer struct to get access to the method
+ // `NotifyBlockProcessed` and the `BlockbeatChan`.
+ chainio.BeatConsumer
+
// startTimestamp is the time when this ChannelArbitrator was started.
startTimestamp time.Time
@@ -352,11 +357,6 @@ type ChannelArbitrator struct {
// to do its duty.
cfg ChannelArbitratorConfig
- // blocks is a channel that the arbitrator will receive new blocks on.
- // This channel should be buffered by so that it does not block the
- // sender.
- blocks chan int32
-
// signalUpdates is a channel that any new live signals for the channel
// we're watching over will be sent.
signalUpdates chan *signalUpdateMsg
@@ -404,9 +404,8 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig,
unmerged[RemotePendingHtlcSet] = htlcSets[RemotePendingHtlcSet]
}
- return &ChannelArbitrator{
+ c := &ChannelArbitrator{
log: log,
- blocks: make(chan int32, arbitratorBlockBufferSize),
signalUpdates: make(chan *signalUpdateMsg),
resolutionSignal: make(chan struct{}),
forceCloseReqs: make(chan *forceCloseReq),
@@ -415,8 +414,16 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig,
cfg: cfg,
quit: make(chan struct{}),
}
+
+ // Mount the block consumer.
+ c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name())
+
+ return c
}
+// Compile-time check for the chainio.Consumer interface.
+var _ chainio.Consumer = (*ChannelArbitrator)(nil)
+
// chanArbStartState contains the information from disk that we need to start
// up a channel arbitrator.
type chanArbStartState struct {
@@ -455,7 +462,9 @@ func (c *ChannelArbitrator) getStartState(tx kvdb.RTx) (*chanArbStartState,
// Start starts all the goroutines that the ChannelArbitrator needs to operate.
// If takes a start state, which will be looked up on disk if it is not
// provided.
-func (c *ChannelArbitrator) Start(state *chanArbStartState) error {
+func (c *ChannelArbitrator) Start(state *chanArbStartState,
+ beat chainio.Blockbeat) error {
+
if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
return nil
}
@@ -477,10 +486,22 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error {
// Set our state from our starting state.
c.state = state.currentState
- _, bestHeight, err := c.cfg.ChainIO.GetBestBlock()
- if err != nil {
- return err
- }
+ // Get the starting height.
+ bestHeight := beat.Height()
+
+ c.wg.Add(1)
+ go c.channelAttendant(bestHeight, state.commitSet)
+
+ return nil
+}
+
+// progressStateMachineAfterRestart attempts to progress the state machine
+// after a restart. This makes sure that if the state transition failed, we
+// will try to progress the state machine again. Moreover it will relaunch
+// resolvers if the channel is still in the pending close state and has not
+// been fully resolved yet.
+func (c *ChannelArbitrator) progressStateMachineAfterRestart(bestHeight int32,
+ commitSet *CommitSet) error {
// If the channel has been marked pending close in the database, and we
// haven't transitioned the state machine to StateContractClosed (or a
@@ -527,7 +548,7 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error {
// on-chain state, and our set of active contracts.
startingState := c.state
nextState, _, err := c.advanceState(
- triggerHeight, trigger, state.commitSet,
+ triggerHeight, trigger, commitSet,
)
if err != nil {
switch err {
@@ -564,14 +585,12 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error {
// receive a chain event from the chain watcher that the
// commitment has been confirmed on chain, and before we
// advance our state step, we call InsertConfirmedCommitSet.
- err := c.relaunchResolvers(state.commitSet, triggerHeight)
+ err := c.relaunchResolvers(commitSet, triggerHeight)
if err != nil {
return err
}
}
- c.wg.Add(1)
- go c.channelAttendant(bestHeight)
return nil
}
@@ -797,7 +816,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet,
// TODO(roasbeef): this isn't re-launched?
}
- c.launchResolvers(unresolvedContracts, true)
+ c.resolveContracts(unresolvedContracts)
return nil
}
@@ -997,7 +1016,7 @@ func (c *ChannelArbitrator) stateStep(
getIdx := func(htlc channeldb.HTLC) uint64 {
return htlc.HtlcIndex
}
- dustHTLCSet := fn.NewSet(fn.Map(getIdx, dustHTLCs)...)
+ dustHTLCSet := fn.NewSet(fn.Map(dustHTLCs, getIdx)...)
err = c.abandonForwards(dustHTLCSet)
if err != nil {
return StateError, closeTx, err
@@ -1306,7 +1325,7 @@ func (c *ChannelArbitrator) stateStep(
return htlc.HtlcIndex
}
remoteDangling := fn.NewSet(fn.Map(
- getIdx, htlcActions[HtlcFailDanglingAction],
+ htlcActions[HtlcFailDanglingAction], getIdx,
)...)
err := c.abandonForwards(remoteDangling)
if err != nil {
@@ -1336,7 +1355,7 @@ func (c *ChannelArbitrator) stateStep(
// Finally, we'll launch all the required contract resolvers.
// Once they're all resolved, we're no longer needed.
- c.launchResolvers(resolvers, false)
+ c.resolveContracts(resolvers)
nextState = StateWaitingFullResolution
@@ -1559,17 +1578,72 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32,
return fn.Some(int32(deadline)), valueLeft, nil
}
-// launchResolvers updates the activeResolvers list and starts the resolvers.
-func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver,
- immediate bool) {
-
+// resolveContracts updates the activeResolvers list and starts to resolve each
+// contract concurrently, and launches them.
+func (c *ChannelArbitrator) resolveContracts(resolvers []ContractResolver) {
c.activeResolversLock.Lock()
- defer c.activeResolversLock.Unlock()
-
c.activeResolvers = resolvers
+ c.activeResolversLock.Unlock()
+
+ // Launch all resolvers.
+ c.launchResolvers()
+
for _, contract := range resolvers {
c.wg.Add(1)
- go c.resolveContract(contract, immediate)
+ go c.resolveContract(contract)
+ }
+}
+
+// launchResolvers launches all the active resolvers concurrently.
+func (c *ChannelArbitrator) launchResolvers() {
+ c.activeResolversLock.Lock()
+ resolvers := c.activeResolvers
+ c.activeResolversLock.Unlock()
+
+ // errChans is a map of channels that will be used to receive errors
+ // returned from launching the resolvers.
+ errChans := make(map[ContractResolver]chan error, len(resolvers))
+
+ // Launch each resolver in goroutines.
+ for _, r := range resolvers {
+ // If the contract is already resolved, there's no need to
+ // launch it again.
+ if r.IsResolved() {
+ log.Debugf("ChannelArbitrator(%v): skipping resolver "+
+ "%T as it's already resolved", c.cfg.ChanPoint,
+ r)
+
+ continue
+ }
+
+ // Create a signal chan.
+ errChan := make(chan error, 1)
+ errChans[r] = errChan
+
+ go func() {
+ err := r.Launch()
+ errChan <- err
+ }()
+ }
+
+ // Wait for all resolvers to finish launching.
+ for r, errChan := range errChans {
+ select {
+ case err := <-errChan:
+ if err == nil {
+ continue
+ }
+
+ log.Errorf("ChannelArbitrator(%v): unable to launch "+
+ "contract resolver(%T): %v", c.cfg.ChanPoint, r,
+ err)
+
+ case <-c.quit:
+ log.Debugf("ChannelArbitrator quit signal received, " +
+ "exit launchResolvers")
+
+ return
+ }
}
}
@@ -1593,8 +1667,8 @@ func (c *ChannelArbitrator) advanceState(
for {
priorState = c.state
log.Debugf("ChannelArbitrator(%v): attempting state step with "+
- "trigger=%v from state=%v", c.cfg.ChanPoint, trigger,
- priorState)
+ "trigger=%v from state=%v at height=%v",
+ c.cfg.ChanPoint, trigger, priorState, triggerHeight)
nextState, closeTx, err := c.stateStep(
triggerHeight, trigger, confCommitSet,
@@ -2541,9 +2615,7 @@ func (c *ChannelArbitrator) replaceResolver(oldResolver,
// contracts.
//
// NOTE: This MUST be run as a goroutine.
-func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver,
- immediate bool) {
-
+func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver) {
defer c.wg.Done()
log.Debugf("ChannelArbitrator(%v): attempting to resolve %T",
@@ -2564,7 +2636,7 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver,
default:
// Otherwise, we'll attempt to resolve the current
// contract.
- nextContract, err := currentContract.Resolve(immediate)
+ nextContract, err := currentContract.Resolve()
if err != nil {
if err == errResolverShuttingDown {
return
@@ -2613,6 +2685,13 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver,
// loop.
currentContract = nextContract
+ // Launch the new contract.
+ err = currentContract.Launch()
+ if err != nil {
+ log.Errorf("Failed to launch %T: %v",
+ currentContract, err)
+ }
+
// If this contract is actually fully resolved, then
// we'll mark it as such within the database.
case currentContract.IsResolved():
@@ -2716,44 +2795,49 @@ func (c *ChannelArbitrator) updateActiveHTLCs() {
// Nursery for incubation, and ultimate sweeping.
//
// NOTE: This MUST be run as a goroutine.
-func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
+//
+//nolint:funlen
+func (c *ChannelArbitrator) channelAttendant(bestHeight int32,
+ commitSet *CommitSet) {
// TODO(roasbeef): tell top chain arb we're done
defer func() {
c.wg.Done()
}()
+ err := c.progressStateMachineAfterRestart(bestHeight, commitSet)
+ if err != nil {
+ // In case of an error, we return early but we do not shutdown
+ // LND, because there might be other channels that still can be
+ // resolved and we don't want to interfere with that.
+ // We continue to run the channel attendant in case the channel
+ // closes via other means for example the remote pary force
+ // closes the channel. So we log the error and continue.
+ log.Errorf("Unable to progress state machine after "+
+ "restart: %v", err)
+ }
+
for {
select {
// A new block has arrived, we'll examine all the active HTLC's
// to see if any of them have expired, and also update our
// track of the best current height.
- case blockHeight, ok := <-c.blocks:
- if !ok {
- return
- }
- bestHeight = blockHeight
+ case beat := <-c.BlockbeatChan:
+ bestHeight = beat.Height()
- // If we're not in the default state, then we can
- // ignore this signal as we're waiting for contract
- // resolution.
- if c.state != StateDefault {
- continue
- }
+ log.Debugf("ChannelArbitrator(%v): new block height=%v",
+ c.cfg.ChanPoint, bestHeight)
- // Now that a new block has arrived, we'll attempt to
- // advance our state forward.
- nextState, _, err := c.advanceState(
- uint32(bestHeight), chainTrigger, nil,
- )
+ err := c.handleBlockbeat(beat)
if err != nil {
- log.Errorf("Unable to advance state: %v", err)
+ log.Errorf("Handle block=%v got err: %v",
+ bestHeight, err)
}
// If as a result of this trigger, the contract is
// fully resolved, then well exit.
- if nextState == StateFullyResolved {
+ if c.state == StateFullyResolved {
return
}
@@ -2802,14 +2886,12 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
// We have broadcasted our commitment, and it is now confirmed
// on-chain.
case closeInfo := <-c.cfg.ChainEvents.LocalUnilateralClosure:
- log.Infof("ChannelArbitrator(%v): local on-chain "+
- "channel close", c.cfg.ChanPoint)
-
if c.state != StateCommitmentBroadcasted {
log.Errorf("ChannelArbitrator(%v): unexpected "+
"local on-chain channel close",
c.cfg.ChanPoint)
}
+
closeTx := closeInfo.CloseTx
resolutions, err := closeInfo.ContractResolutions.
@@ -2837,6 +2919,10 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
return
}
+ log.Infof("ChannelArbitrator(%v): local force close "+
+ "tx=%v confirmed", c.cfg.ChanPoint,
+ closeTx.TxHash())
+
contractRes := &ContractResolutions{
CommitHash: closeTx.TxHash(),
CommitResolution: resolutions.CommitResolution,
@@ -3104,6 +3190,37 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) {
}
}
+// handleBlockbeat processes a newly received blockbeat by advancing the
+// arbitrator's internal state using the received block height.
+func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Blockbeat) error {
+ // Notify we've processed the block.
+ defer c.NotifyBlockProcessed(beat, nil)
+
+ // Try to advance the state if we are in StateDefault.
+ if c.state == StateDefault {
+ // Now that a new block has arrived, we'll attempt to advance
+ // our state forward.
+ _, _, err := c.advanceState(
+ uint32(beat.Height()), chainTrigger, nil,
+ )
+ if err != nil {
+ return fmt.Errorf("unable to advance state: %w", err)
+ }
+ }
+
+ // Launch all active resolvers when a new blockbeat is received.
+ c.launchResolvers()
+
+ return nil
+}
+
+// Name returns a human-readable string for this subsystem.
+//
+// NOTE: Part of chainio.Consumer interface.
+func (c *ChannelArbitrator) Name() string {
+ return fmt.Sprintf("ChannelArbitrator(%v)", c.cfg.ChanPoint)
+}
+
// checkLegacyBreach returns StateFullyResolved if the channel was closed with
// a breach transaction before the channel arbitrator launched its own breach
// resolver. StateContractClosed is returned if this is a modern breach close
diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go
index 92ad608eb9..827e10b5c7 100644
--- a/contractcourt/channel_arbitrator_test.go
+++ b/contractcourt/channel_arbitrator_test.go
@@ -13,14 +13,17 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
+ "github.com/davecgh/go-spew/spew"
+ "github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntest/mock"
+ "github.com/lightningnetwork/lnd/lntest/wait"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
@@ -226,6 +229,15 @@ func (c *chanArbTestCtx) CleanUp() {
}
}
+// receiveBlockbeat mocks the behavior of a blockbeat being sent by the
+// BlockbeatDispatcher, which essentially mocks the method `ProcessBlock`.
+func (c *chanArbTestCtx) receiveBlockbeat(height int) {
+ go func() {
+ beat := newBeatFromHeight(int32(height))
+ c.chanArb.BlockbeatChan <- beat
+ }()
+}
+
// AssertStateTransitions asserts that the state machine steps through the
// passed states in order.
func (c *chanArbTestCtx) AssertStateTransitions(expectedStates ...ArbitratorState) {
@@ -285,7 +297,8 @@ func (c *chanArbTestCtx) Restart(restartClosure func(*chanArbTestCtx)) (*chanArb
restartClosure(newCtx)
}
- if err := newCtx.chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := newCtx.chanArb.Start(nil, beat); err != nil {
return nil, err
}
@@ -512,7 +525,8 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) {
chanArbCtx, err := createTestChannelArbitrator(t, log)
require.NoError(t, err, "unable to create ChannelArbitrator")
- if err := chanArbCtx.chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArbCtx.chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
t.Cleanup(func() {
@@ -570,7 +584,8 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) {
require.NoError(t, err, "unable to create ChannelArbitrator")
chanArb := chanArbCtx.chanArb
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer chanArb.Stop()
@@ -623,7 +638,8 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) {
require.NoError(t, err, "unable to create ChannelArbitrator")
chanArb := chanArbCtx.chanArb
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer chanArb.Stop()
@@ -735,7 +751,8 @@ func TestChannelArbitratorBreachClose(t *testing.T) {
chanArb.cfg.PreimageDB = newMockWitnessBeacon()
chanArb.cfg.Registry = &mockRegistry{}
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
t.Cleanup(func() {
@@ -862,7 +879,8 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
chanArb.cfg.PreimageDB = newMockWitnessBeacon()
chanArb.cfg.Registry = &mockRegistry{}
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer chanArb.Stop()
@@ -965,6 +983,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
},
},
}
+ closeTxid := closeTx.TxHash()
htlcOp := wire.OutPoint{
Hash: closeTx.TxHash(),
@@ -1036,7 +1055,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
}
require.Equal(t, expectedFinalHtlcs, chanArbCtx.finalHtlcs)
- // We'll no re-create the resolver, notice that we use the existing
+ // We'll now re-create the resolver, notice that we use the existing
// arbLog so it carries over the same on-disk state.
chanArbCtxNew, err := chanArbCtx.Restart(nil)
require.NoError(t, err, "unable to create ChannelArbitrator")
@@ -1045,10 +1064,19 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
// Post restart, it should be the case that our resolver was properly
// supplemented, and we only have a single resolver in the final set.
- if len(chanArb.activeResolvers) != 1 {
- t.Fatalf("expected single resolver, instead got: %v",
- len(chanArb.activeResolvers))
- }
+ // The resolvers are added concurrently so we need to wait here.
+ err = wait.NoError(func() error {
+ chanArb.activeResolversLock.Lock()
+ defer chanArb.activeResolversLock.Unlock()
+
+ if len(chanArb.activeResolvers) != 1 {
+ return fmt.Errorf("expected single resolver, instead "+
+ "got: %v", len(chanArb.activeResolvers))
+ }
+
+ return nil
+ }, defaultTimeout)
+ require.NoError(t, err)
// We'll now examine the in-memory state of the active resolvers to
// ensure t hey were populated properly.
@@ -1086,7 +1114,11 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
// Notify resolver that the HTLC output of the commitment has been
// spent.
- oldNotifier.SpendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx}
+ oldNotifier.SpendChan <- &chainntnfs.SpendDetail{
+ SpendingTx: closeTx,
+ SpentOutPoint: &wire.OutPoint{},
+ SpenderTxHash: &closeTxid,
+ }
// Finally, we should also receive a resolution message instructing the
// switch to cancel back the HTLC.
@@ -1113,8 +1145,12 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
default:
}
- // Notify resolver that the second level transaction is spent.
- oldNotifier.SpendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx}
+ // Notify resolver that the output of the timeout tx has been spent.
+ oldNotifier.SpendChan <- &chainntnfs.SpendDetail{
+ SpendingTx: closeTx,
+ SpentOutPoint: &wire.OutPoint{},
+ SpenderTxHash: &closeTxid,
+ }
// At this point channel should be marked as resolved.
chanArbCtxNew.AssertStateTransitions(StateFullyResolved)
@@ -1138,7 +1174,8 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) {
require.NoError(t, err, "unable to create ChannelArbitrator")
chanArb := chanArbCtx.chanArb
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer chanArb.Stop()
@@ -1245,7 +1282,8 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) {
require.NoError(t, err, "unable to create ChannelArbitrator")
chanArb := chanArbCtx.chanArb
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer chanArb.Stop()
@@ -1351,7 +1389,8 @@ func TestChannelArbitratorPersistence(t *testing.T) {
require.NoError(t, err, "unable to create ChannelArbitrator")
chanArb := chanArbCtx.chanArb
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
@@ -1469,7 +1508,8 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) {
require.NoError(t, err, "unable to create ChannelArbitrator")
chanArb := chanArbCtx.chanArb
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
@@ -1656,7 +1696,8 @@ func TestChannelArbitratorCommitFailure(t *testing.T) {
}
chanArb := chanArbCtx.chanArb
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
@@ -1740,7 +1781,8 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) {
chanArb.cfg.ClosingHeight = 100
chanArb.cfg.CloseType = channeldb.RemoteForceClose
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(100)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
@@ -1770,7 +1812,8 @@ func TestChannelArbitratorAlreadyForceClosed(t *testing.T) {
chanArbCtx, err := createTestChannelArbitrator(t, log)
require.NoError(t, err, "unable to create ChannelArbitrator")
chanArb := chanArbCtx.chanArb
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
defer chanArb.Stop()
@@ -1868,9 +1911,10 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
t.Fatalf("unable to create ChannelArbitrator: %v", err)
}
chanArb := chanArbCtx.chanArb
- if err := chanArb.Start(nil); err != nil {
- t.Fatalf("unable to start ChannelArbitrator: %v", err)
- }
+ beat := newBeatFromHeight(0)
+ err = chanArb.Start(nil, beat)
+ require.NoError(t, err)
+
defer chanArb.Stop()
// Now that our channel arb has started, we'll set up
@@ -1914,7 +1958,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// now mine a block (height 5), which is 5 blocks away
// (our grace delta) from the expiry of that HTLC.
case testCase.htlcExpired:
- chanArbCtx.chanArb.blocks <- 5
+ beat := newBeatFromHeight(5)
+ chanArbCtx.chanArb.BlockbeatChan <- beat
// Otherwise, we'll just trigger a regular force close
// request.
@@ -2026,8 +2071,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) {
// so instead, we'll mine another block which'll cause
// it to re-examine its state and realize there're no
// more HTLCs.
- chanArbCtx.chanArb.blocks <- 6
- chanArbCtx.AssertStateTransitions(StateFullyResolved)
+ chanArbCtx.receiveBlockbeat(6)
})
}
}
@@ -2064,7 +2108,8 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) {
return false
}
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
t.Cleanup(func() {
@@ -2098,13 +2143,15 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) {
// We will advance the uptime to 10 seconds which should be still within
// the grace period and should not trigger going to chain.
testClock.SetTime(startTime.Add(time.Second * 10))
- chanArbCtx.chanArb.blocks <- 5
+ beat = newBeatFromHeight(5)
+ chanArbCtx.chanArb.BlockbeatChan <- beat
chanArbCtx.AssertState(StateDefault)
// We will advance the uptime to 16 seconds which should trigger going
// to chain.
testClock.SetTime(startTime.Add(time.Second * 16))
- chanArbCtx.chanArb.blocks <- 6
+ beat = newBeatFromHeight(6)
+ chanArbCtx.chanArb.BlockbeatChan <- beat
chanArbCtx.AssertStateTransitions(
StateBroadcastCommit,
StateCommitmentBroadcasted,
@@ -2217,8 +2264,8 @@ func TestRemoteCloseInitiator(t *testing.T) {
"ChannelArbitrator: %v", err)
}
chanArb := chanArbCtx.chanArb
-
- if err := chanArb.Start(nil); err != nil {
+ beat := newBeatFromHeight(0)
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start "+
"ChannelArbitrator: %v", err)
}
@@ -2472,7 +2519,7 @@ func TestSweepAnchors(t *testing.T) {
// Set current block height.
heightHint := uint32(1000)
- chanArbCtx.chanArb.blocks <- int32(heightHint)
+ chanArbCtx.receiveBlockbeat(int(heightHint))
htlcIndexBase := uint64(99)
deadlineDelta := uint32(10)
@@ -2635,7 +2682,7 @@ func TestSweepLocalAnchor(t *testing.T) {
// Set current block height.
heightHint := uint32(1000)
- chanArbCtx.chanArb.blocks <- int32(heightHint)
+ chanArbCtx.receiveBlockbeat(int(heightHint))
htlcIndex := uint64(99)
deadlineDelta := uint32(10)
@@ -2769,7 +2816,9 @@ func TestChannelArbitratorAnchors(t *testing.T) {
},
}
- if err := chanArb.Start(nil); err != nil {
+ heightHint := uint32(1000)
+ beat := newBeatFromHeight(int32(heightHint))
+ if err := chanArb.Start(nil, beat); err != nil {
t.Fatalf("unable to start ChannelArbitrator: %v", err)
}
t.Cleanup(func() {
@@ -2781,27 +2830,28 @@ func TestChannelArbitratorAnchors(t *testing.T) {
}
chanArb.UpdateContractSignals(signals)
- // Set current block height.
- heightHint := uint32(1000)
- chanArbCtx.chanArb.blocks <- int32(heightHint)
-
htlcAmt := lnwire.MilliSatoshi(1_000_000)
// Create testing HTLCs.
- deadlineDelta := uint32(10)
- deadlinePreimageDelta := deadlineDelta + 2
+ spendingHeight := uint32(beat.Height())
+ deadlineDelta := uint32(100)
+
+ deadlinePreimageDelta := deadlineDelta
htlcWithPreimage := channeldb.HTLC{
- HtlcIndex: 99,
- RefundTimeout: heightHint + deadlinePreimageDelta,
+ HtlcIndex: 99,
+ // RefundTimeout is 101.
+ RefundTimeout: spendingHeight + deadlinePreimageDelta,
RHash: rHash,
Incoming: true,
Amt: htlcAmt,
}
+ expectedDeadline := deadlineDelta/2 + spendingHeight
- deadlineHTLCdelta := deadlineDelta + 3
+ deadlineHTLCdelta := deadlineDelta + 40
htlc := channeldb.HTLC{
- HtlcIndex: 100,
- RefundTimeout: heightHint + deadlineHTLCdelta,
+ HtlcIndex: 100,
+ // RefundTimeout is 141.
+ RefundTimeout: spendingHeight + deadlineHTLCdelta,
Amt: htlcAmt,
}
@@ -2886,7 +2936,9 @@ func TestChannelArbitratorAnchors(t *testing.T) {
//nolint:ll
chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{
- SpendDetail: &chainntnfs.SpendDetail{},
+ SpendDetail: &chainntnfs.SpendDetail{
+ SpendingHeight: int32(spendingHeight),
+ },
LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{
CloseTx: closeTx,
ContractResolutions: fn.Some(lnwallet.ContractResolutions{
@@ -2950,12 +3002,14 @@ func TestChannelArbitratorAnchors(t *testing.T) {
// to htlcWithPreimage's CLTV.
require.Equal(t, 2, len(chanArbCtx.sweeper.deadlines))
require.EqualValues(t,
- heightHint+deadlinePreimageDelta/2,
- chanArbCtx.sweeper.deadlines[0],
+ expectedDeadline,
+ chanArbCtx.sweeper.deadlines[0], "want %d, got %d",
+ expectedDeadline, chanArbCtx.sweeper.deadlines[0],
)
require.EqualValues(t,
- heightHint+deadlinePreimageDelta/2,
- chanArbCtx.sweeper.deadlines[1],
+ expectedDeadline,
+ chanArbCtx.sweeper.deadlines[1], "want %d, got %d",
+ expectedDeadline, chanArbCtx.sweeper.deadlines[1],
)
}
@@ -3000,9 +3054,12 @@ func TestChannelArbitratorStartForceCloseFail(t *testing.T) {
{
name: "Commitment is rejected with an " +
"unmatched error",
- broadcastErr: fmt.Errorf("Reject Commitment Tx"),
- expectedState: StateBroadcastCommit,
- expectedStartup: false,
+ broadcastErr: fmt.Errorf("Reject Commitment Tx"),
+ expectedState: StateBroadcastCommit,
+ // We should still be able to start up since we other
+ // channels might be closing as well and we should
+ // resolve the contracts.
+ expectedStartup: true,
},
// We started after the DLP was triggered, and try to force
@@ -3054,7 +3111,8 @@ func TestChannelArbitratorStartForceCloseFail(t *testing.T) {
return test.broadcastErr
}
- err = chanArb.Start(nil)
+ beat := newBeatFromHeight(0)
+ err = chanArb.Start(nil, beat)
if !test.expectedStartup {
require.ErrorIs(t, err, test.broadcastErr)
@@ -3102,7 +3160,8 @@ func assertResolverReport(t *testing.T, reports chan *channeldb.ResolverReport,
select {
case report := <-reports:
if !reflect.DeepEqual(report, expected) {
- t.Fatalf("expected: %v, got: %v", expected, report)
+ t.Fatalf("expected: %v, got: %v", spew.Sdump(expected),
+ spew.Sdump(report))
}
case <-time.After(defaultTimeout):
@@ -3133,3 +3192,11 @@ func (m *mockChannel) ForceCloseChan() (*wire.MsgTx, error) {
return &wire.MsgTx{}, nil
}
+
+func newBeatFromHeight(height int32) *chainio.Beat {
+ epoch := chainntnfs.BlockEpoch{
+ Height: height,
+ }
+
+ return chainio.NewBeat(epoch)
+}
diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go
index 4b47a34294..55ee08e5d8 100644
--- a/contractcourt/commit_sweep_resolver.go
+++ b/contractcourt/commit_sweep_resolver.go
@@ -13,7 +13,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/sweep"
@@ -39,9 +39,6 @@ type commitSweepResolver struct {
// this HTLC on-chain.
commitResolution lnwallet.CommitOutputResolution
- // resolved reflects if the contract has been fully resolved or not.
- resolved bool
-
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
@@ -88,7 +85,7 @@ func newCommitSweepResolver(res lnwallet.CommitOutputResolution,
chanPoint: chanPoint,
}
- r.initLogger(r)
+ r.initLogger(fmt.Sprintf("%T(%v)", r, r.commitResolution.SelfOutPoint))
r.initReport()
return r
@@ -101,36 +98,6 @@ func (c *commitSweepResolver) ResolverKey() []byte {
return key[:]
}
-// waitForHeight registers for block notifications and waits for the provided
-// block height to be reached.
-func waitForHeight(waitHeight uint32, notifier chainntnfs.ChainNotifier,
- quit <-chan struct{}) error {
-
- // Register for block epochs. After registration, the current height
- // will be sent on the channel immediately.
- blockEpochs, err := notifier.RegisterBlockEpochNtfn(nil)
- if err != nil {
- return err
- }
- defer blockEpochs.Cancel()
-
- for {
- select {
- case newBlock, ok := <-blockEpochs.Epochs:
- if !ok {
- return errResolverShuttingDown
- }
- height := newBlock.Height
- if height >= int32(waitHeight) {
- return nil
- }
-
- case <-quit:
- return errResolverShuttingDown
- }
- }
-}
-
// waitForSpend waits for the given outpoint to be spent, and returns the
// details of the spending tx.
func waitForSpend(op *wire.OutPoint, pkScript []byte, heightHint uint32,
@@ -195,203 +162,17 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) {
// returned.
//
// NOTE: This function MUST be run as a goroutine.
+
+// TODO(yy): fix the funlen in the next PR.
//
//nolint:funlen
-func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) {
+func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
// If we're already resolved, then we can exit early.
- if c.resolved {
+ if c.IsResolved() {
+ c.log.Errorf("already resolved")
return nil, nil
}
- confHeight, err := c.getCommitTxConfHeight()
- if err != nil {
- return nil, err
- }
-
- // Wait up until the CSV expires, unless we also have a CLTV that
- // expires after.
- unlockHeight := confHeight + c.commitResolution.MaturityDelay
- if c.hasCLTV() {
- unlockHeight = uint32(math.Max(
- float64(unlockHeight), float64(c.leaseExpiry),
- ))
- }
-
- c.log.Debugf("commit conf_height=%v, unlock_height=%v",
- confHeight, unlockHeight)
-
- // Update report now that we learned the confirmation height.
- c.reportLock.Lock()
- c.currentReport.MaturityHeight = unlockHeight
- c.reportLock.Unlock()
-
- // If there is a csv/cltv lock, we'll wait for that.
- if c.commitResolution.MaturityDelay > 0 || c.hasCLTV() {
- // Determine what height we should wait until for the locks to
- // expire.
- var waitHeight uint32
- switch {
- // If we have both a csv and cltv lock, we'll need to look at
- // both and see which expires later.
- case c.commitResolution.MaturityDelay > 0 && c.hasCLTV():
- c.log.Debugf("waiting for CSV and CLTV lock to expire "+
- "at height %v", unlockHeight)
- // If the CSV expires after the CLTV, or there is no
- // CLTV, then we can broadcast a sweep a block before.
- // Otherwise, we need to broadcast at our expected
- // unlock height.
- waitHeight = uint32(math.Max(
- float64(unlockHeight-1), float64(c.leaseExpiry),
- ))
-
- // If we only have a csv lock, wait for the height before the
- // lock expires as the spend path should be unlocked by then.
- case c.commitResolution.MaturityDelay > 0:
- c.log.Debugf("waiting for CSV lock to expire at "+
- "height %v", unlockHeight)
- waitHeight = unlockHeight - 1
- }
-
- err := waitForHeight(waitHeight, c.Notifier, c.quit)
- if err != nil {
- return nil, err
- }
- }
-
- var (
- isLocalCommitTx bool
-
- signDesc = c.commitResolution.SelfOutputSignDesc
- )
-
- switch {
- // For taproot channels, we'll know if this is the local commit based
- // on the timelock value. For remote commitment transactions, the
- // witness script has a timelock of 1.
- case c.chanType.IsTaproot():
- delayKey := c.localChanCfg.DelayBasePoint.PubKey
- nonDelayKey := c.localChanCfg.PaymentBasePoint.PubKey
-
- signKey := c.commitResolution.SelfOutputSignDesc.KeyDesc.PubKey
-
- // If the key in the script is neither of these, we shouldn't
- // proceed. This should be impossible.
- if !signKey.IsEqual(delayKey) && !signKey.IsEqual(nonDelayKey) {
- return nil, fmt.Errorf("unknown sign key %v", signKey)
- }
-
- // The commitment transaction is ours iff the signing key is
- // the delay key.
- isLocalCommitTx = signKey.IsEqual(delayKey)
-
- // The output is on our local commitment if the script starts with
- // OP_IF for the revocation clause. On the remote commitment it will
- // either be a regular P2WKH or a simple sig spend with a CSV delay.
- default:
- isLocalCommitTx = signDesc.WitnessScript[0] == txscript.OP_IF
- }
- isDelayedOutput := c.commitResolution.MaturityDelay != 0
-
- c.log.Debugf("isDelayedOutput=%v, isLocalCommitTx=%v", isDelayedOutput,
- isLocalCommitTx)
-
- // There're three types of commitments, those that have tweaks for the
- // remote key (us in this case), those that don't, and a third where
- // there is no tweak and the output is delayed. On the local commitment
- // our output will always be delayed. We'll rely on the presence of the
- // commitment tweak to discern which type of commitment this is.
- var witnessType input.WitnessType
- switch {
- // The local delayed output for a taproot channel.
- case isLocalCommitTx && c.chanType.IsTaproot():
- witnessType = input.TaprootLocalCommitSpend
-
- // The CSV 1 delayed output for a taproot channel.
- case !isLocalCommitTx && c.chanType.IsTaproot():
- witnessType = input.TaprootRemoteCommitSpend
-
- // Delayed output to us on our local commitment for a channel lease in
- // which we are the initiator.
- case isLocalCommitTx && c.hasCLTV():
- witnessType = input.LeaseCommitmentTimeLock
-
- // Delayed output to us on our local commitment.
- case isLocalCommitTx:
- witnessType = input.CommitmentTimeLock
-
- // A confirmed output to us on the remote commitment for a channel lease
- // in which we are the initiator.
- case isDelayedOutput && c.hasCLTV():
- witnessType = input.LeaseCommitmentToRemoteConfirmed
-
- // A confirmed output to us on the remote commitment.
- case isDelayedOutput:
- witnessType = input.CommitmentToRemoteConfirmed
-
- // A non-delayed output on the remote commitment where the key is
- // tweakless.
- case c.commitResolution.SelfOutputSignDesc.SingleTweak == nil:
- witnessType = input.CommitSpendNoDelayTweakless
-
- // A non-delayed output on the remote commitment where the key is
- // tweaked.
- default:
- witnessType = input.CommitmentNoDelay
- }
-
- c.log.Infof("Sweeping with witness type: %v", witnessType)
-
- // We'll craft an input with all the information required for the
- // sweeper to create a fully valid sweeping transaction to recover
- // these coins.
- var inp *input.BaseInput
- if c.hasCLTV() {
- inp = input.NewCsvInputWithCltv(
- &c.commitResolution.SelfOutPoint, witnessType,
- &c.commitResolution.SelfOutputSignDesc,
- c.broadcastHeight, c.commitResolution.MaturityDelay,
- c.leaseExpiry,
- input.WithResolutionBlob(
- c.commitResolution.ResolutionBlob,
- ),
- )
- } else {
- inp = input.NewCsvInput(
- &c.commitResolution.SelfOutPoint, witnessType,
- &c.commitResolution.SelfOutputSignDesc,
- c.broadcastHeight, c.commitResolution.MaturityDelay,
- input.WithResolutionBlob(
- c.commitResolution.ResolutionBlob,
- ),
- )
- }
-
- // TODO(roasbeef): instead of ading ctrl block to the sign desc, make
- // new input type, have sweeper set it?
-
- // Calculate the budget for the sweeping this input.
- budget := calculateBudget(
- btcutil.Amount(inp.SignDesc().Output.Value),
- c.Budget.ToLocalRatio, c.Budget.ToLocal,
- )
- c.log.Infof("Sweeping commit output using budget=%v", budget)
-
- // With our input constructed, we'll now offer it to the sweeper.
- resultChan, err := c.Sweeper.SweepInput(
- inp, sweep.Params{
- Budget: budget,
-
- // Specify a nil deadline here as there's no time
- // pressure.
- DeadlineHeight: fn.None[int32](),
- },
- )
- if err != nil {
- c.log.Errorf("unable to sweep input: %v", err)
-
- return nil, err
- }
-
var sweepTxID chainhash.Hash
// Sweeper is going to join this input with other inputs if possible
@@ -400,7 +181,7 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) {
// happen.
outcome := channeldb.ResolverOutcomeClaimed
select {
- case sweepResult := <-resultChan:
+ case sweepResult := <-c.sweepResultChan:
switch sweepResult.Err {
// If the remote party was able to sweep this output it's
// likely what we sent was actually a revoked commitment.
@@ -440,7 +221,7 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) {
report := c.currentReport.resolverReport(
&sweepTxID, channeldb.ResolverTypeCommit, outcome,
)
- c.resolved = true
+ c.markResolved()
// Checkpoint the resolver with a closure that will write the outcome
// of the resolver and its sweep transaction to disk.
@@ -452,17 +233,11 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) {
//
// NOTE: Part of the ContractResolver interface.
func (c *commitSweepResolver) Stop() {
+ c.log.Debugf("stopping...")
+ defer c.log.Debugf("stopped")
close(c.quit)
}
-// IsResolved returns true if the stored state in the resolve is fully
-// resolved. In this case the target output can be forgotten.
-//
-// NOTE: Part of the ContractResolver interface.
-func (c *commitSweepResolver) IsResolved() bool {
- return c.resolved
-}
-
// SupplementState allows the user of a ContractResolver to supplement it with
// state required for the proper resolution of a contract.
//
@@ -491,7 +266,7 @@ func (c *commitSweepResolver) Encode(w io.Writer) error {
return err
}
- if err := binary.Write(w, endian, c.resolved); err != nil {
+ if err := binary.Write(w, endian, c.IsResolved()); err != nil {
return err
}
if err := binary.Write(w, endian, c.broadcastHeight); err != nil {
@@ -526,9 +301,14 @@ func newCommitSweepResolverFromReader(r io.Reader, resCfg ResolverConfig) (
return nil, err
}
- if err := binary.Read(r, endian, &c.resolved); err != nil {
+ var resolved bool
+ if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
+ if resolved {
+ c.markResolved()
+ }
+
if err := binary.Read(r, endian, &c.broadcastHeight); err != nil {
return nil, err
}
@@ -545,7 +325,7 @@ func newCommitSweepResolverFromReader(r io.Reader, resCfg ResolverConfig) (
// removed this, but keep in mind that this data may still be present in
// the database.
- c.initLogger(c)
+ c.initLogger(fmt.Sprintf("%T(%v)", c, c.commitResolution.SelfOutPoint))
c.initReport()
return c, nil
@@ -585,3 +365,181 @@ func (c *commitSweepResolver) initReport() {
// A compile time assertion to ensure commitSweepResolver meets the
// ContractResolver interface.
var _ reportingContractResolver = (*commitSweepResolver)(nil)
+
+// Launch constructs a commit input and offers it to the sweeper.
+func (c *commitSweepResolver) Launch() error {
+ if c.isLaunched() {
+ c.log.Tracef("already launched")
+ return nil
+ }
+
+ c.log.Debugf("launching resolver...")
+ c.markLaunched()
+
+ // If we're already resolved, then we can exit early.
+ if c.IsResolved() {
+ c.log.Errorf("already resolved")
+ return nil
+ }
+
+ confHeight, err := c.getCommitTxConfHeight()
+ if err != nil {
+ return err
+ }
+
+ // Wait up until the CSV expires, unless we also have a CLTV that
+ // expires after.
+ unlockHeight := confHeight + c.commitResolution.MaturityDelay
+ if c.hasCLTV() {
+ unlockHeight = uint32(math.Max(
+ float64(unlockHeight), float64(c.leaseExpiry),
+ ))
+ }
+
+ // Update report now that we learned the confirmation height.
+ c.reportLock.Lock()
+ c.currentReport.MaturityHeight = unlockHeight
+ c.reportLock.Unlock()
+
+ // Derive the witness type for this input.
+ witnessType, err := c.decideWitnessType()
+ if err != nil {
+ return err
+ }
+
+ // We'll craft an input with all the information required for the
+ // sweeper to create a fully valid sweeping transaction to recover
+ // these coins.
+ var inp *input.BaseInput
+ if c.hasCLTV() {
+ inp = input.NewCsvInputWithCltv(
+ &c.commitResolution.SelfOutPoint, witnessType,
+ &c.commitResolution.SelfOutputSignDesc,
+ c.broadcastHeight, c.commitResolution.MaturityDelay,
+ c.leaseExpiry,
+ )
+ } else {
+ inp = input.NewCsvInput(
+ &c.commitResolution.SelfOutPoint, witnessType,
+ &c.commitResolution.SelfOutputSignDesc,
+ c.broadcastHeight, c.commitResolution.MaturityDelay,
+ )
+ }
+
+ // TODO(roasbeef): instead of ading ctrl block to the sign desc, make
+ // new input type, have sweeper set it?
+
+ // Calculate the budget for the sweeping this input.
+ budget := calculateBudget(
+ btcutil.Amount(inp.SignDesc().Output.Value),
+ c.Budget.ToLocalRatio, c.Budget.ToLocal,
+ )
+ c.log.Infof("sweeping commit output %v using budget=%v", witnessType,
+ budget)
+
+ // With our input constructed, we'll now offer it to the sweeper.
+ resultChan, err := c.Sweeper.SweepInput(
+ inp, sweep.Params{
+ Budget: budget,
+
+ // Specify a nil deadline here as there's no time
+ // pressure.
+ DeadlineHeight: fn.None[int32](),
+ },
+ )
+ if err != nil {
+ c.log.Errorf("unable to sweep input: %v", err)
+
+ return err
+ }
+
+ c.sweepResultChan = resultChan
+
+ return nil
+}
+
+// decideWitnessType returns the witness type for the input.
+func (c *commitSweepResolver) decideWitnessType() (input.WitnessType, error) {
+ var (
+ isLocalCommitTx bool
+ signDesc = c.commitResolution.SelfOutputSignDesc
+ )
+
+ switch {
+ // For taproot channels, we'll know if this is the local commit based
+ // on the timelock value. For remote commitment transactions, the
+ // witness script has a timelock of 1.
+ case c.chanType.IsTaproot():
+ delayKey := c.localChanCfg.DelayBasePoint.PubKey
+ nonDelayKey := c.localChanCfg.PaymentBasePoint.PubKey
+
+ signKey := c.commitResolution.SelfOutputSignDesc.KeyDesc.PubKey
+
+ // If the key in the script is neither of these, we shouldn't
+ // proceed. This should be impossible.
+ if !signKey.IsEqual(delayKey) && !signKey.IsEqual(nonDelayKey) {
+ return nil, fmt.Errorf("unknown sign key %v", signKey)
+ }
+
+ // The commitment transaction is ours iff the signing key is
+ // the delay key.
+ isLocalCommitTx = signKey.IsEqual(delayKey)
+
+ // The output is on our local commitment if the script starts with
+ // OP_IF for the revocation clause. On the remote commitment it will
+ // either be a regular P2WKH or a simple sig spend with a CSV delay.
+ default:
+ isLocalCommitTx = signDesc.WitnessScript[0] == txscript.OP_IF
+ }
+
+ isDelayedOutput := c.commitResolution.MaturityDelay != 0
+
+ c.log.Debugf("isDelayedOutput=%v, isLocalCommitTx=%v", isDelayedOutput,
+ isLocalCommitTx)
+
+ // There're three types of commitments, those that have tweaks for the
+ // remote key (us in this case), those that don't, and a third where
+ // there is no tweak and the output is delayed. On the local commitment
+ // our output will always be delayed. We'll rely on the presence of the
+ // commitment tweak to discern which type of commitment this is.
+ var witnessType input.WitnessType
+ switch {
+ // The local delayed output for a taproot channel.
+ case isLocalCommitTx && c.chanType.IsTaproot():
+ witnessType = input.TaprootLocalCommitSpend
+
+ // The CSV 1 delayed output for a taproot channel.
+ case !isLocalCommitTx && c.chanType.IsTaproot():
+ witnessType = input.TaprootRemoteCommitSpend
+
+ // Delayed output to us on our local commitment for a channel lease in
+ // which we are the initiator.
+ case isLocalCommitTx && c.hasCLTV():
+ witnessType = input.LeaseCommitmentTimeLock
+
+ // Delayed output to us on our local commitment.
+ case isLocalCommitTx:
+ witnessType = input.CommitmentTimeLock
+
+ // A confirmed output to us on the remote commitment for a channel lease
+ // in which we are the initiator.
+ case isDelayedOutput && c.hasCLTV():
+ witnessType = input.LeaseCommitmentToRemoteConfirmed
+
+ // A confirmed output to us on the remote commitment.
+ case isDelayedOutput:
+ witnessType = input.CommitmentToRemoteConfirmed
+
+ // A non-delayed output on the remote commitment where the key is
+ // tweakless.
+ case c.commitResolution.SelfOutputSignDesc.SingleTweak == nil:
+ witnessType = input.CommitSpendNoDelayTweakless
+
+ // A non-delayed output on the remote commitment where the key is
+ // tweaked.
+ default:
+ witnessType = input.CommitmentNoDelay
+ }
+
+ return witnessType, nil
+}
diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go
index 077fb8f82c..6855fddcd3 100644
--- a/contractcourt/commit_sweep_resolver_test.go
+++ b/contractcourt/commit_sweep_resolver_test.go
@@ -15,6 +15,7 @@ import (
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/sweep"
+ "github.com/stretchr/testify/require"
)
type commitSweepResolverTestContext struct {
@@ -82,7 +83,10 @@ func (i *commitSweepResolverTestContext) resolve() {
// Start resolver.
i.resolverResultChan = make(chan resolveResult, 1)
go func() {
- nextResolver, err := i.resolver.Resolve(false)
+ err := i.resolver.Launch()
+ require.NoError(i.t, err)
+
+ nextResolver, err := i.resolver.Resolve()
i.resolverResultChan <- resolveResult{
nextResolver: nextResolver,
err: err,
@@ -90,12 +94,6 @@ func (i *commitSweepResolverTestContext) resolve() {
}()
}
-func (i *commitSweepResolverTestContext) notifyEpoch(height int32) {
- i.notifier.EpochChan <- &chainntnfs.BlockEpoch{
- Height: height,
- }
-}
-
func (i *commitSweepResolverTestContext) waitForResult() {
i.t.Helper()
@@ -292,22 +290,10 @@ func testCommitSweepResolverDelay(t *testing.T, sweepErr error) {
t.Fatal("report maturity height incorrect")
}
- // Notify initial block height. The csv lock is still in effect, so we
- // don't expect any sweep to happen yet.
- ctx.notifyEpoch(testInitialBlockHeight)
-
- select {
- case <-ctx.sweeper.sweptInputs:
- t.Fatal("no sweep expected")
- case <-time.After(sweepProcessInterval):
- }
-
- // A new block arrives. The commit tx confirmed at height -1 and the csv
- // is 3, so a spend will be valid in the first block after height +1.
- ctx.notifyEpoch(testInitialBlockHeight + 1)
-
- <-ctx.sweeper.sweptInputs
-
+ // Notify initial block height. Although the csv lock is still in
+ // effect, we expect the input being sent to the sweeper before the csv
+ // lock expires.
+ //
// Set the resolution report outcome based on whether our sweep
// succeeded.
outcome := channeldb.ResolverOutcomeClaimed
diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go
index 53f4f680d0..d11bd2f597 100644
--- a/contractcourt/contract_resolver.go
+++ b/contractcourt/contract_resolver.go
@@ -5,11 +5,13 @@ import (
"errors"
"fmt"
"io"
+ "sync/atomic"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
+ "github.com/lightningnetwork/lnd/sweep"
)
var (
@@ -35,6 +37,17 @@ type ContractResolver interface {
// resides within.
ResolverKey() []byte
+ // Launch starts the resolver by constructing an input and offering it
+ // to the sweeper. Once offered, it's expected to monitor the sweeping
+ // result in a goroutine invoked by calling Resolve.
+ //
+ // NOTE: We can call `Resolve` inside a goroutine at the end of this
+ // method to avoid calling it in the ChannelArbitrator. However, there
+ // are some DB-related operations such as SwapContract/ResolveContract
+ // which need to be done inside the resolvers instead, which needs a
+ // deeper refactoring.
+ Launch() error
+
// Resolve instructs the contract resolver to resolve the output
// on-chain. Once the output has been *fully* resolved, the function
// should return immediately with a nil ContractResolver value for the
@@ -42,7 +55,7 @@ type ContractResolver interface {
// resolution, then another resolve is returned.
//
// NOTE: This function MUST be run as a goroutine.
- Resolve(immediate bool) (ContractResolver, error)
+ Resolve() (ContractResolver, error)
// SupplementState allows the user of a ContractResolver to supplement
// it with state required for the proper resolution of a contract.
@@ -109,6 +122,21 @@ type contractResolverKit struct {
log btclog.Logger
quit chan struct{}
+
+ // sweepResultChan is the result chan returned from calling
+ // `SweepInput`. It should be mounted to the specific resolver once the
+ // input has been offered to the sweeper.
+ sweepResultChan chan sweep.Result
+
+ // launched specifies whether the resolver has been launched. Calling
+ // `Launch` will be a no-op if this is true. This value is not saved to
+ // db, as it's fine to relaunch a resolver after a restart. It's only
+ // used to avoid resending requests to the sweeper when a new blockbeat
+ // is received.
+ launched atomic.Bool
+
+ // resolved reflects if the contract has been fully resolved or not.
+ resolved atomic.Bool
}
// newContractResolverKit instantiates the mix-in struct.
@@ -120,11 +148,36 @@ func newContractResolverKit(cfg ResolverConfig) *contractResolverKit {
}
// initLogger initializes the resolver-specific logger.
-func (r *contractResolverKit) initLogger(resolver ContractResolver) {
- logPrefix := fmt.Sprintf("%T(%v):", resolver, r.ChanPoint)
+func (r *contractResolverKit) initLogger(prefix string) {
+ logPrefix := fmt.Sprintf("ChannelArbitrator(%v): %s:", r.ChanPoint,
+ prefix)
+
r.log = log.WithPrefix(logPrefix)
}
+// IsResolved returns true if the stored state in the resolve is fully
+// resolved. In this case the target output can be forgotten.
+//
+// NOTE: Part of the ContractResolver interface.
+func (r *contractResolverKit) IsResolved() bool {
+ return r.resolved.Load()
+}
+
+// markResolved marks the resolver as resolved.
+func (r *contractResolverKit) markResolved() {
+ r.resolved.Store(true)
+}
+
+// isLaunched returns true if the resolver has been launched.
+func (r *contractResolverKit) isLaunched() bool {
+ return r.launched.Load()
+}
+
+// markLaunched marks the resolver as launched.
+func (r *contractResolverKit) markLaunched() {
+ r.launched.Store(true)
+}
+
var (
// errResolverShuttingDown is returned when the resolver stops
// progressing because it received the quit signal.
diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go
index 73841eb88c..bc5948487d 100644
--- a/contractcourt/htlc_incoming_contest_resolver.go
+++ b/contractcourt/htlc_incoming_contest_resolver.go
@@ -10,7 +10,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/invoices"
@@ -78,6 +78,37 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error {
return nil
}
+// Launch will call the inner resolver's launch method if the preimage can be
+// found, otherwise it's a no-op.
+func (h *htlcIncomingContestResolver) Launch() error {
+ // NOTE: we don't mark this resolver as launched as the inner resolver
+ // will set it when it's launched.
+ if h.isLaunched() {
+ h.log.Tracef("already launched")
+ return nil
+ }
+
+ h.log.Debugf("launching contest resolver...")
+
+ // Query the preimage and apply it if we already know it.
+ applied, err := h.findAndapplyPreimage()
+ if err != nil {
+ return err
+ }
+
+ // No preimage found, leave it to be handled by the resolver.
+ if !applied {
+ return nil
+ }
+
+ h.log.Debugf("found preimage for htlc=%x, transforming into success "+
+ "resolver and launching it", h.htlc.RHash)
+
+ // Once we've applied the preimage, we'll launch the inner resolver to
+ // attempt to claim the HTLC.
+ return h.htlcSuccessResolver.Launch()
+}
+
// Resolve attempts to resolve this contract. As we don't yet know of the
// preimage for the contract, we'll wait for one of two things to happen:
//
@@ -90,12 +121,11 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error {
// as we have no remaining actions left at our disposal.
//
// NOTE: Part of the ContractResolver interface.
-func (h *htlcIncomingContestResolver) Resolve(
- _ bool) (ContractResolver, error) {
-
+func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
// If we're already full resolved, then we don't have anything further
// to do.
- if h.resolved {
+ if h.IsResolved() {
+ h.log.Errorf("already resolved")
return nil, nil
}
@@ -103,15 +133,14 @@ func (h *htlcIncomingContestResolver) Resolve(
// now.
payload, nextHopOnionBlob, err := h.decodePayload()
if err != nil {
- log.Debugf("ChannelArbitrator(%v): cannot decode payload of "+
- "htlc %v", h.ChanPoint, h.HtlcPoint())
+ h.log.Debugf("cannot decode payload of htlc %v", h.HtlcPoint())
// If we've locked in an htlc with an invalid payload on our
// commitment tx, we don't need to resolve it. The other party
// will time it out and get their funds back. This situation
// can present itself when we crash before processRemoteAdds in
// the link has ran.
- h.resolved = true
+ h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
@@ -164,7 +193,7 @@ func (h *htlcIncomingContestResolver) Resolve(
log.Infof("%T(%v): HTLC has timed out (expiry=%v, height=%v), "+
"abandoning", h, h.htlcResolution.ClaimOutpoint,
h.htlcExpiry, currentHeight)
- h.resolved = true
+ h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
@@ -179,65 +208,6 @@ func (h *htlcIncomingContestResolver) Resolve(
return nil, h.Checkpoint(h, report)
}
- // applyPreimage is a helper function that will populate our internal
- // resolver with the preimage we learn of. This should be called once
- // the preimage is revealed so the inner resolver can properly complete
- // its duties. The error return value indicates whether the preimage
- // was properly applied.
- applyPreimage := func(preimage lntypes.Preimage) error {
- // Sanity check to see if this preimage matches our htlc. At
- // this point it should never happen that it does not match.
- if !preimage.Matches(h.htlc.RHash) {
- return errors.New("preimage does not match hash")
- }
-
- // Update htlcResolution with the matching preimage.
- h.htlcResolution.Preimage = preimage
-
- log.Infof("%T(%v): applied preimage=%v", h,
- h.htlcResolution.ClaimOutpoint, preimage)
-
- isSecondLevel := h.htlcResolution.SignedSuccessTx != nil
-
- // If we didn't have to go to the second level to claim (this
- // is the remote commitment transaction), then we don't need to
- // modify our canned witness.
- if !isSecondLevel {
- return nil
- }
-
- isTaproot := txscript.IsPayToTaproot(
- h.htlcResolution.SignedSuccessTx.TxOut[0].PkScript,
- )
-
- // If this is our commitment transaction, then we'll need to
- // populate the witness for the second-level HTLC transaction.
- switch {
- // For taproot channels, the witness for sweeping with success
- // looks like:
- // -
- //
- //
- // So we'll insert it at the 3rd index of the witness.
- case isTaproot:
- //nolint:ll
- h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[2] = preimage[:]
-
- // Within the witness for the success transaction, the
- // preimage is the 4th element as it looks like:
- //
- // * <0>
- //
- // We'll populate it within the witness, as since this
- // was a "contest" resolver, we didn't yet know of the
- // preimage.
- case !isTaproot:
- h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[3] = preimage[:]
- }
-
- return nil
- }
-
// Define a closure to process htlc resolutions either directly or
// triggered by future notifications.
processHtlcResolution := func(e invoices.HtlcResolution) (
@@ -249,7 +219,7 @@ func (h *htlcIncomingContestResolver) Resolve(
// If the htlc resolution was a settle, apply the
// preimage and return a success resolver.
case *invoices.HtlcSettleResolution:
- err := applyPreimage(resolution.Preimage)
+ err := h.applyPreimage(resolution.Preimage)
if err != nil {
return nil, err
}
@@ -264,7 +234,7 @@ func (h *htlcIncomingContestResolver) Resolve(
h.htlcResolution.ClaimOutpoint,
h.htlcExpiry, currentHeight)
- h.resolved = true
+ h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
@@ -314,6 +284,9 @@ func (h *htlcIncomingContestResolver) Resolve(
return nil, err
}
+ h.log.Debugf("received resolution from registry: %v",
+ resolution)
+
defer func() {
h.Registry.HodlUnsubscribeAll(hodlQueue.ChanIn())
@@ -371,7 +344,9 @@ func (h *htlcIncomingContestResolver) Resolve(
// However, we don't know how to ourselves, so we'll
// return our inner resolver which has the knowledge to
// do so.
- if err := applyPreimage(preimage); err != nil {
+ h.log.Debugf("Found preimage for htlc=%x", h.htlc.RHash)
+
+ if err := h.applyPreimage(preimage); err != nil {
return nil, err
}
@@ -390,7 +365,10 @@ func (h *htlcIncomingContestResolver) Resolve(
continue
}
- if err := applyPreimage(preimage); err != nil {
+ h.log.Debugf("Received preimage for htlc=%x",
+ h.htlc.RHash)
+
+ if err := h.applyPreimage(preimage); err != nil {
return nil, err
}
@@ -417,7 +395,8 @@ func (h *htlcIncomingContestResolver) Resolve(
"(expiry=%v, height=%v), abandoning", h,
h.htlcResolution.ClaimOutpoint,
h.htlcExpiry, currentHeight)
- h.resolved = true
+
+ h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
@@ -437,6 +416,76 @@ func (h *htlcIncomingContestResolver) Resolve(
}
}
+// applyPreimage is a helper function that will populate our internal resolver
+// with the preimage we learn of. This should be called once the preimage is
+// revealed so the inner resolver can properly complete its duties. The error
+// return value indicates whether the preimage was properly applied.
+func (h *htlcIncomingContestResolver) applyPreimage(
+ preimage lntypes.Preimage) error {
+
+ // Sanity check to see if this preimage matches our htlc. At this point
+ // it should never happen that it does not match.
+ if !preimage.Matches(h.htlc.RHash) {
+ return errors.New("preimage does not match hash")
+ }
+
+ // We may already have the preimage since both the `Launch` and
+ // `Resolve` methods will look for it.
+ if h.htlcResolution.Preimage != lntypes.ZeroHash {
+ h.log.Debugf("already applied preimage for htlc=%x",
+ h.htlc.RHash)
+
+ return nil
+ }
+
+ // Update htlcResolution with the matching preimage.
+ h.htlcResolution.Preimage = preimage
+
+ log.Infof("%T(%v): applied preimage=%v", h,
+ h.htlcResolution.ClaimOutpoint, preimage)
+
+ isSecondLevel := h.htlcResolution.SignedSuccessTx != nil
+
+ // If we didn't have to go to the second level to claim (this
+ // is the remote commitment transaction), then we don't need to
+ // modify our canned witness.
+ if !isSecondLevel {
+ return nil
+ }
+
+ isTaproot := txscript.IsPayToTaproot(
+ h.htlcResolution.SignedSuccessTx.TxOut[0].PkScript,
+ )
+
+ // If this is our commitment transaction, then we'll need to
+ // populate the witness for the second-level HTLC transaction.
+ switch {
+ // For taproot channels, the witness for sweeping with success
+ // looks like:
+ // -
+ //
+ //
+ // So we'll insert it at the 3rd index of the witness.
+ case isTaproot:
+ //nolint:ll
+ h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[2] = preimage[:]
+
+ // Within the witness for the success transaction, the
+ // preimage is the 4th element as it looks like:
+ //
+ // * <0>
+ //
+ // We'll populate it within the witness, as since this
+ // was a "contest" resolver, we didn't yet know of the
+ // preimage.
+ case !isTaproot:
+ //nolint:ll
+ h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[3] = preimage[:]
+ }
+
+ return nil
+}
+
// report returns a report on the resolution state of the contract.
func (h *htlcIncomingContestResolver) report() *ContractReport {
// No locking needed as these values are read-only.
@@ -463,17 +512,11 @@ func (h *htlcIncomingContestResolver) report() *ContractReport {
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcIncomingContestResolver) Stop() {
+ h.log.Debugf("stopping...")
+ defer h.log.Debugf("stopped")
close(h.quit)
}
-// IsResolved returns true if the stored state in the resolve is fully
-// resolved. In this case the target output can be forgotten.
-//
-// NOTE: Part of the ContractResolver interface.
-func (h *htlcIncomingContestResolver) IsResolved() bool {
- return h.resolved
-}
-
// Encode writes an encoded version of the ContractResolver into the passed
// Writer.
//
@@ -562,3 +605,82 @@ func (h *htlcIncomingContestResolver) decodePayload() (*hop.Payload,
// A compile time assertion to ensure htlcIncomingContestResolver meets the
// ContractResolver interface.
var _ htlcContractResolver = (*htlcIncomingContestResolver)(nil)
+
+// findAndapplyPreimage performs a non-blocking read to find the preimage for
+// the incoming HTLC. If found, it will be applied to the resolver. This method
+// is used for the resolver to decide whether it wants to transform into a
+// success resolver during launching.
+//
+// NOTE: Since we have two places to query the preimage, we need to check both
+// the preimage db and the invoice db to look up the preimage.
+func (h *htlcIncomingContestResolver) findAndapplyPreimage() (bool, error) {
+ // Query to see if we already know the preimage.
+ preimage, ok := h.PreimageDB.LookupPreimage(h.htlc.RHash)
+
+ // If the preimage is known, we'll apply it.
+ if ok {
+ if err := h.applyPreimage(preimage); err != nil {
+ return false, err
+ }
+
+ // Successfully applied the preimage, we can now return.
+ return true, nil
+ }
+
+ // First try to parse the payload.
+ payload, _, err := h.decodePayload()
+ if err != nil {
+ h.log.Errorf("Cannot decode payload of htlc %v", h.HtlcPoint())
+
+ // If we cannot decode the payload, we will return a nil error
+ // and let it to be handled in `Resolve`.
+ return false, nil
+ }
+
+ // Exit early if this is not the exit hop, which means we are not the
+ // payment receiver and don't have preimage.
+ if payload.FwdInfo.NextHop != hop.Exit {
+ return false, nil
+ }
+
+ // Notify registry that we are potentially resolving as an exit hop
+ // on-chain. If this HTLC indeed pays to an existing invoice, the
+ // invoice registry will tell us what to do with the HTLC. This is
+ // identical to HTLC resolution in the link.
+ circuitKey := models.CircuitKey{
+ ChanID: h.ShortChanID,
+ HtlcID: h.htlc.HtlcIndex,
+ }
+
+ // Try get the resolution - if it doesn't give us a resolution
+ // immediately, we'll assume we don't know it yet and let the `Resolve`
+ // handle the waiting.
+ //
+ // NOTE: we use a nil subscriber here and a zero current height as we
+ // are only interested in the settle resolution.
+ //
+ // TODO(yy): move this logic to link and let the preimage be accessed
+ // via the preimage beacon.
+ resolution, err := h.Registry.NotifyExitHopHtlc(
+ h.htlc.RHash, h.htlc.Amt, h.htlcExpiry, 0,
+ circuitKey, nil, nil, payload,
+ )
+ if err != nil {
+ return false, err
+ }
+
+ res, ok := resolution.(*invoices.HtlcSettleResolution)
+
+ // Exit early if it's not a settle resolution.
+ if !ok {
+ return false, nil
+ }
+
+ // Otherwise we have a settle resolution, apply the preimage.
+ err = h.applyPreimage(res.Preimage)
+ if err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go
index 22280f953e..f17190e96e 100644
--- a/contractcourt/htlc_incoming_contest_resolver_test.go
+++ b/contractcourt/htlc_incoming_contest_resolver_test.go
@@ -5,11 +5,13 @@ import (
"io"
"testing"
+ "github.com/btcsuite/btcd/wire"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
+ "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnmock"
@@ -356,6 +358,7 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver
return nil
},
+ Sweeper: newMockSweeper(),
},
PutResolverReport: func(_ kvdb.RwTx,
_ *channeldb.ResolverReport) error {
@@ -374,10 +377,16 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver
},
}
+ res := lnwallet.IncomingHtlcResolution{
+ SweepSignDesc: input.SignDescriptor{
+ Output: &wire.TxOut{},
+ },
+ }
+
c.resolver = &htlcIncomingContestResolver{
htlcSuccessResolver: &htlcSuccessResolver{
contractResolverKit: *newContractResolverKit(cfg),
- htlcResolution: lnwallet.IncomingHtlcResolution{},
+ htlcResolution: res,
htlc: channeldb.HTLC{
Amt: lnwire.MilliSatoshi(testHtlcAmount),
RHash: testResHash,
@@ -386,6 +395,7 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver
},
htlcExpiry: testHtlcExpiry,
}
+ c.resolver.initLogger("htlcIncomingContestResolver")
return c
}
@@ -395,7 +405,11 @@ func (i *incomingResolverTestContext) resolve() {
i.resolveErr = make(chan error, 1)
go func() {
var err error
- i.nextResolver, err = i.resolver.Resolve(false)
+
+ err = i.resolver.Launch()
+ require.NoError(i.t, err)
+
+ i.nextResolver, err = i.resolver.Resolve()
i.resolveErr <- err
}()
diff --git a/contractcourt/htlc_lease_resolver.go b/contractcourt/htlc_lease_resolver.go
index 53fa893553..6230f96777 100644
--- a/contractcourt/htlc_lease_resolver.go
+++ b/contractcourt/htlc_lease_resolver.go
@@ -6,7 +6,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/tlv"
)
@@ -57,10 +57,10 @@ func (h *htlcLeaseResolver) makeSweepInput(op *wire.OutPoint,
signDesc *input.SignDescriptor, csvDelay, broadcastHeight uint32,
payHash [32]byte, resBlob fn.Option[tlv.Blob]) *input.BaseInput {
- if h.hasCLTV() {
- log.Infof("%T(%x): CSV and CLTV locks expired, offering "+
- "second-layer output to sweeper: %v", h, payHash, op)
+ log.Infof("%T(%x): offering second-layer output to sweeper: %v", h,
+ payHash, op)
+ if h.hasCLTV() {
return input.NewCsvInputWithCltv(
op, cltvWtype, signDesc,
broadcastHeight, csvDelay,
diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go
index 2466544c98..b66a3fdf0b 100644
--- a/contractcourt/htlc_outgoing_contest_resolver.go
+++ b/contractcourt/htlc_outgoing_contest_resolver.go
@@ -1,12 +1,11 @@
package contractcourt
import (
- "fmt"
"io"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet"
)
@@ -36,6 +35,37 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution,
}
}
+// Launch will call the inner resolver's launch method if the expiry height has
+// been reached, otherwise it's a no-op.
+func (h *htlcOutgoingContestResolver) Launch() error {
+ // NOTE: we don't mark this resolver as launched as the inner resolver
+ // will set it when it's launched.
+ if h.isLaunched() {
+ h.log.Tracef("already launched")
+ return nil
+ }
+
+ h.log.Debugf("launching contest resolver...")
+
+ _, bestHeight, err := h.ChainIO.GetBestBlock()
+ if err != nil {
+ return err
+ }
+
+ if uint32(bestHeight) < h.htlcResolution.Expiry {
+ return nil
+ }
+
+ // If the current height is >= expiry, then a timeout path spend will
+ // be valid to be included in the next block, and we can immediately
+ // return the resolver.
+ h.log.Infof("expired (height=%v, expiry=%v), transforming into "+
+ "timeout resolver and launching it", bestHeight,
+ h.htlcResolution.Expiry)
+
+ return h.htlcTimeoutResolver.Launch()
+}
+
// Resolve commences the resolution of this contract. As this contract hasn't
// yet timed out, we'll wait for one of two things to happen
//
@@ -49,12 +79,11 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution,
// When either of these two things happens, we'll create a new resolver which
// is able to handle the final resolution of the contract. We're only the pivot
// point.
-func (h *htlcOutgoingContestResolver) Resolve(
- _ bool) (ContractResolver, error) {
-
+func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) {
// If we're already full resolved, then we don't have anything further
// to do.
- if h.resolved {
+ if h.IsResolved() {
+ h.log.Errorf("already resolved")
return nil, nil
}
@@ -88,8 +117,7 @@ func (h *htlcOutgoingContestResolver) Resolve(
return nil, errResolverShuttingDown
}
- // TODO(roasbeef): Checkpoint?
- return h.claimCleanUp(commitSpend)
+ return nil, h.claimCleanUp(commitSpend)
// If it hasn't, then we'll watch for both the expiration, and the
// sweeping out this output.
@@ -126,12 +154,21 @@ func (h *htlcOutgoingContestResolver) Resolve(
// finalized` will be returned and the broadcast will
// fail.
newHeight := uint32(newBlock.Height)
- if newHeight >= h.htlcResolution.Expiry {
- log.Infof("%T(%v): HTLC has expired "+
+ expiry := h.htlcResolution.Expiry
+
+ // Check if the expiry height is about to be reached.
+ // We offer this HTLC one block earlier to make sure
+ // when the next block arrives, the sweeper will pick
+ // up this input and sweep it immediately. The sweeper
+ // will handle the waiting for the one last block till
+ // expiry.
+ if newHeight >= expiry-1 {
+ h.log.Infof("HTLC about to expire "+
"(height=%v, expiry=%v), transforming "+
"into timeout resolver", h,
h.htlcResolution.ClaimOutpoint,
newHeight, h.htlcResolution.Expiry)
+
return h.htlcTimeoutResolver, nil
}
@@ -146,10 +183,10 @@ func (h *htlcOutgoingContestResolver) Resolve(
// party is by revealing the preimage. So we'll perform
// our duties to clean up the contract once it has been
// claimed.
- return h.claimCleanUp(commitSpend)
+ return nil, h.claimCleanUp(commitSpend)
case <-h.quit:
- return nil, fmt.Errorf("resolver canceled")
+ return nil, errResolverShuttingDown
}
}
}
@@ -180,17 +217,11 @@ func (h *htlcOutgoingContestResolver) report() *ContractReport {
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcOutgoingContestResolver) Stop() {
+ h.log.Debugf("stopping...")
+ defer h.log.Debugf("stopped")
close(h.quit)
}
-// IsResolved returns true if the stored state in the resolve is fully
-// resolved. In this case the target output can be forgotten.
-//
-// NOTE: Part of the ContractResolver interface.
-func (h *htlcOutgoingContestResolver) IsResolved() bool {
- return h.resolved
-}
-
// Encode writes an encoded version of the ContractResolver into the passed
// Writer.
//
diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go
index 6608a6fb51..625df60bf1 100644
--- a/contractcourt/htlc_outgoing_contest_resolver_test.go
+++ b/contractcourt/htlc_outgoing_contest_resolver_test.go
@@ -15,6 +15,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
+ "github.com/stretchr/testify/require"
)
const (
@@ -159,6 +160,7 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext {
return nil
},
+ ChainIO: &mock.ChainIO{},
},
PutResolverReport: func(_ kvdb.RwTx,
_ *channeldb.ResolverReport) error {
@@ -195,6 +197,7 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext {
},
},
}
+ resolver.initLogger("htlcOutgoingContestResolver")
return &outgoingResolverTestContext{
resolver: resolver,
@@ -209,7 +212,10 @@ func (i *outgoingResolverTestContext) resolve() {
// Start resolver.
i.resolverResultChan = make(chan resolveResult, 1)
go func() {
- nextResolver, err := i.resolver.Resolve(false)
+ err := i.resolver.Launch()
+ require.NoError(i.t, err)
+
+ nextResolver, err := i.resolver.Resolve()
i.resolverResultChan <- resolveResult{
nextResolver: nextResolver,
err: err,
diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go
index b2716ad305..a4d27ba4e8 100644
--- a/contractcourt/htlc_success_resolver.go
+++ b/contractcourt/htlc_success_resolver.go
@@ -2,6 +2,7 @@ package contractcourt
import (
"encoding/binary"
+ "fmt"
"io"
"sync"
@@ -9,10 +10,8 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
- "github.com/davecgh/go-spew/spew"
- "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
@@ -43,9 +42,6 @@ type htlcSuccessResolver struct {
// second-level output (true).
outputIncubating bool
- // resolved reflects if the contract has been fully resolved or not.
- resolved bool
-
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
@@ -81,27 +77,30 @@ func newSuccessResolver(res lnwallet.IncomingHtlcResolution,
}
h.initReport()
+ h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint()))
return h
}
-// ResolverKey returns an identifier which should be globally unique for this
-// particular resolver within the chain the original contract resides within.
-//
-// NOTE: Part of the ContractResolver interface.
-func (h *htlcSuccessResolver) ResolverKey() []byte {
+// outpoint returns the outpoint of the HTLC output we're attempting to sweep.
+func (h *htlcSuccessResolver) outpoint() wire.OutPoint {
// The primary key for this resolver will be the outpoint of the HTLC
// on the commitment transaction itself. If this is our commitment,
// then the output can be found within the signed success tx,
// otherwise, it's just the ClaimOutpoint.
- var op wire.OutPoint
if h.htlcResolution.SignedSuccessTx != nil {
- op = h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint
- } else {
- op = h.htlcResolution.ClaimOutpoint
+ return h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint
}
- key := newResolverID(op)
+ return h.htlcResolution.ClaimOutpoint
+}
+
+// ResolverKey returns an identifier which should be globally unique for this
+// particular resolver within the chain the original contract resides within.
+//
+// NOTE: Part of the ContractResolver interface.
+func (h *htlcSuccessResolver) ResolverKey() []byte {
+ key := newResolverID(h.outpoint())
return key[:]
}
@@ -112,423 +111,66 @@ func (h *htlcSuccessResolver) ResolverKey() []byte {
// anymore. Every HTLC has already passed through the incoming contest resolver
// and in there the invoice was already marked as settled.
//
-// TODO(roasbeef): create multi to batch
-//
// NOTE: Part of the ContractResolver interface.
-func (h *htlcSuccessResolver) Resolve(
- immediate bool) (ContractResolver, error) {
-
- // If we're already resolved, then we can exit early.
- if h.resolved {
- return nil, nil
- }
-
- // If we don't have a success transaction, then this means that this is
- // an output on the remote party's commitment transaction.
- if h.htlcResolution.SignedSuccessTx == nil {
- return h.resolveRemoteCommitOutput(immediate)
- }
-
- // Otherwise this an output on our own commitment, and we must start by
- // broadcasting the second-level success transaction.
- secondLevelOutpoint, err := h.broadcastSuccessTx(immediate)
- if err != nil {
- return nil, err
- }
-
- // To wrap this up, we'll wait until the second-level transaction has
- // been spent, then fully resolve the contract.
- log.Infof("%T(%x): waiting for second-level HTLC output to be spent "+
- "after csv_delay=%v", h, h.htlc.RHash[:], h.htlcResolution.CsvDelay)
-
- spend, err := waitForSpend(
- secondLevelOutpoint,
- h.htlcResolution.SweepSignDesc.Output.PkScript,
- h.broadcastHeight, h.Notifier, h.quit,
- )
- if err != nil {
- return nil, err
- }
-
- h.reportLock.Lock()
- h.currentReport.RecoveredBalance = h.currentReport.LimboBalance
- h.currentReport.LimboBalance = 0
- h.reportLock.Unlock()
-
- h.resolved = true
- return nil, h.checkpointClaim(
- spend.SpenderTxHash, channeldb.ResolverOutcomeClaimed,
- )
-}
-
-// broadcastSuccessTx handles an HTLC output on our local commitment by
-// broadcasting the second-level success transaction. It returns the ultimate
-// outpoint of the second-level tx, that we must wait to be spent for the
-// resolver to be fully resolved.
-func (h *htlcSuccessResolver) broadcastSuccessTx(
- immediate bool) (*wire.OutPoint, error) {
-
- // If we have non-nil SignDetails, this means that have a 2nd level
- // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY
- // (the case for anchor type channels). In this case we can re-sign it
- // and attach fees at will. We let the sweeper handle this job. We use
- // the checkpointed outputIncubating field to determine if we already
- // swept the HTLC output into the second level transaction.
- if h.htlcResolution.SignDetails != nil {
- return h.broadcastReSignedSuccessTx(immediate)
- }
-
- // Otherwise we'll publish the second-level transaction directly and
- // offer the resolution to the nursery to handle.
- log.Infof("%T(%x): broadcasting second-layer transition tx: %v",
- h, h.htlc.RHash[:], spew.Sdump(h.htlcResolution.SignedSuccessTx))
-
- // We'll now broadcast the second layer transaction so we can kick off
- // the claiming process.
- //
- // TODO(roasbeef): after changing sighashes send to tx bundler
- label := labels.MakeLabel(
- labels.LabelTypeChannelClose, &h.ShortChanID,
- )
- err := h.PublishTx(h.htlcResolution.SignedSuccessTx, label)
- if err != nil {
- return nil, err
- }
-
- // Otherwise, this is an output on our commitment transaction. In this
- // case, we'll send it to the incubator, but only if we haven't already
- // done so.
- if !h.outputIncubating {
- log.Infof("%T(%x): incubating incoming htlc output",
- h, h.htlc.RHash[:])
-
- err := h.IncubateOutputs(
- h.ChanPoint, fn.None[lnwallet.OutgoingHtlcResolution](),
- fn.Some(h.htlcResolution),
- h.broadcastHeight, fn.Some(int32(h.htlc.RefundTimeout)),
- )
- if err != nil {
- return nil, err
- }
-
- h.outputIncubating = true
-
- if err := h.Checkpoint(h); err != nil {
- log.Errorf("unable to Checkpoint: %v", err)
- return nil, err
- }
- }
-
- return &h.htlcResolution.ClaimOutpoint, nil
-}
-
-// broadcastReSignedSuccessTx handles the case where we have non-nil
-// SignDetails, and offers the second level transaction to the Sweeper, that
-// will re-sign it and attach fees at will.
//
-//nolint:funlen
-func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) (
- *wire.OutPoint, error) {
-
- // Keep track of the tx spending the HTLC output on the commitment, as
- // this will be the confirmed second-level tx we'll ultimately sweep.
- var commitSpend *chainntnfs.SpendDetail
-
- // We will have to let the sweeper re-sign the success tx and wait for
- // it to confirm, if we haven't already.
- isTaproot := txscript.IsPayToTaproot(
- h.htlcResolution.SweepSignDesc.Output.PkScript,
- )
- if !h.outputIncubating {
- var secondLevelInput input.HtlcSecondLevelAnchorInput
- if isTaproot {
- //nolint:ll
- secondLevelInput = input.MakeHtlcSecondLevelSuccessTaprootInput(
- h.htlcResolution.SignedSuccessTx,
- h.htlcResolution.SignDetails, h.htlcResolution.Preimage,
- h.broadcastHeight,
- input.WithResolutionBlob(
- h.htlcResolution.ResolutionBlob,
- ),
- )
- } else {
- //nolint:ll
- secondLevelInput = input.MakeHtlcSecondLevelSuccessAnchorInput(
- h.htlcResolution.SignedSuccessTx,
- h.htlcResolution.SignDetails, h.htlcResolution.Preimage,
- h.broadcastHeight,
- )
- }
-
- // Calculate the budget for this sweep.
- value := btcutil.Amount(
- secondLevelInput.SignDesc().Output.Value,
- )
- budget := calculateBudget(
- value, h.Budget.DeadlineHTLCRatio,
- h.Budget.DeadlineHTLC,
- )
-
- // The deadline would be the CLTV in this HTLC output. If we
- // are the initiator of this force close, with the default
- // `IncomingBroadcastDelta`, it means we have 10 blocks left
- // when going onchain. Given we need to mine one block to
- // confirm the force close tx, and one more block to trigger
- // the sweep, we have 8 blocks left to sweep the HTLC.
- deadline := fn.Some(int32(h.htlc.RefundTimeout))
-
- log.Infof("%T(%x): offering second-level HTLC success tx to "+
- "sweeper with deadline=%v, budget=%v", h,
- h.htlc.RHash[:], h.htlc.RefundTimeout, budget)
-
- // We'll now offer the second-level transaction to the sweeper.
- _, err := h.Sweeper.SweepInput(
- &secondLevelInput,
- sweep.Params{
- Budget: budget,
- DeadlineHeight: deadline,
- Immediate: immediate,
- },
- )
- if err != nil {
- return nil, err
- }
-
- log.Infof("%T(%x): waiting for second-level HTLC success "+
- "transaction to confirm", h, h.htlc.RHash[:])
-
- // Wait for the second level transaction to confirm.
- commitSpend, err = waitForSpend(
- &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint,
- h.htlcResolution.SignDetails.SignDesc.Output.PkScript,
- h.broadcastHeight, h.Notifier, h.quit,
- )
- if err != nil {
- return nil, err
- }
+// TODO(yy): refactor the interface method to return an error only.
+func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) {
+ var err error
- // Now that the second-level transaction has confirmed, we
- // checkpoint the state so we'll go to the next stage in case
- // of restarts.
- h.outputIncubating = true
- if err := h.Checkpoint(h); err != nil {
- log.Errorf("unable to Checkpoint: %v", err)
- return nil, err
- }
-
- log.Infof("%T(%x): second-level HTLC success transaction "+
- "confirmed!", h, h.htlc.RHash[:])
- }
-
- // If we ended up here after a restart, we must again get the
- // spend notification.
- if commitSpend == nil {
- var err error
- commitSpend, err = waitForSpend(
- &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint,
- h.htlcResolution.SignDetails.SignDesc.Output.PkScript,
- h.broadcastHeight, h.Notifier, h.quit,
- )
- if err != nil {
- return nil, err
- }
- }
-
- // The HTLC success tx has a CSV lock that we must wait for, and if
- // this is a lease enforced channel and we're the imitator, we may need
- // to wait for longer.
- waitHeight := h.deriveWaitHeight(
- h.htlcResolution.CsvDelay, commitSpend,
- )
-
- // Now that the sweeper has broadcasted the second-level transaction,
- // it has confirmed, and we have checkpointed our state, we'll sweep
- // the second level output. We report the resolver has moved the next
- // stage.
- h.reportLock.Lock()
- h.currentReport.Stage = 2
- h.currentReport.MaturityHeight = waitHeight
- h.reportLock.Unlock()
-
- if h.hasCLTV() {
- log.Infof("%T(%x): waiting for CSV and CLTV lock to "+
- "expire at height %v", h, h.htlc.RHash[:],
- waitHeight)
- } else {
- log.Infof("%T(%x): waiting for CSV lock to expire at "+
- "height %v", h, h.htlc.RHash[:], waitHeight)
- }
-
- // Deduct one block so this input is offered to the sweeper one block
- // earlier since the sweeper will wait for one block to trigger the
- // sweeping.
- //
- // TODO(yy): this is done so the outputs can be aggregated
- // properly. Suppose CSV locks of five 2nd-level outputs all
- // expire at height 840000, there is a race in block digestion
- // between contractcourt and sweeper:
- // - G1: block 840000 received in contractcourt, it now offers
- // the outputs to the sweeper.
- // - G2: block 840000 received in sweeper, it now starts to
- // sweep the received outputs - there's no guarantee all
- // fives have been received.
- // To solve this, we either offer the outputs earlier, or
- // implement `blockbeat`, and force contractcourt and sweeper
- // to consume each block sequentially.
- waitHeight--
-
- // TODO(yy): let sweeper handles the wait?
- err := waitForHeight(waitHeight, h.Notifier, h.quit)
- if err != nil {
- return nil, err
- }
-
- // We'll use this input index to determine the second-level output
- // index on the transaction, as the signatures requires the indexes to
- // be the same. We don't look for the second-level output script
- // directly, as there might be more than one HTLC output to the same
- // pkScript.
- op := &wire.OutPoint{
- Hash: *commitSpend.SpenderTxHash,
- Index: commitSpend.SpenderInputIndex,
- }
-
- // Let the sweeper sweep the second-level output now that the
- // CSV/CLTV locks have expired.
- var witType input.StandardWitnessType
- if isTaproot {
- witType = input.TaprootHtlcAcceptedSuccessSecondLevel
- } else {
- witType = input.HtlcAcceptedSuccessSecondLevel
- }
- inp := h.makeSweepInput(
- op, witType,
- input.LeaseHtlcAcceptedSuccessSecondLevel,
- &h.htlcResolution.SweepSignDesc,
- h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight),
- h.htlc.RHash, h.htlcResolution.ResolutionBlob,
- )
-
- // Calculate the budget for this sweep.
- budget := calculateBudget(
- btcutil.Amount(inp.SignDesc().Output.Value),
- h.Budget.NoDeadlineHTLCRatio,
- h.Budget.NoDeadlineHTLC,
- )
+ switch {
+ // If we're already resolved, then we can exit early.
+ case h.IsResolved():
+ h.log.Errorf("already resolved")
- log.Infof("%T(%x): offering second-level success tx output to sweeper "+
- "with no deadline and budget=%v at height=%v", h,
- h.htlc.RHash[:], budget, waitHeight)
+ // If this is an output on the remote party's commitment transaction,
+ // use the direct-spend path to sweep the htlc.
+ case h.isRemoteCommitOutput():
+ err = h.resolveRemoteCommitOutput()
- // TODO(roasbeef): need to update above for leased types
- _, err = h.Sweeper.SweepInput(
- inp,
- sweep.Params{
- Budget: budget,
+ // If this is an output on our commitment transaction using post-anchor
+ // channel type, it will be handled by the sweeper.
+ case h.isZeroFeeOutput():
+ err = h.resolveSuccessTx()
- // For second level success tx, there's no rush to get
- // it confirmed, so we use a nil deadline.
- DeadlineHeight: fn.None[int32](),
- },
- )
- if err != nil {
- return nil, err
+ // If this is an output on our own commitment using pre-anchor channel
+ // type, we will publish the success tx and offer the output to the
+ // nursery.
+ default:
+ err = h.resolveLegacySuccessTx()
}
- // Will return this outpoint, when this is spent the resolver is fully
- // resolved.
- return op, nil
+ return nil, err
}
// resolveRemoteCommitOutput handles sweeping an HTLC output on the remote
// commitment with the preimage. In this case we can sweep the output directly,
// and don't have to broadcast a second-level transaction.
-func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) (
- ContractResolver, error) {
-
- isTaproot := txscript.IsPayToTaproot(
- h.htlcResolution.SweepSignDesc.Output.PkScript,
- )
-
- // Before we can craft out sweeping transaction, we need to
- // create an input which contains all the items required to add
- // this input to a sweeping transaction, and generate a
- // witness.
- var inp input.Input
- if isTaproot {
- inp = lnutils.Ptr(input.MakeTaprootHtlcSucceedInput(
- &h.htlcResolution.ClaimOutpoint,
- &h.htlcResolution.SweepSignDesc,
- h.htlcResolution.Preimage[:],
- h.broadcastHeight,
- h.htlcResolution.CsvDelay,
- input.WithResolutionBlob(
- h.htlcResolution.ResolutionBlob,
- ),
- ))
- } else {
- inp = lnutils.Ptr(input.MakeHtlcSucceedInput(
- &h.htlcResolution.ClaimOutpoint,
- &h.htlcResolution.SweepSignDesc,
- h.htlcResolution.Preimage[:],
- h.broadcastHeight,
- h.htlcResolution.CsvDelay,
- ))
- }
-
- // Calculate the budget for this sweep.
- budget := calculateBudget(
- btcutil.Amount(inp.SignDesc().Output.Value),
- h.Budget.DeadlineHTLCRatio,
- h.Budget.DeadlineHTLC,
- )
-
- deadline := fn.Some(int32(h.htlc.RefundTimeout))
-
- log.Infof("%T(%x): offering direct-preimage HTLC output to sweeper "+
- "with deadline=%v, budget=%v", h, h.htlc.RHash[:],
- h.htlc.RefundTimeout, budget)
-
- // We'll now offer the direct preimage HTLC to the sweeper.
- _, err := h.Sweeper.SweepInput(
- inp,
- sweep.Params{
- Budget: budget,
- DeadlineHeight: deadline,
- Immediate: immediate,
- },
- )
- if err != nil {
- return nil, err
- }
+func (h *htlcSuccessResolver) resolveRemoteCommitOutput() error {
+ h.log.Info("waiting for direct-preimage spend of the htlc to confirm")
// Wait for the direct-preimage HTLC sweep tx to confirm.
+ //
+ // TODO(yy): use the result chan returned from `SweepInput`.
sweepTxDetails, err := waitForSpend(
&h.htlcResolution.ClaimOutpoint,
h.htlcResolution.SweepSignDesc.Output.PkScript,
h.broadcastHeight, h.Notifier, h.quit,
)
if err != nil {
- return nil, err
+ return err
}
- // Once the transaction has received a sufficient number of
- // confirmations, we'll mark ourselves as fully resolved and exit.
- h.resolved = true
+ // TODO(yy): should also update the `RecoveredBalance` and
+ // `LimboBalance` like other paths?
// Checkpoint the resolver, and write the outcome to disk.
- return nil, h.checkpointClaim(
- sweepTxDetails.SpenderTxHash,
- channeldb.ResolverOutcomeClaimed,
- )
+ return h.checkpointClaim(sweepTxDetails.SpenderTxHash)
}
// checkpointClaim checkpoints the success resolver with the reports it needs.
// If this htlc was claimed two stages, it will write reports for both stages,
// otherwise it will just write for the single htlc claim.
-func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash,
- outcome channeldb.ResolverOutcome) error {
-
+func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash) error {
// Mark the htlc as final settled.
err := h.ChainArbitratorConfig.PutFinalHtlcOutcome(
h.ChannelArbitratorConfig.ShortChanID, h.htlc.HtlcIndex, true,
@@ -556,7 +198,7 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash,
OutPoint: h.htlcResolution.ClaimOutpoint,
Amount: amt,
ResolverType: channeldb.ResolverTypeIncomingHtlc,
- ResolverOutcome: outcome,
+ ResolverOutcome: channeldb.ResolverOutcomeClaimed,
SpendTxID: spendTx,
},
}
@@ -581,6 +223,7 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash,
}
// Finally, we checkpoint the resolver with our report(s).
+ h.markResolved()
return h.Checkpoint(h, reports...)
}
@@ -589,15 +232,10 @@ func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash,
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcSuccessResolver) Stop() {
- close(h.quit)
-}
+ h.log.Debugf("stopping...")
+ defer h.log.Debugf("stopped")
-// IsResolved returns true if the stored state in the resolve is fully
-// resolved. In this case the target output can be forgotten.
-//
-// NOTE: Part of the ContractResolver interface.
-func (h *htlcSuccessResolver) IsResolved() bool {
- return h.resolved
+ close(h.quit)
}
// report returns a report on the resolution state of the contract.
@@ -649,7 +287,7 @@ func (h *htlcSuccessResolver) Encode(w io.Writer) error {
if err := binary.Write(w, endian, h.outputIncubating); err != nil {
return err
}
- if err := binary.Write(w, endian, h.resolved); err != nil {
+ if err := binary.Write(w, endian, h.IsResolved()); err != nil {
return err
}
if err := binary.Write(w, endian, h.broadcastHeight); err != nil {
@@ -688,9 +326,15 @@ func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) (
if err := binary.Read(r, endian, &h.outputIncubating); err != nil {
return nil, err
}
- if err := binary.Read(r, endian, &h.resolved); err != nil {
+
+ var resolved bool
+ if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
+ if resolved {
+ h.markResolved()
+ }
+
if err := binary.Read(r, endian, &h.broadcastHeight); err != nil {
return nil, err
}
@@ -709,6 +353,7 @@ func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) (
}
h.initReport()
+ h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint()))
return h, nil
}
@@ -737,3 +382,391 @@ func (h *htlcSuccessResolver) SupplementDeadline(_ fn.Option[int32]) {
// A compile time assertion to ensure htlcSuccessResolver meets the
// ContractResolver interface.
var _ htlcContractResolver = (*htlcSuccessResolver)(nil)
+
+// isRemoteCommitOutput returns a bool to indicate whether the htlc output is
+// on the remote commitment.
+func (h *htlcSuccessResolver) isRemoteCommitOutput() bool {
+ // If we don't have a success transaction, then this means that this is
+ // an output on the remote party's commitment transaction.
+ return h.htlcResolution.SignedSuccessTx == nil
+}
+
+// isZeroFeeOutput returns a boolean indicating whether the htlc output is from
+// a anchor-enabled channel, which uses the sighash SINGLE|ANYONECANPAY.
+func (h *htlcSuccessResolver) isZeroFeeOutput() bool {
+ // If we have non-nil SignDetails, this means it has a 2nd level HTLC
+ // transaction that is signed using sighash SINGLE|ANYONECANPAY (the
+ // case for anchor type channels). In this case we can re-sign it and
+ // attach fees at will.
+ return h.htlcResolution.SignedSuccessTx != nil &&
+ h.htlcResolution.SignDetails != nil
+}
+
+// isTaproot returns true if the resolver is for a taproot output.
+func (h *htlcSuccessResolver) isTaproot() bool {
+ return txscript.IsPayToTaproot(
+ h.htlcResolution.SweepSignDesc.Output.PkScript,
+ )
+}
+
+// sweepRemoteCommitOutput creates a sweep request to sweep the HTLC output on
+// the remote commitment via the direct preimage-spend.
+func (h *htlcSuccessResolver) sweepRemoteCommitOutput() error {
+ // Before we can craft out sweeping transaction, we need to create an
+ // input which contains all the items required to add this input to a
+ // sweeping transaction, and generate a witness.
+ var inp input.Input
+
+ if h.isTaproot() {
+ inp = lnutils.Ptr(input.MakeTaprootHtlcSucceedInput(
+ &h.htlcResolution.ClaimOutpoint,
+ &h.htlcResolution.SweepSignDesc,
+ h.htlcResolution.Preimage[:],
+ h.broadcastHeight,
+ h.htlcResolution.CsvDelay,
+ input.WithResolutionBlob(
+ h.htlcResolution.ResolutionBlob,
+ ),
+ ))
+ } else {
+ inp = lnutils.Ptr(input.MakeHtlcSucceedInput(
+ &h.htlcResolution.ClaimOutpoint,
+ &h.htlcResolution.SweepSignDesc,
+ h.htlcResolution.Preimage[:],
+ h.broadcastHeight,
+ h.htlcResolution.CsvDelay,
+ ))
+ }
+
+ // Calculate the budget for this sweep.
+ budget := calculateBudget(
+ btcutil.Amount(inp.SignDesc().Output.Value),
+ h.Budget.DeadlineHTLCRatio,
+ h.Budget.DeadlineHTLC,
+ )
+
+ deadline := fn.Some(int32(h.htlc.RefundTimeout))
+
+ log.Infof("%T(%x): offering direct-preimage HTLC output to sweeper "+
+ "with deadline=%v, budget=%v", h, h.htlc.RHash[:],
+ h.htlc.RefundTimeout, budget)
+
+ // We'll now offer the direct preimage HTLC to the sweeper.
+ _, err := h.Sweeper.SweepInput(
+ inp,
+ sweep.Params{
+ Budget: budget,
+ DeadlineHeight: deadline,
+ },
+ )
+
+ return err
+}
+
+// sweepSuccessTx attempts to sweep the second level success tx.
+func (h *htlcSuccessResolver) sweepSuccessTx() error {
+ var secondLevelInput input.HtlcSecondLevelAnchorInput
+ if h.isTaproot() {
+ secondLevelInput = input.MakeHtlcSecondLevelSuccessTaprootInput(
+ h.htlcResolution.SignedSuccessTx,
+ h.htlcResolution.SignDetails, h.htlcResolution.Preimage,
+ h.broadcastHeight, input.WithResolutionBlob(
+ h.htlcResolution.ResolutionBlob,
+ ),
+ )
+ } else {
+ secondLevelInput = input.MakeHtlcSecondLevelSuccessAnchorInput(
+ h.htlcResolution.SignedSuccessTx,
+ h.htlcResolution.SignDetails, h.htlcResolution.Preimage,
+ h.broadcastHeight,
+ )
+ }
+
+ // Calculate the budget for this sweep.
+ value := btcutil.Amount(secondLevelInput.SignDesc().Output.Value)
+ budget := calculateBudget(
+ value, h.Budget.DeadlineHTLCRatio, h.Budget.DeadlineHTLC,
+ )
+
+ // The deadline would be the CLTV in this HTLC output. If we are the
+ // initiator of this force close, with the default
+ // `IncomingBroadcastDelta`, it means we have 10 blocks left when going
+ // onchain.
+ deadline := fn.Some(int32(h.htlc.RefundTimeout))
+
+ h.log.Infof("offering second-level HTLC success tx to sweeper with "+
+ "deadline=%v, budget=%v", h.htlc.RefundTimeout, budget)
+
+ // We'll now offer the second-level transaction to the sweeper.
+ _, err := h.Sweeper.SweepInput(
+ &secondLevelInput,
+ sweep.Params{
+ Budget: budget,
+ DeadlineHeight: deadline,
+ },
+ )
+
+ return err
+}
+
+// sweepSuccessTxOutput attempts to sweep the output of the second level
+// success tx.
+func (h *htlcSuccessResolver) sweepSuccessTxOutput() error {
+ h.log.Debugf("sweeping output %v from 2nd-level HTLC success tx",
+ h.htlcResolution.ClaimOutpoint)
+
+ // This should be non-blocking as we will only attempt to sweep the
+ // output when the second level tx has already been confirmed. In other
+ // words, waitForSpend will return immediately.
+ commitSpend, err := waitForSpend(
+ &h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint,
+ h.htlcResolution.SignDetails.SignDesc.Output.PkScript,
+ h.broadcastHeight, h.Notifier, h.quit,
+ )
+ if err != nil {
+ return err
+ }
+
+ // The HTLC success tx has a CSV lock that we must wait for, and if
+ // this is a lease enforced channel and we're the imitator, we may need
+ // to wait for longer.
+ waitHeight := h.deriveWaitHeight(h.htlcResolution.CsvDelay, commitSpend)
+
+ // Now that the sweeper has broadcasted the second-level transaction,
+ // it has confirmed, and we have checkpointed our state, we'll sweep
+ // the second level output. We report the resolver has moved the next
+ // stage.
+ h.reportLock.Lock()
+ h.currentReport.Stage = 2
+ h.currentReport.MaturityHeight = waitHeight
+ h.reportLock.Unlock()
+
+ if h.hasCLTV() {
+ log.Infof("%T(%x): waiting for CSV and CLTV lock to expire at "+
+ "height %v", h, h.htlc.RHash[:], waitHeight)
+ } else {
+ log.Infof("%T(%x): waiting for CSV lock to expire at height %v",
+ h, h.htlc.RHash[:], waitHeight)
+ }
+
+ // We'll use this input index to determine the second-level output
+ // index on the transaction, as the signatures requires the indexes to
+ // be the same. We don't look for the second-level output script
+ // directly, as there might be more than one HTLC output to the same
+ // pkScript.
+ op := &wire.OutPoint{
+ Hash: *commitSpend.SpenderTxHash,
+ Index: commitSpend.SpenderInputIndex,
+ }
+
+ // Let the sweeper sweep the second-level output now that the
+ // CSV/CLTV locks have expired.
+ var witType input.StandardWitnessType
+ if h.isTaproot() {
+ witType = input.TaprootHtlcAcceptedSuccessSecondLevel
+ } else {
+ witType = input.HtlcAcceptedSuccessSecondLevel
+ }
+ inp := h.makeSweepInput(
+ op, witType,
+ input.LeaseHtlcAcceptedSuccessSecondLevel,
+ &h.htlcResolution.SweepSignDesc,
+ h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight),
+ h.htlc.RHash, h.htlcResolution.ResolutionBlob,
+ )
+
+ // Calculate the budget for this sweep.
+ budget := calculateBudget(
+ btcutil.Amount(inp.SignDesc().Output.Value),
+ h.Budget.NoDeadlineHTLCRatio,
+ h.Budget.NoDeadlineHTLC,
+ )
+
+ log.Infof("%T(%x): offering second-level success tx output to sweeper "+
+ "with no deadline and budget=%v at height=%v", h,
+ h.htlc.RHash[:], budget, waitHeight)
+
+ // TODO(yy): use the result chan returned from SweepInput.
+ _, err = h.Sweeper.SweepInput(
+ inp,
+ sweep.Params{
+ Budget: budget,
+
+ // For second level success tx, there's no rush to get
+ // it confirmed, so we use a nil deadline.
+ DeadlineHeight: fn.None[int32](),
+ },
+ )
+
+ return err
+}
+
+// resolveLegacySuccessTx handles an HTLC output from a pre-anchor type channel
+// by broadcasting the second-level success transaction.
+func (h *htlcSuccessResolver) resolveLegacySuccessTx() error {
+ // Otherwise we'll publish the second-level transaction directly and
+ // offer the resolution to the nursery to handle.
+ h.log.Infof("broadcasting legacy second-level success tx: %v",
+ h.htlcResolution.SignedSuccessTx.TxHash())
+
+ // We'll now broadcast the second layer transaction so we can kick off
+ // the claiming process.
+ //
+ // TODO(yy): offer it to the sweeper instead.
+ label := labels.MakeLabel(
+ labels.LabelTypeChannelClose, &h.ShortChanID,
+ )
+ err := h.PublishTx(h.htlcResolution.SignedSuccessTx, label)
+ if err != nil {
+ return err
+ }
+
+ // Fast-forward to resolve the output from the success tx if the it has
+ // already been sent to the UtxoNursery.
+ if h.outputIncubating {
+ return h.resolveSuccessTxOutput(h.htlcResolution.ClaimOutpoint)
+ }
+
+ h.log.Infof("incubating incoming htlc output")
+
+ // Send the output to the incubator.
+ err = h.IncubateOutputs(
+ h.ChanPoint, fn.None[lnwallet.OutgoingHtlcResolution](),
+ fn.Some(h.htlcResolution),
+ h.broadcastHeight, fn.Some(int32(h.htlc.RefundTimeout)),
+ )
+ if err != nil {
+ return err
+ }
+
+ // Mark the output as incubating and checkpoint it.
+ h.outputIncubating = true
+ if err := h.Checkpoint(h); err != nil {
+ return err
+ }
+
+ // Move to resolve the output.
+ return h.resolveSuccessTxOutput(h.htlcResolution.ClaimOutpoint)
+}
+
+// resolveSuccessTx waits for the sweeping tx of the second-level success tx to
+// confirm and offers the output from the success tx to the sweeper.
+func (h *htlcSuccessResolver) resolveSuccessTx() error {
+ h.log.Infof("waiting for 2nd-level HTLC success transaction to confirm")
+
+ // Create aliases to make the code more readable.
+ outpoint := h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint
+ pkScript := h.htlcResolution.SignDetails.SignDesc.Output.PkScript
+
+ // Wait for the second level transaction to confirm.
+ commitSpend, err := waitForSpend(
+ &outpoint, pkScript, h.broadcastHeight, h.Notifier, h.quit,
+ )
+ if err != nil {
+ return err
+ }
+
+ // We'll use this input index to determine the second-level output
+ // index on the transaction, as the signatures requires the indexes to
+ // be the same. We don't look for the second-level output script
+ // directly, as there might be more than one HTLC output to the same
+ // pkScript.
+ op := wire.OutPoint{
+ Hash: *commitSpend.SpenderTxHash,
+ Index: commitSpend.SpenderInputIndex,
+ }
+
+ // If the 2nd-stage sweeping has already been started, we can
+ // fast-forward to start the resolving process for the stage two
+ // output.
+ if h.outputIncubating {
+ return h.resolveSuccessTxOutput(op)
+ }
+
+ // Now that the second-level transaction has confirmed, we checkpoint
+ // the state so we'll go to the next stage in case of restarts.
+ h.outputIncubating = true
+ if err := h.Checkpoint(h); err != nil {
+ log.Errorf("unable to Checkpoint: %v", err)
+ return err
+ }
+
+ h.log.Infof("2nd-level HTLC success tx=%v confirmed",
+ commitSpend.SpenderTxHash)
+
+ // Send the sweep request for the output from the success tx.
+ if err := h.sweepSuccessTxOutput(); err != nil {
+ return err
+ }
+
+ return h.resolveSuccessTxOutput(op)
+}
+
+// resolveSuccessTxOutput waits for the spend of the output from the 2nd-level
+// success tx.
+func (h *htlcSuccessResolver) resolveSuccessTxOutput(op wire.OutPoint) error {
+ // To wrap this up, we'll wait until the second-level transaction has
+ // been spent, then fully resolve the contract.
+ log.Infof("%T(%x): waiting for second-level HTLC output to be spent "+
+ "after csv_delay=%v", h, h.htlc.RHash[:],
+ h.htlcResolution.CsvDelay)
+
+ spend, err := waitForSpend(
+ &op, h.htlcResolution.SweepSignDesc.Output.PkScript,
+ h.broadcastHeight, h.Notifier, h.quit,
+ )
+ if err != nil {
+ return err
+ }
+
+ h.reportLock.Lock()
+ h.currentReport.RecoveredBalance = h.currentReport.LimboBalance
+ h.currentReport.LimboBalance = 0
+ h.reportLock.Unlock()
+
+ return h.checkpointClaim(spend.SpenderTxHash)
+}
+
+// Launch creates an input based on the details of the incoming htlc resolution
+// and offers it to the sweeper.
+func (h *htlcSuccessResolver) Launch() error {
+ if h.isLaunched() {
+ h.log.Tracef("already launched")
+ return nil
+ }
+
+ h.log.Debugf("launching resolver...")
+ h.markLaunched()
+
+ switch {
+ // If we're already resolved, then we can exit early.
+ case h.IsResolved():
+ h.log.Errorf("already resolved")
+ return nil
+
+ // If this is an output on the remote party's commitment transaction,
+ // use the direct-spend path.
+ case h.isRemoteCommitOutput():
+ return h.sweepRemoteCommitOutput()
+
+ // If this is an anchor type channel, we now sweep either the
+ // second-level success tx or the output from the second-level success
+ // tx.
+ case h.isZeroFeeOutput():
+ // If the second-level success tx has already been swept, we
+ // can go ahead and sweep its output.
+ if h.outputIncubating {
+ return h.sweepSuccessTxOutput()
+ }
+
+ // Otherwise, sweep the second level tx.
+ return h.sweepSuccessTx()
+
+ // If this is a legacy channel type, the output is handled by the
+ // nursery via the Resolve so we do nothing here.
+ //
+ // TODO(yy): handle the legacy output by offering it to the sweeper.
+ default:
+ return nil
+ }
+}
diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go
index 23023729fa..fe6ee1ad0e 100644
--- a/contractcourt/htlc_success_resolver_test.go
+++ b/contractcourt/htlc_success_resolver_test.go
@@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"testing"
+ "time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
@@ -12,7 +13,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
@@ -20,6 +21,7 @@ import (
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
+ "github.com/stretchr/testify/require"
)
var testHtlcAmt = lnwire.MilliSatoshi(200000)
@@ -39,6 +41,15 @@ type htlcResolverTestContext struct {
t *testing.T
}
+func newHtlcResolverTestContextFromReader(t *testing.T,
+ newResolver func(htlc channeldb.HTLC,
+ cfg ResolverConfig) ContractResolver) *htlcResolverTestContext {
+
+ ctx := newHtlcResolverTestContext(t, newResolver)
+
+ return ctx
+}
+
func newHtlcResolverTestContext(t *testing.T,
newResolver func(htlc channeldb.HTLC,
cfg ResolverConfig) ContractResolver) *htlcResolverTestContext {
@@ -133,8 +144,12 @@ func newHtlcResolverTestContext(t *testing.T,
func (i *htlcResolverTestContext) resolve() {
// Start resolver.
i.resolverResultChan = make(chan resolveResult, 1)
+
go func() {
- nextResolver, err := i.resolver.Resolve(false)
+ err := i.resolver.Launch()
+ require.NoError(i.t, err)
+
+ nextResolver, err := i.resolver.Resolve()
i.resolverResultChan <- resolveResult{
nextResolver: nextResolver,
err: err,
@@ -192,6 +207,7 @@ func TestHtlcSuccessSingleStage(t *testing.T) {
// sweeper.
details := &chainntnfs.SpendDetail{
SpendingTx: sweepTx,
+ SpentOutPoint: &htlcOutpoint,
SpenderTxHash: &sweepTxid,
}
ctx.notifier.SpendChan <- details
@@ -215,8 +231,8 @@ func TestHtlcSuccessSingleStage(t *testing.T) {
)
}
-// TestSecondStageResolution tests successful sweep of a second stage htlc
-// claim, going through the Nursery.
+// TestHtlcSuccessSecondStageResolution tests successful sweep of a second
+// stage htlc claim, going through the Nursery.
func TestHtlcSuccessSecondStageResolution(t *testing.T) {
commitOutpoint := wire.OutPoint{Index: 2}
htlcOutpoint := wire.OutPoint{Index: 3}
@@ -279,6 +295,7 @@ func TestHtlcSuccessSecondStageResolution(t *testing.T) {
ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: sweepTx,
+ SpentOutPoint: &htlcOutpoint,
SpenderTxHash: &sweepHash,
}
@@ -302,6 +319,8 @@ func TestHtlcSuccessSecondStageResolution(t *testing.T) {
// TestHtlcSuccessSecondStageResolutionSweeper test that a resolver with
// non-nil SignDetails will offer the second-level transaction to the sweeper
// for re-signing.
+//
+//nolint:ll
func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) {
commitOutpoint := wire.OutPoint{Index: 2}
htlcOutpoint := wire.OutPoint{Index: 3}
@@ -399,7 +418,20 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) {
_ bool) error {
resolver := ctx.resolver.(*htlcSuccessResolver)
- inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs
+
+ var (
+ inp input.Input
+ ok bool
+ )
+
+ select {
+ case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs:
+ require.True(t, ok)
+
+ case <-time.After(1 * time.Second):
+ t.Fatal("expected input to be swept")
+ }
+
op := inp.OutPoint()
if op != commitOutpoint {
return fmt.Errorf("outpoint %v swept, "+
@@ -412,6 +444,7 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) {
SpenderTxHash: &reSignedHash,
SpenderInputIndex: 1,
SpendingHeight: 10,
+ SpentOutPoint: &commitOutpoint,
}
return nil
},
@@ -434,17 +467,37 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) {
SpenderTxHash: &reSignedHash,
SpenderInputIndex: 1,
SpendingHeight: 10,
+ SpentOutPoint: &commitOutpoint,
}
}
- ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{
- Height: 13,
- }
-
// We expect it to sweep the second-level
// transaction we notfied about above.
resolver := ctx.resolver.(*htlcSuccessResolver)
- inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs
+
+ // Mock `waitForSpend` to return the commit
+ // spend.
+ ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
+ SpendingTx: reSignedSuccessTx,
+ SpenderTxHash: &reSignedHash,
+ SpenderInputIndex: 1,
+ SpendingHeight: 10,
+ SpentOutPoint: &commitOutpoint,
+ }
+
+ var (
+ inp input.Input
+ ok bool
+ )
+
+ select {
+ case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs:
+ require.True(t, ok)
+
+ case <-time.After(1 * time.Second):
+ t.Fatal("expected input to be swept")
+ }
+
op := inp.OutPoint()
exp := wire.OutPoint{
Hash: reSignedHash,
@@ -461,6 +514,7 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) {
SpendingTx: sweepTx,
SpenderTxHash: &sweepHash,
SpendingHeight: 14,
+ SpentOutPoint: &op,
}
return nil
@@ -508,11 +562,14 @@ func testHtlcSuccess(t *testing.T, resolution lnwallet.IncomingHtlcResolution,
// for the next portion of the test.
ctx := newHtlcResolverTestContext(t,
func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver {
- return &htlcSuccessResolver{
+ r := &htlcSuccessResolver{
contractResolverKit: *newContractResolverKit(cfg),
htlc: htlc,
htlcResolution: resolution,
}
+ r.initLogger("htlcSuccessResolver")
+
+ return r
},
)
@@ -562,11 +619,11 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext,
var resolved, incubating bool
if h, ok := resolver.(*htlcSuccessResolver); ok {
- resolved = h.resolved
+ resolved = h.resolved.Load()
incubating = h.outputIncubating
}
if h, ok := resolver.(*htlcTimeoutResolver); ok {
- resolved = h.resolved
+ resolved = h.resolved.Load()
incubating = h.outputIncubating
}
@@ -610,7 +667,12 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext,
checkpointedState = append(checkpointedState, b.Bytes())
nextCheckpoint++
- checkpointChan <- struct{}{}
+ select {
+ case checkpointChan <- struct{}{}:
+ case <-time.After(1 * time.Second):
+ t.Fatal("checkpoint timeout")
+ }
+
return nil
}
@@ -621,6 +683,8 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext,
// preCheckpoint logic if needed.
resumed := true
for i, cp := range expectedCheckpoints {
+ t.Logf("Running checkpoint %d", i)
+
if cp.preCheckpoint != nil {
if err := cp.preCheckpoint(ctx, resumed); err != nil {
t.Fatalf("failure at stage %d: %v", i, err)
@@ -629,15 +693,15 @@ func runFromCheckpoint(t *testing.T, ctx *htlcResolverTestContext,
resumed = false
// Wait for the resolver to have checkpointed its state.
- <-checkpointChan
+ select {
+ case <-checkpointChan:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("resolver did not checkpoint at stage %d", i)
+ }
}
// Wait for the resolver to fully complete.
ctx.waitForResult()
- if nextCheckpoint < len(expectedCheckpoints) {
- t.Fatalf("not all checkpoints hit")
- }
-
return checkpointedState
}
diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go
index 9954c3c0db..1782cfb3ba 100644
--- a/contractcourt/htlc_timeout_resolver.go
+++ b/contractcourt/htlc_timeout_resolver.go
@@ -7,12 +7,13 @@ import (
"sync"
"github.com/btcsuite/btcd/btcutil"
+ "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
@@ -37,9 +38,6 @@ type htlcTimeoutResolver struct {
// incubator (utxo nursery).
outputIncubating bool
- // resolved reflects if the contract has been fully resolved or not.
- resolved bool
-
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
@@ -82,6 +80,7 @@ func newTimeoutResolver(res lnwallet.OutgoingHtlcResolution,
}
h.initReport()
+ h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint()))
return h
}
@@ -93,23 +92,25 @@ func (h *htlcTimeoutResolver) isTaproot() bool {
)
}
-// ResolverKey returns an identifier which should be globally unique for this
-// particular resolver within the chain the original contract resides within.
-//
-// NOTE: Part of the ContractResolver interface.
-func (h *htlcTimeoutResolver) ResolverKey() []byte {
+// outpoint returns the outpoint of the HTLC output we're attempting to sweep.
+func (h *htlcTimeoutResolver) outpoint() wire.OutPoint {
// The primary key for this resolver will be the outpoint of the HTLC
// on the commitment transaction itself. If this is our commitment,
// then the output can be found within the signed timeout tx,
// otherwise, it's just the ClaimOutpoint.
- var op wire.OutPoint
if h.htlcResolution.SignedTimeoutTx != nil {
- op = h.htlcResolution.SignedTimeoutTx.TxIn[0].PreviousOutPoint
- } else {
- op = h.htlcResolution.ClaimOutpoint
+ return h.htlcResolution.SignedTimeoutTx.TxIn[0].PreviousOutPoint
}
- key := newResolverID(op)
+ return h.htlcResolution.ClaimOutpoint
+}
+
+// ResolverKey returns an identifier which should be globally unique for this
+// particular resolver within the chain the original contract resides within.
+//
+// NOTE: Part of the ContractResolver interface.
+func (h *htlcTimeoutResolver) ResolverKey() []byte {
+ key := newResolverID(h.outpoint())
return key[:]
}
@@ -157,7 +158,7 @@ const (
// by the remote party. It'll extract the preimage, add it to the global cache,
// and finally send the appropriate clean up message.
func (h *htlcTimeoutResolver) claimCleanUp(
- commitSpend *chainntnfs.SpendDetail) (ContractResolver, error) {
+ commitSpend *chainntnfs.SpendDetail) error {
// Depending on if this is our commitment or not, then we'll be looking
// for a different witness pattern.
@@ -192,7 +193,7 @@ func (h *htlcTimeoutResolver) claimCleanUp(
// element, then we're actually on the losing side of a breach
// attempt...
case h.isTaproot() && len(spendingInput.Witness) == 1:
- return nil, fmt.Errorf("breach attempt failed")
+ return fmt.Errorf("breach attempt failed")
// Otherwise, they'll be spending directly from our commitment output.
// In which case the witness stack looks like:
@@ -209,8 +210,8 @@ func (h *htlcTimeoutResolver) claimCleanUp(
preimage, err := lntypes.MakePreimage(preimageBytes)
if err != nil {
- return nil, fmt.Errorf("unable to create pre-image from "+
- "witness: %v", err)
+ return fmt.Errorf("unable to create pre-image from witness: %w",
+ err)
}
log.Infof("%T(%v): extracting preimage=%v from on-chain "+
@@ -232,9 +233,9 @@ func (h *htlcTimeoutResolver) claimCleanUp(
HtlcIndex: h.htlc.HtlcIndex,
PreImage: &pre,
}); err != nil {
- return nil, err
+ return err
}
- h.resolved = true
+ h.markResolved()
// Checkpoint our resolver with a report which reflects the preimage
// claim by the remote party.
@@ -247,7 +248,7 @@ func (h *htlcTimeoutResolver) claimCleanUp(
SpendTxID: commitSpend.SpenderTxHash,
}
- return nil, h.Checkpoint(h, report)
+ return h.Checkpoint(h, report)
}
// chainDetailsToWatch returns the output and script which we use to watch for
@@ -418,70 +419,33 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool {
// see a direct sweep via the timeout clause.
//
// NOTE: Part of the ContractResolver interface.
-func (h *htlcTimeoutResolver) Resolve(
- immediate bool) (ContractResolver, error) {
-
+func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) {
// If we're already resolved, then we can exit early.
- if h.resolved {
+ if h.IsResolved() {
+ h.log.Errorf("already resolved")
return nil, nil
}
- // Start by spending the HTLC output, either by broadcasting the
- // second-level timeout transaction, or directly if this is the remote
- // commitment.
- commitSpend, err := h.spendHtlcOutput(immediate)
- if err != nil {
- return nil, err
+ // If this is an output on the remote party's commitment transaction,
+ // use the direct-spend path to sweep the htlc.
+ if h.isRemoteCommitOutput() {
+ return nil, h.resolveRemoteCommitOutput()
}
- // If the spend reveals the pre-image, then we'll enter the clean up
- // workflow to pass the pre-image back to the incoming link, add it to
- // the witness cache, and exit.
- if isPreimageSpend(
- h.isTaproot(), commitSpend,
- h.htlcResolution.SignedTimeoutTx != nil,
- ) {
-
- log.Infof("%T(%v): HTLC has been swept with pre-image by "+
- "remote party during timeout flow! Adding pre-image to "+
- "witness cache", h, h.htlc.RHash[:],
- h.htlcResolution.ClaimOutpoint)
-
- return h.claimCleanUp(commitSpend)
- }
-
- // At this point, the second-level transaction is sufficiently
- // confirmed, or a transaction directly spending the output is.
- // Therefore, we can now send back our clean up message, failing the
- // HTLC on the incoming link.
- //
- // NOTE: This can be called twice if the outgoing resolver restarts
- // before the second-stage timeout transaction is confirmed.
- log.Infof("%T(%v): resolving htlc with incoming fail msg, "+
- "fully confirmed", h, h.htlcResolution.ClaimOutpoint)
-
- failureMsg := &lnwire.FailPermanentChannelFailure{}
- err = h.DeliverResolutionMsg(ResolutionMsg{
- SourceChan: h.ShortChanID,
- HtlcIndex: h.htlc.HtlcIndex,
- Failure: failureMsg,
- })
- if err != nil {
- return nil, err
+ // If this is a zero-fee HTLC, we now handle the spend from our
+ // commitment transaction.
+ if h.isZeroFeeOutput() {
+ return nil, h.resolveTimeoutTx()
}
- // Depending on whether this was a local or remote commit, we must
- // handle the spending transaction accordingly.
- return h.handleCommitSpend(commitSpend)
+ // If this is an output on our own commitment using pre-anchor channel
+ // type, we will let the utxo nursery handle it.
+ return nil, h.resolveSecondLevelTxLegacy()
}
-// sweepSecondLevelTx sends a second level timeout transaction to the sweeper.
+// sweepTimeoutTx sends a second level timeout transaction to the sweeper.
// This transaction uses the SINLGE|ANYONECANPAY flag.
-func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error {
- log.Infof("%T(%x): offering second-layer timeout tx to sweeper: %v",
- h, h.htlc.RHash[:],
- spew.Sdump(h.htlcResolution.SignedTimeoutTx))
-
+func (h *htlcTimeoutResolver) sweepTimeoutTx() error {
var inp input.Input
if h.isTaproot() {
inp = lnutils.Ptr(input.MakeHtlcSecondLevelTimeoutTaprootInput(
@@ -512,33 +476,17 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error {
btcutil.Amount(inp.SignDesc().Output.Value), 2, 0,
)
+ h.log.Infof("offering 2nd-level HTLC timeout tx to sweeper "+
+ "with deadline=%v, budget=%v", h.incomingHTLCExpiryHeight,
+ budget)
+
// For an outgoing HTLC, it must be swept before the RefundTimeout of
// its incoming HTLC is reached.
- //
- // TODO(yy): we may end up mixing inputs with different time locks.
- // Suppose we have two outgoing HTLCs,
- // - HTLC1: nLocktime is 800000, CLTV delta is 80.
- // - HTLC2: nLocktime is 800001, CLTV delta is 79.
- // This means they would both have an incoming HTLC that expires at
- // 800080, hence they share the same deadline but different locktimes.
- // However, with current design, when we are at block 800000, HTLC1 is
- // offered to the sweeper. When block 800001 is reached, HTLC1's
- // sweeping process is already started, while HTLC2 is being offered to
- // the sweeper, so they won't be mixed. This can become an issue tho,
- // if we decide to sweep per X blocks. Or the contractcourt sees the
- // block first while the sweeper is only aware of the last block. To
- // properly fix it, we need `blockbeat` to make sure subsystems are in
- // sync.
- log.Infof("%T(%x): offering second-level HTLC timeout tx to sweeper "+
- "with deadline=%v, budget=%v", h, h.htlc.RHash[:],
- h.incomingHTLCExpiryHeight, budget)
-
_, err := h.Sweeper.SweepInput(
inp,
sweep.Params{
Budget: budget,
DeadlineHeight: h.incomingHTLCExpiryHeight,
- Immediate: immediate,
},
)
if err != nil {
@@ -548,12 +496,13 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error {
return err
}
-// sendSecondLevelTxLegacy sends a second level timeout transaction to the utxo
-// nursery. This transaction uses the legacy SIGHASH_ALL flag.
-func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error {
- log.Debugf("%T(%v): incubating htlc output", h,
- h.htlcResolution.ClaimOutpoint)
+// resolveSecondLevelTxLegacy sends a second level timeout transaction to the
+// utxo nursery. This transaction uses the legacy SIGHASH_ALL flag.
+func (h *htlcTimeoutResolver) resolveSecondLevelTxLegacy() error {
+ h.log.Debug("incubating htlc output")
+ // The utxo nursery will take care of broadcasting the second-level
+ // timeout tx and sweeping its output once it confirms.
err := h.IncubateOutputs(
h.ChanPoint, fn.Some(h.htlcResolution),
fn.None[lnwallet.IncomingHtlcResolution](),
@@ -563,16 +512,14 @@ func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error {
return err
}
- h.outputIncubating = true
-
- return h.Checkpoint(h)
+ return h.resolveTimeoutTx()
}
// sweepDirectHtlcOutput sends the direct spend of the HTLC output to the
// sweeper. This is used when the remote party goes on chain, and we're able to
// sweep an HTLC we offered after a timeout. Only the CLTV encumbered outputs
// are resolved via this path.
-func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error {
+func (h *htlcTimeoutResolver) sweepDirectHtlcOutput() error {
var htlcWitnessType input.StandardWitnessType
if h.isTaproot() {
htlcWitnessType = input.TaprootHtlcOfferedRemoteTimeout
@@ -612,7 +559,6 @@ func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error {
// This is an outgoing HTLC, so we want to make sure
// that we sweep it before the incoming HTLC expires.
DeadlineHeight: h.incomingHTLCExpiryHeight,
- Immediate: immediate,
},
)
if err != nil {
@@ -622,53 +568,6 @@ func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error {
return nil
}
-// spendHtlcOutput handles the initial spend of an HTLC output via the timeout
-// clause. If this is our local commitment, the second-level timeout TX will be
-// used to spend the output into the next stage. If this is the remote
-// commitment, the output will be swept directly without the timeout
-// transaction.
-func (h *htlcTimeoutResolver) spendHtlcOutput(
- immediate bool) (*chainntnfs.SpendDetail, error) {
-
- switch {
- // If we have non-nil SignDetails, this means that have a 2nd level
- // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY
- // (the case for anchor type channels). In this case we can re-sign it
- // and attach fees at will. We let the sweeper handle this job.
- case h.htlcResolution.SignDetails != nil && !h.outputIncubating:
- if err := h.sweepSecondLevelTx(immediate); err != nil {
- log.Errorf("Sending timeout tx to sweeper: %v", err)
-
- return nil, err
- }
-
- // If this is a remote commitment there's no second level timeout txn,
- // and we can just send this directly to the sweeper.
- case h.htlcResolution.SignedTimeoutTx == nil && !h.outputIncubating:
- if err := h.sweepDirectHtlcOutput(immediate); err != nil {
- log.Errorf("Sending direct spend to sweeper: %v", err)
-
- return nil, err
- }
-
- // If we have a SignedTimeoutTx but no SignDetails, this is a local
- // commitment for a non-anchor channel, so we'll send it to the utxo
- // nursery.
- case h.htlcResolution.SignDetails == nil && !h.outputIncubating:
- if err := h.sendSecondLevelTxLegacy(); err != nil {
- log.Errorf("Sending timeout tx to nursery: %v", err)
-
- return nil, err
- }
- }
-
- // Now that we've handed off the HTLC to the nursery or sweeper, we'll
- // watch for a spend of the output, and make our next move off of that.
- // Depending on if this is our commitment, or the remote party's
- // commitment, we'll be watching a different outpoint and script.
- return h.watchHtlcSpend()
-}
-
// watchHtlcSpend watches for a spend of the HTLC output. For neutrino backend,
// it will check blocks for the confirmed spend. For btcd and bitcoind, it will
// check both the mempool and the blocks.
@@ -697,9 +596,6 @@ func (h *htlcTimeoutResolver) watchHtlcSpend() (*chainntnfs.SpendDetail,
func (h *htlcTimeoutResolver) waitForConfirmedSpend(op *wire.OutPoint,
pkScript []byte) (*chainntnfs.SpendDetail, error) {
- log.Infof("%T(%v): waiting for spent of HTLC output %v to be "+
- "fully confirmed", h, h.htlcResolution.ClaimOutpoint, op)
-
// We'll block here until either we exit, or the HTLC output on the
// commitment transaction has been spent.
spend, err := waitForSpend(
@@ -709,239 +605,18 @@ func (h *htlcTimeoutResolver) waitForConfirmedSpend(op *wire.OutPoint,
return nil, err
}
- // Once confirmed, persist the state on disk.
- if err := h.checkPointSecondLevelTx(); err != nil {
- return nil, err
- }
-
return spend, err
}
-// checkPointSecondLevelTx persists the state of a second level HTLC tx to disk
-// if it's published by the sweeper.
-func (h *htlcTimeoutResolver) checkPointSecondLevelTx() error {
- // If this was the second level transaction published by the sweeper,
- // we can checkpoint the resolver now that it's confirmed.
- if h.htlcResolution.SignDetails != nil && !h.outputIncubating {
- h.outputIncubating = true
- if err := h.Checkpoint(h); err != nil {
- log.Errorf("unable to Checkpoint: %v", err)
- return err
- }
- }
-
- return nil
-}
-
-// handleCommitSpend handles the spend of the HTLC output on the commitment
-// transaction. If this was our local commitment, the spend will be he
-// confirmed second-level timeout transaction, and we'll sweep that into our
-// wallet. If the was a remote commitment, the resolver will resolve
-// immetiately.
-func (h *htlcTimeoutResolver) handleCommitSpend(
- commitSpend *chainntnfs.SpendDetail) (ContractResolver, error) {
-
- var (
- // claimOutpoint will be the outpoint of the second level
- // transaction, or on the remote commitment directly. It will
- // start out as set in the resolution, but we'll update it if
- // the second-level goes through the sweeper and changes its
- // txid.
- claimOutpoint = h.htlcResolution.ClaimOutpoint
-
- // spendTxID will be the ultimate spend of the claimOutpoint.
- // We set it to the commit spend for now, as this is the
- // ultimate spend in case this is a remote commitment. If we go
- // through the second-level transaction, we'll update this
- // accordingly.
- spendTxID = commitSpend.SpenderTxHash
-
- reports []*channeldb.ResolverReport
- )
-
- switch {
-
- // If we swept an HTLC directly off the remote party's commitment
- // transaction, then we can exit here as there's no second level sweep
- // to do.
- case h.htlcResolution.SignedTimeoutTx == nil:
- break
-
- // If the sweeper is handling the second level transaction, wait for
- // the CSV and possible CLTV lock to expire, before sweeping the output
- // on the second-level.
- case h.htlcResolution.SignDetails != nil:
- waitHeight := h.deriveWaitHeight(
- h.htlcResolution.CsvDelay, commitSpend,
- )
-
- h.reportLock.Lock()
- h.currentReport.Stage = 2
- h.currentReport.MaturityHeight = waitHeight
- h.reportLock.Unlock()
-
- if h.hasCLTV() {
- log.Infof("%T(%x): waiting for CSV and CLTV lock to "+
- "expire at height %v", h, h.htlc.RHash[:],
- waitHeight)
- } else {
- log.Infof("%T(%x): waiting for CSV lock to expire at "+
- "height %v", h, h.htlc.RHash[:], waitHeight)
- }
-
- // Deduct one block so this input is offered to the sweeper one
- // block earlier since the sweeper will wait for one block to
- // trigger the sweeping.
- //
- // TODO(yy): this is done so the outputs can be aggregated
- // properly. Suppose CSV locks of five 2nd-level outputs all
- // expire at height 840000, there is a race in block digestion
- // between contractcourt and sweeper:
- // - G1: block 840000 received in contractcourt, it now offers
- // the outputs to the sweeper.
- // - G2: block 840000 received in sweeper, it now starts to
- // sweep the received outputs - there's no guarantee all
- // fives have been received.
- // To solve this, we either offer the outputs earlier, or
- // implement `blockbeat`, and force contractcourt and sweeper
- // to consume each block sequentially.
- waitHeight--
-
- // TODO(yy): let sweeper handles the wait?
- err := waitForHeight(waitHeight, h.Notifier, h.quit)
- if err != nil {
- return nil, err
- }
-
- // We'll use this input index to determine the second-level
- // output index on the transaction, as the signatures requires
- // the indexes to be the same. We don't look for the
- // second-level output script directly, as there might be more
- // than one HTLC output to the same pkScript.
- op := &wire.OutPoint{
- Hash: *commitSpend.SpenderTxHash,
- Index: commitSpend.SpenderInputIndex,
- }
-
- var csvWitnessType input.StandardWitnessType
- if h.isTaproot() {
- //nolint:ll
- csvWitnessType = input.TaprootHtlcOfferedTimeoutSecondLevel
- } else {
- csvWitnessType = input.HtlcOfferedTimeoutSecondLevel
- }
-
- // Let the sweeper sweep the second-level output now that the
- // CSV/CLTV locks have expired.
- inp := h.makeSweepInput(
- op, csvWitnessType,
- input.LeaseHtlcOfferedTimeoutSecondLevel,
- &h.htlcResolution.SweepSignDesc,
- h.htlcResolution.CsvDelay,
- uint32(commitSpend.SpendingHeight), h.htlc.RHash,
- h.htlcResolution.ResolutionBlob,
- )
-
- // Calculate the budget for this sweep.
- budget := calculateBudget(
- btcutil.Amount(inp.SignDesc().Output.Value),
- h.Budget.NoDeadlineHTLCRatio,
- h.Budget.NoDeadlineHTLC,
- )
-
- log.Infof("%T(%x): offering second-level timeout tx output to "+
- "sweeper with no deadline and budget=%v at height=%v",
- h, h.htlc.RHash[:], budget, waitHeight)
-
- _, err = h.Sweeper.SweepInput(
- inp,
- sweep.Params{
- Budget: budget,
-
- // For second level success tx, there's no rush
- // to get it confirmed, so we use a nil
- // deadline.
- DeadlineHeight: fn.None[int32](),
- },
- )
- if err != nil {
- return nil, err
- }
-
- // Update the claim outpoint to point to the second-level
- // transaction created by the sweeper.
- claimOutpoint = *op
- fallthrough
-
- // Finally, if this was an output on our commitment transaction, we'll
- // wait for the second-level HTLC output to be spent, and for that
- // transaction itself to confirm.
- case h.htlcResolution.SignedTimeoutTx != nil:
- log.Infof("%T(%v): waiting for nursery/sweeper to spend CSV "+
- "delayed output", h, claimOutpoint)
-
- sweepTx, err := waitForSpend(
- &claimOutpoint,
- h.htlcResolution.SweepSignDesc.Output.PkScript,
- h.broadcastHeight, h.Notifier, h.quit,
- )
- if err != nil {
- return nil, err
- }
-
- // Update the spend txid to the hash of the sweep transaction.
- spendTxID = sweepTx.SpenderTxHash
-
- // Once our sweep of the timeout tx has confirmed, we add a
- // resolution for our timeoutTx tx first stage transaction.
- timeoutTx := commitSpend.SpendingTx
- index := commitSpend.SpenderInputIndex
- spendHash := commitSpend.SpenderTxHash
-
- reports = append(reports, &channeldb.ResolverReport{
- OutPoint: timeoutTx.TxIn[index].PreviousOutPoint,
- Amount: h.htlc.Amt.ToSatoshis(),
- ResolverType: channeldb.ResolverTypeOutgoingHtlc,
- ResolverOutcome: channeldb.ResolverOutcomeFirstStage,
- SpendTxID: spendHash,
- })
- }
-
- // With the clean up message sent, we'll now mark the contract
- // resolved, update the recovered balance, record the timeout and the
- // sweep txid on disk, and wait.
- h.resolved = true
- h.reportLock.Lock()
- h.currentReport.RecoveredBalance = h.currentReport.LimboBalance
- h.currentReport.LimboBalance = 0
- h.reportLock.Unlock()
-
- amt := btcutil.Amount(h.htlcResolution.SweepSignDesc.Output.Value)
- reports = append(reports, &channeldb.ResolverReport{
- OutPoint: claimOutpoint,
- Amount: amt,
- ResolverType: channeldb.ResolverTypeOutgoingHtlc,
- ResolverOutcome: channeldb.ResolverOutcomeTimeout,
- SpendTxID: spendTxID,
- })
-
- return nil, h.Checkpoint(h, reports...)
-}
-
// Stop signals the resolver to cancel any current resolution processes, and
// suspend.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcTimeoutResolver) Stop() {
- close(h.quit)
-}
+ h.log.Debugf("stopping...")
+ defer h.log.Debugf("stopped")
-// IsResolved returns true if the stored state in the resolve is fully
-// resolved. In this case the target output can be forgotten.
-//
-// NOTE: Part of the ContractResolver interface.
-func (h *htlcTimeoutResolver) IsResolved() bool {
- return h.resolved
+ close(h.quit)
}
// report returns a report on the resolution state of the contract.
@@ -1003,7 +678,7 @@ func (h *htlcTimeoutResolver) Encode(w io.Writer) error {
if err := binary.Write(w, endian, h.outputIncubating); err != nil {
return err
}
- if err := binary.Write(w, endian, h.resolved); err != nil {
+ if err := binary.Write(w, endian, h.IsResolved()); err != nil {
return err
}
if err := binary.Write(w, endian, h.broadcastHeight); err != nil {
@@ -1044,9 +719,15 @@ func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) (
if err := binary.Read(r, endian, &h.outputIncubating); err != nil {
return nil, err
}
- if err := binary.Read(r, endian, &h.resolved); err != nil {
+
+ var resolved bool
+ if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
+ if resolved {
+ h.markResolved()
+ }
+
if err := binary.Read(r, endian, &h.broadcastHeight); err != nil {
return nil, err
}
@@ -1066,6 +747,7 @@ func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) (
}
h.initReport()
+ h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint()))
return h, nil
}
@@ -1173,12 +855,6 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult,
// Create a result chan to hold the results.
result := &spendResult{}
- // hasMempoolSpend is a flag that indicates whether we have found a
- // preimage spend from the mempool. This is used to determine whether
- // to checkpoint the resolver or not when later we found the
- // corresponding block spend.
- hasMempoolSpent := false
-
// Wait for a spend event to arrive.
for {
select {
@@ -1206,23 +882,6 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult,
// Once confirmed, persist the state on disk if
// we haven't seen the output's spending tx in
// mempool before.
- //
- // NOTE: we don't checkpoint the resolver if
- // it's spending tx has already been found in
- // mempool - the resolver will take care of the
- // checkpoint in its `claimCleanUp`. If we do
- // checkpoint here, however, we'd create a new
- // record in db for the same htlc resolver
- // which won't be cleaned up later, resulting
- // the channel to stay in unresolved state.
- //
- // TODO(yy): when fee bumper is implemented, we
- // need to further check whether this is a
- // preimage spend. Also need to refactor here
- // to save us some indentation.
- if !hasMempoolSpent {
- result.err = h.checkPointSecondLevelTx()
- }
}
// Send the result and exit the loop.
@@ -1256,7 +915,7 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult,
// continue the loop.
hasPreimage := isPreimageSpend(
h.isTaproot(), spendDetail,
- h.htlcResolution.SignedTimeoutTx != nil,
+ !h.isRemoteCommitOutput(),
)
if !hasPreimage {
log.Debugf("HTLC output %s spent doesn't "+
@@ -1269,10 +928,6 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult,
result.spend = spendDetail
resultChan <- result
- // Set the hasMempoolSpent flag to true so we won't
- // checkpoint the resolver again in db.
- hasMempoolSpent = true
-
continue
// If the resolver exits, we exit the goroutine.
@@ -1284,3 +939,379 @@ func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult,
}
}
}
+
+// isRemoteCommitOutput returns a bool to indicate whether the htlc output is
+// on the remote commitment.
+func (h *htlcTimeoutResolver) isRemoteCommitOutput() bool {
+ // If we don't have a timeout transaction, then this means that this is
+ // an output on the remote party's commitment transaction.
+ return h.htlcResolution.SignedTimeoutTx == nil
+}
+
+// isZeroFeeOutput returns a boolean indicating whether the htlc output is from
+// a anchor-enabled channel, which uses the sighash SINGLE|ANYONECANPAY.
+func (h *htlcTimeoutResolver) isZeroFeeOutput() bool {
+ // If we have non-nil SignDetails, this means it has a 2nd level HTLC
+ // transaction that is signed using sighash SINGLE|ANYONECANPAY (the
+ // case for anchor type channels). In this case we can re-sign it and
+ // attach fees at will.
+ return h.htlcResolution.SignedTimeoutTx != nil &&
+ h.htlcResolution.SignDetails != nil
+}
+
+// waitHtlcSpendAndCheckPreimage waits for the htlc output to be spent and
+// checks whether the spending reveals the preimage. If the preimage is found,
+// it will be added to the preimage beacon to settle the incoming link, and a
+// nil spend details will be returned. Otherwise, the spend details will be
+// returned, indicating this is a non-preimage spend.
+func (h *htlcTimeoutResolver) waitHtlcSpendAndCheckPreimage() (
+ *chainntnfs.SpendDetail, error) {
+
+ // Wait for the htlc output to be spent, which can happen in one of the
+ // paths,
+ // 1. The remote party spends the htlc output using the preimage.
+ // 2. The local party spends the htlc timeout tx from the local
+ // commitment.
+ // 3. The local party spends the htlc output directlt from the remote
+ // commitment.
+ spend, err := h.watchHtlcSpend()
+ if err != nil {
+ return nil, err
+ }
+
+ // If the spend reveals the pre-image, then we'll enter the clean up
+ // workflow to pass the preimage back to the incoming link, add it to
+ // the witness cache, and exit.
+ if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) {
+ return nil, h.claimCleanUp(spend)
+ }
+
+ return spend, nil
+}
+
+// sweepTimeoutTxOutput attempts to sweep the output of the second level
+// timeout tx.
+func (h *htlcTimeoutResolver) sweepTimeoutTxOutput() error {
+ h.log.Debugf("sweeping output %v from 2nd-level HTLC timeout tx",
+ h.htlcResolution.ClaimOutpoint)
+
+ // This should be non-blocking as we will only attempt to sweep the
+ // output when the second level tx has already been confirmed. In other
+ // words, waitHtlcSpendAndCheckPreimage will return immediately.
+ commitSpend, err := h.waitHtlcSpendAndCheckPreimage()
+ if err != nil {
+ return err
+ }
+
+ // Exit early if the spend is nil, as this means it's a remote spend
+ // using the preimage path, which is handled in claimCleanUp.
+ if commitSpend == nil {
+ h.log.Infof("preimage spend detected, skipping 2nd-level " +
+ "HTLC output sweep")
+
+ return nil
+ }
+
+ waitHeight := h.deriveWaitHeight(h.htlcResolution.CsvDelay, commitSpend)
+
+ // Now that the sweeper has broadcasted the second-level transaction,
+ // it has confirmed, and we have checkpointed our state, we'll sweep
+ // the second level output. We report the resolver has moved the next
+ // stage.
+ h.reportLock.Lock()
+ h.currentReport.Stage = 2
+ h.currentReport.MaturityHeight = waitHeight
+ h.reportLock.Unlock()
+
+ if h.hasCLTV() {
+ h.log.Infof("waiting for CSV and CLTV lock to expire at "+
+ "height %v", waitHeight)
+ } else {
+ h.log.Infof("waiting for CSV lock to expire at height %v",
+ waitHeight)
+ }
+
+ // We'll use this input index to determine the second-level output
+ // index on the transaction, as the signatures requires the indexes to
+ // be the same. We don't look for the second-level output script
+ // directly, as there might be more than one HTLC output to the same
+ // pkScript.
+ op := &wire.OutPoint{
+ Hash: *commitSpend.SpenderTxHash,
+ Index: commitSpend.SpenderInputIndex,
+ }
+
+ var witType input.StandardWitnessType
+ if h.isTaproot() {
+ witType = input.TaprootHtlcOfferedTimeoutSecondLevel
+ } else {
+ witType = input.HtlcOfferedTimeoutSecondLevel
+ }
+
+ // Let the sweeper sweep the second-level output now that the CSV/CLTV
+ // locks have expired.
+ inp := h.makeSweepInput(
+ op, witType,
+ input.LeaseHtlcOfferedTimeoutSecondLevel,
+ &h.htlcResolution.SweepSignDesc,
+ h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight),
+ h.htlc.RHash, h.htlcResolution.ResolutionBlob,
+ )
+
+ // Calculate the budget for this sweep.
+ budget := calculateBudget(
+ btcutil.Amount(inp.SignDesc().Output.Value),
+ h.Budget.NoDeadlineHTLCRatio,
+ h.Budget.NoDeadlineHTLC,
+ )
+
+ h.log.Infof("offering output from 2nd-level timeout tx to sweeper "+
+ "with no deadline and budget=%v", budget)
+
+ // TODO(yy): use the result chan returned from SweepInput to get the
+ // confirmation status of this sweeping tx so we don't need to make
+ // anothe subscription via `RegisterSpendNtfn` for this outpoint here
+ // in the resolver.
+ _, err = h.Sweeper.SweepInput(
+ inp,
+ sweep.Params{
+ Budget: budget,
+
+ // For second level success tx, there's no rush
+ // to get it confirmed, so we use a nil
+ // deadline.
+ DeadlineHeight: fn.None[int32](),
+ },
+ )
+
+ return err
+}
+
+// checkpointStageOne creates a checkpoint for the first stage of the htlc
+// timeout transaction. This is used to ensure that the resolver can resume
+// watching for the second stage spend in case of a restart.
+func (h *htlcTimeoutResolver) checkpointStageOne(
+ spendTxid chainhash.Hash) error {
+
+ h.log.Debugf("checkpoint stage one spend of HTLC output %v, spent "+
+ "in tx %v", h.outpoint(), spendTxid)
+
+ // Now that the second-level transaction has confirmed, we checkpoint
+ // the state so we'll go to the next stage in case of restarts.
+ h.outputIncubating = true
+
+ // Create stage-one report.
+ report := &channeldb.ResolverReport{
+ OutPoint: h.outpoint(),
+ Amount: h.htlc.Amt.ToSatoshis(),
+ ResolverType: channeldb.ResolverTypeOutgoingHtlc,
+ ResolverOutcome: channeldb.ResolverOutcomeFirstStage,
+ SpendTxID: &spendTxid,
+ }
+
+ // At this point, the second-level transaction is sufficiently
+ // confirmed. We can now send back our clean up message, failing the
+ // HTLC on the incoming link.
+ failureMsg := &lnwire.FailPermanentChannelFailure{}
+ err := h.DeliverResolutionMsg(ResolutionMsg{
+ SourceChan: h.ShortChanID,
+ HtlcIndex: h.htlc.HtlcIndex,
+ Failure: failureMsg,
+ })
+ if err != nil {
+ return err
+ }
+
+ return h.Checkpoint(h, report)
+}
+
+// checkpointClaim checkpoints the timeout resolver with the reports it needs.
+func (h *htlcTimeoutResolver) checkpointClaim(
+ spendDetail *chainntnfs.SpendDetail) error {
+
+ h.log.Infof("resolving htlc with incoming fail msg, output=%v "+
+ "confirmed in tx=%v", spendDetail.SpentOutPoint,
+ spendDetail.SpenderTxHash)
+
+ // Create a resolver report for the claiming of the HTLC.
+ amt := btcutil.Amount(h.htlcResolution.SweepSignDesc.Output.Value)
+ report := &channeldb.ResolverReport{
+ OutPoint: *spendDetail.SpentOutPoint,
+ Amount: amt,
+ ResolverType: channeldb.ResolverTypeOutgoingHtlc,
+ ResolverOutcome: channeldb.ResolverOutcomeTimeout,
+ SpendTxID: spendDetail.SpenderTxHash,
+ }
+
+ // Finally, we checkpoint the resolver with our report(s).
+ h.markResolved()
+
+ return h.Checkpoint(h, report)
+}
+
+// resolveRemoteCommitOutput handles sweeping an HTLC output on the remote
+// commitment with via the timeout path. In this case we can sweep the output
+// directly, and don't have to broadcast a second-level transaction.
+func (h *htlcTimeoutResolver) resolveRemoteCommitOutput() error {
+ h.log.Debug("waiting for direct-timeout spend of the htlc to confirm")
+
+ // Wait for the direct-timeout HTLC sweep tx to confirm.
+ spend, err := h.watchHtlcSpend()
+ if err != nil {
+ return err
+ }
+
+ // If the spend reveals the preimage, then we'll enter the clean up
+ // workflow to pass the preimage back to the incoming link, add it to
+ // the witness cache, and exit.
+ if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) {
+ return h.claimCleanUp(spend)
+ }
+
+ // Send the clean up msg to fail the incoming HTLC.
+ failureMsg := &lnwire.FailPermanentChannelFailure{}
+ err = h.DeliverResolutionMsg(ResolutionMsg{
+ SourceChan: h.ShortChanID,
+ HtlcIndex: h.htlc.HtlcIndex,
+ Failure: failureMsg,
+ })
+ if err != nil {
+ return err
+ }
+
+ // TODO(yy): should also update the `RecoveredBalance` and
+ // `LimboBalance` like other paths?
+
+ // Checkpoint the resolver, and write the outcome to disk.
+ return h.checkpointClaim(spend)
+}
+
+// resolveTimeoutTx waits for the sweeping tx of the second-level
+// timeout tx to confirm and offers the output from the timeout tx to the
+// sweeper.
+func (h *htlcTimeoutResolver) resolveTimeoutTx() error {
+ h.log.Debug("waiting for first-stage 2nd-level HTLC timeout tx to " +
+ "confirm")
+
+ // Wait for the second level transaction to confirm.
+ spend, err := h.watchHtlcSpend()
+ if err != nil {
+ return err
+ }
+
+ // If the spend reveals the preimage, then we'll enter the clean up
+ // workflow to pass the preimage back to the incoming link, add it to
+ // the witness cache, and exit.
+ if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) {
+ return h.claimCleanUp(spend)
+ }
+
+ op := h.htlcResolution.ClaimOutpoint
+ spenderTxid := *spend.SpenderTxHash
+
+ // If the timeout tx is a re-signed tx, we will need to find the actual
+ // spent outpoint from the spending tx.
+ if h.isZeroFeeOutput() {
+ op = wire.OutPoint{
+ Hash: spenderTxid,
+ Index: spend.SpenderInputIndex,
+ }
+ }
+
+ // If the 2nd-stage sweeping has already been started, we can
+ // fast-forward to start the resolving process for the stage two
+ // output.
+ if h.outputIncubating {
+ return h.resolveTimeoutTxOutput(op)
+ }
+
+ h.log.Infof("2nd-level HTLC timeout tx=%v confirmed", spenderTxid)
+
+ // Start the process to sweep the output from the timeout tx.
+ if h.isZeroFeeOutput() {
+ err = h.sweepTimeoutTxOutput()
+ if err != nil {
+ return err
+ }
+ }
+
+ // Create a checkpoint since the timeout tx is confirmed and the sweep
+ // request has been made.
+ if err := h.checkpointStageOne(spenderTxid); err != nil {
+ return err
+ }
+
+ // Start the resolving process for the stage two output.
+ return h.resolveTimeoutTxOutput(op)
+}
+
+// resolveTimeoutTxOutput waits for the spend of the output from the 2nd-level
+// timeout tx.
+func (h *htlcTimeoutResolver) resolveTimeoutTxOutput(op wire.OutPoint) error {
+ h.log.Debugf("waiting for second-stage 2nd-level timeout tx output %v "+
+ "to be spent after csv_delay=%v", op, h.htlcResolution.CsvDelay)
+
+ spend, err := waitForSpend(
+ &op, h.htlcResolution.SweepSignDesc.Output.PkScript,
+ h.broadcastHeight, h.Notifier, h.quit,
+ )
+ if err != nil {
+ return err
+ }
+
+ h.reportLock.Lock()
+ h.currentReport.RecoveredBalance = h.currentReport.LimboBalance
+ h.currentReport.LimboBalance = 0
+ h.reportLock.Unlock()
+
+ return h.checkpointClaim(spend)
+}
+
+// Launch creates an input based on the details of the outgoing htlc resolution
+// and offers it to the sweeper.
+func (h *htlcTimeoutResolver) Launch() error {
+ if h.isLaunched() {
+ h.log.Tracef("already launched")
+ return nil
+ }
+
+ h.log.Debugf("launching resolver...")
+ h.launched.Store(true)
+
+ switch {
+ // If we're already resolved, then we can exit early.
+ case h.IsResolved():
+ h.log.Errorf("already resolved")
+ return nil
+
+ // If this is an output on the remote party's commitment transaction,
+ // use the direct timeout spend path.
+ //
+ // NOTE: When the outputIncubating is false, it means that the output
+ // has been offered to the utxo nursery as starting in 0.18.4, we
+ // stopped marking this flag for direct timeout spends (#9062). In that
+ // case, we will do nothing and let the utxo nursery handle it.
+ case h.isRemoteCommitOutput() && !h.outputIncubating:
+ return h.sweepDirectHtlcOutput()
+
+ // If this is an anchor type channel, we now sweep either the
+ // second-level timeout tx or the output from the second-level timeout
+ // tx.
+ case h.isZeroFeeOutput():
+ // If the second-level timeout tx has already been swept, we
+ // can go ahead and sweep its output.
+ if h.outputIncubating {
+ return h.sweepTimeoutTxOutput()
+ }
+
+ // Otherwise, sweep the second level tx.
+ return h.sweepTimeoutTx()
+
+ // If this is an output on our own commitment using pre-anchor channel
+ // type, we will let the utxo nursery handle it via Resolve.
+ //
+ // TODO(yy): handle the legacy output by offering it to the sweeper.
+ default:
+ return nil
+ }
+}
diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go
index f3f23c385c..017d3d3886 100644
--- a/contractcourt/htlc_timeout_resolver_test.go
+++ b/contractcourt/htlc_timeout_resolver_test.go
@@ -14,7 +14,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
@@ -40,7 +40,7 @@ type mockWitnessBeacon struct {
func newMockWitnessBeacon() *mockWitnessBeacon {
return &mockWitnessBeacon{
preImageUpdates: make(chan lntypes.Preimage, 1),
- newPreimages: make(chan []lntypes.Preimage),
+ newPreimages: make(chan []lntypes.Preimage, 1),
lookupPreimage: make(map[lntypes.Hash]lntypes.Preimage),
}
}
@@ -280,7 +280,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch),
- SpendChan: make(chan *chainntnfs.SpendDetail),
+ SpendChan: make(chan *chainntnfs.SpendDetail, 1),
ConfChan: make(chan *chainntnfs.TxConfirmation),
}
@@ -321,6 +321,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
return nil
},
+ HtlcNotifier: &mockHTLCNotifier{},
},
PutResolverReport: func(_ kvdb.RwTx,
_ *channeldb.ResolverReport) error {
@@ -356,6 +357,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
Amt: testHtlcAmt,
},
}
+ resolver.initLogger("timeoutResolver")
var reports []*channeldb.ResolverReport
@@ -390,7 +392,12 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
go func() {
defer wg.Done()
- _, err := resolver.Resolve(false)
+ err := resolver.Launch()
+ if err != nil {
+ resolveErr <- err
+ }
+
+ _, err = resolver.Resolve()
if err != nil {
resolveErr <- err
}
@@ -406,8 +413,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
sweepChan = mockSweeper.sweptInputs
}
- // The output should be offered to either the sweeper or
- // the nursery.
+ // The output should be offered to either the sweeper or the nursery.
select {
case <-incubateChan:
case <-sweepChan:
@@ -431,6 +437,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
case notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: spendingTx,
SpenderTxHash: &spendTxHash,
+ SpentOutPoint: &testChanPoint2,
}:
case <-time.After(time.Second * 5):
t.Fatalf("failed to request spend ntfn")
@@ -487,6 +494,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
case notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: spendingTx,
SpenderTxHash: &spendTxHash,
+ SpentOutPoint: &testChanPoint2,
}:
case <-time.After(time.Second * 5):
t.Fatalf("failed to request spend ntfn")
@@ -524,7 +532,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) {
wg.Wait()
// Finally, the resolver should be marked as resolved.
- if !resolver.resolved {
+ if !resolver.resolved.Load() {
t.Fatalf("resolver should be marked as resolved")
}
}
@@ -549,6 +557,8 @@ func TestHtlcTimeoutResolver(t *testing.T) {
// TestHtlcTimeoutSingleStage tests a remote commitment confirming, and the
// local node sweeping the HTLC output directly after timeout.
+//
+//nolint:ll
func TestHtlcTimeoutSingleStage(t *testing.T) {
commitOutpoint := wire.OutPoint{Index: 3}
@@ -573,6 +583,12 @@ func TestHtlcTimeoutSingleStage(t *testing.T) {
SpendTxID: &sweepTxid,
}
+ sweepSpend := &chainntnfs.SpendDetail{
+ SpendingTx: sweepTx,
+ SpentOutPoint: &commitOutpoint,
+ SpenderTxHash: &sweepTxid,
+ }
+
checkpoints := []checkpoint{
{
// We send a confirmation the sweep tx from published
@@ -582,9 +598,10 @@ func TestHtlcTimeoutSingleStage(t *testing.T) {
// The nursery will create and publish a sweep
// tx.
- ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
- SpendingTx: sweepTx,
- SpenderTxHash: &sweepTxid,
+ select {
+ case ctx.notifier.SpendChan <- sweepSpend:
+ case <-time.After(time.Second * 5):
+ t.Fatalf("failed to send spend ntfn")
}
// The resolver should deliver a failure
@@ -620,7 +637,9 @@ func TestHtlcTimeoutSingleStage(t *testing.T) {
// TestHtlcTimeoutSecondStage tests a local commitment being confirmed, and the
// local node claiming the HTLC output using the second-level timeout tx.
-func TestHtlcTimeoutSecondStage(t *testing.T) {
+//
+//nolint:ll
+func TestHtlcTimeoutSecondStagex(t *testing.T) {
commitOutpoint := wire.OutPoint{Index: 2}
htlcOutpoint := wire.OutPoint{Index: 3}
@@ -678,23 +697,57 @@ func TestHtlcTimeoutSecondStage(t *testing.T) {
SpendTxID: &sweepHash,
}
+ timeoutSpend := &chainntnfs.SpendDetail{
+ SpendingTx: timeoutTx,
+ SpentOutPoint: &commitOutpoint,
+ SpenderTxHash: &timeoutTxid,
+ }
+
+ sweepSpend := &chainntnfs.SpendDetail{
+ SpendingTx: sweepTx,
+ SpentOutPoint: &htlcOutpoint,
+ SpenderTxHash: &sweepHash,
+ }
+
checkpoints := []checkpoint{
{
+ preCheckpoint: func(ctx *htlcResolverTestContext,
+ _ bool) error {
+
+ // Deliver spend of timeout tx.
+ ctx.notifier.SpendChan <- timeoutSpend
+
+ return nil
+ },
+
// Output should be handed off to the nursery.
incubating: true,
+ reports: []*channeldb.ResolverReport{
+ firstStage,
+ },
},
{
// We send a confirmation for our sweep tx to indicate
// that our sweep succeeded.
preCheckpoint: func(ctx *htlcResolverTestContext,
- _ bool) error {
+ resumed bool) error {
- // The nursery will publish the timeout tx.
- ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
- SpendingTx: timeoutTx,
- SpenderTxHash: &timeoutTxid,
+ // When it's reloaded from disk, we need to
+ // re-send the notification to mock the first
+ // `watchHtlcSpend`.
+ if resumed {
+ // Deliver spend of timeout tx.
+ ctx.notifier.SpendChan <- timeoutSpend
+
+ // Deliver spend of timeout tx output.
+ ctx.notifier.SpendChan <- sweepSpend
+
+ return nil
}
+ // Deliver spend of timeout tx output.
+ ctx.notifier.SpendChan <- sweepSpend
+
// The resolver should deliver a failure
// resolution message (indicating we
// successfully timed out the HTLC).
@@ -707,12 +760,6 @@ func TestHtlcTimeoutSecondStage(t *testing.T) {
t.Fatalf("resolution not sent")
}
- // Deliver spend of timeout tx.
- ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
- SpendingTx: sweepTx,
- SpenderTxHash: &sweepHash,
- }
-
return nil
},
@@ -722,7 +769,7 @@ func TestHtlcTimeoutSecondStage(t *testing.T) {
incubating: true,
resolved: true,
reports: []*channeldb.ResolverReport{
- firstStage, secondState,
+ secondState,
},
},
}
@@ -796,10 +843,6 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) {
}
checkpoints := []checkpoint{
- {
- // Output should be handed off to the nursery.
- incubating: true,
- },
{
// We send a spend notification for a remote spend with
// the preimage.
@@ -812,6 +855,7 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) {
// the preimage.
ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: spendTx,
+ SpentOutPoint: &commitOutpoint,
SpenderTxHash: &spendTxHash,
}
@@ -847,7 +891,7 @@ func TestHtlcTimeoutSingleStageRemoteSpend(t *testing.T) {
// After the success tx has confirmed, we expect the
// checkpoint to be resolved, and with the above
// report.
- incubating: true,
+ incubating: false,
resolved: true,
reports: []*channeldb.ResolverReport{
claim,
@@ -914,6 +958,7 @@ func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) {
ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: remoteSuccessTx,
+ SpentOutPoint: &commitOutpoint,
SpenderTxHash: &successTxid,
}
@@ -967,20 +1012,15 @@ func TestHtlcTimeoutSecondStageRemoteSpend(t *testing.T) {
// TestHtlcTimeoutSecondStageSweeper tests that for anchor channels, when a
// local commitment confirms, the timeout tx is handed to the sweeper to claim
// the HTLC output.
+//
+//nolint:ll
func TestHtlcTimeoutSecondStageSweeper(t *testing.T) {
- commitOutpoint := wire.OutPoint{Index: 2}
htlcOutpoint := wire.OutPoint{Index: 3}
- sweepTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{{}},
- TxOut: []*wire.TxOut{{}},
- }
- sweepHash := sweepTx.TxHash()
-
timeoutTx := &wire.MsgTx{
TxIn: []*wire.TxIn{
{
- PreviousOutPoint: commitOutpoint,
+ PreviousOutPoint: htlcOutpoint,
},
},
TxOut: []*wire.TxOut{
@@ -1027,11 +1067,16 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) {
},
}
reSignedHash := reSignedTimeoutTx.TxHash()
- reSignedOutPoint := wire.OutPoint{
+
+ timeoutTxOutpoint := wire.OutPoint{
Hash: reSignedHash,
Index: 1,
}
+ // Make a copy so `isPreimageSpend` can easily pass.
+ sweepTx := reSignedTimeoutTx.Copy()
+ sweepHash := sweepTx.TxHash()
+
// twoStageResolution is a resolution for a htlc on the local
// party's commitment, where the timeout tx can be re-signed.
twoStageResolution := lnwallet.OutgoingHtlcResolution{
@@ -1045,7 +1090,7 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) {
}
firstStage := &channeldb.ResolverReport{
- OutPoint: commitOutpoint,
+ OutPoint: htlcOutpoint,
Amount: testHtlcAmt.ToSatoshis(),
ResolverType: channeldb.ResolverTypeOutgoingHtlc,
ResolverOutcome: channeldb.ResolverOutcomeFirstStage,
@@ -1053,12 +1098,45 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) {
}
secondState := &channeldb.ResolverReport{
- OutPoint: reSignedOutPoint,
+ OutPoint: timeoutTxOutpoint,
Amount: btcutil.Amount(testSignDesc.Output.Value),
ResolverType: channeldb.ResolverTypeOutgoingHtlc,
ResolverOutcome: channeldb.ResolverOutcomeTimeout,
SpendTxID: &sweepHash,
}
+ // mockTimeoutTxSpend is a helper closure to mock `waitForSpend` to
+ // return the commit spend in `sweepTimeoutTxOutput`.
+ mockTimeoutTxSpend := func(ctx *htlcResolverTestContext) {
+ select {
+ case ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
+ SpendingTx: reSignedTimeoutTx,
+ SpenderInputIndex: 1,
+ SpenderTxHash: &reSignedHash,
+ SpendingHeight: 10,
+ SpentOutPoint: &htlcOutpoint,
+ }:
+
+ case <-time.After(time.Second * 1):
+ t.Fatalf("spend not sent")
+ }
+ }
+
+ // mockSweepTxSpend is a helper closure to mock `waitForSpend` to
+ // return the commit spend in `sweepTimeoutTxOutput`.
+ mockSweepTxSpend := func(ctx *htlcResolverTestContext) {
+ select {
+ case ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
+ SpendingTx: sweepTx,
+ SpenderInputIndex: 1,
+ SpenderTxHash: &sweepHash,
+ SpendingHeight: 10,
+ SpentOutPoint: &timeoutTxOutpoint,
+ }:
+
+ case <-time.After(time.Second * 1):
+ t.Fatalf("spend not sent")
+ }
+ }
checkpoints := []checkpoint{
{
@@ -1067,28 +1145,40 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) {
_ bool) error {
resolver := ctx.resolver.(*htlcTimeoutResolver)
- inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs
+
+ var (
+ inp input.Input
+ ok bool
+ )
+
+ select {
+ case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs:
+ require.True(t, ok)
+
+ case <-time.After(1 * time.Second):
+ t.Fatal("expected input to be swept")
+ }
+
op := inp.OutPoint()
- if op != commitOutpoint {
+ if op != htlcOutpoint {
return fmt.Errorf("outpoint %v swept, "+
- "expected %v", op,
- commitOutpoint)
+ "expected %v", op, htlcOutpoint)
}
- // Emulat the sweeper spending using the
- // re-signed timeout tx.
- ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
- SpendingTx: reSignedTimeoutTx,
- SpenderInputIndex: 1,
- SpenderTxHash: &reSignedHash,
- SpendingHeight: 10,
- }
+ // Mock `waitForSpend` twice, called in,
+ // - `resolveReSignedTimeoutTx`
+ // - `sweepTimeoutTxOutput`.
+ mockTimeoutTxSpend(ctx)
+ mockTimeoutTxSpend(ctx)
return nil
},
// incubating=true is used to signal that the
// second-level transaction was confirmed.
incubating: true,
+ reports: []*channeldb.ResolverReport{
+ firstStage,
+ },
},
{
// We send a confirmation for our sweep tx to indicate
@@ -1096,18 +1186,18 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) {
preCheckpoint: func(ctx *htlcResolverTestContext,
resumed bool) error {
- // If we are resuming from a checkpoint, we
- // expect the resolver to re-subscribe to a
- // spend, hence we must resend it.
+ // Mock `waitForSpend` to return the commit
+ // spend.
if resumed {
- ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
- SpendingTx: reSignedTimeoutTx,
- SpenderInputIndex: 1,
- SpenderTxHash: &reSignedHash,
- SpendingHeight: 10,
- }
+ mockTimeoutTxSpend(ctx)
+ mockTimeoutTxSpend(ctx)
+ mockSweepTxSpend(ctx)
+
+ return nil
}
+ mockSweepTxSpend(ctx)
+
// The resolver should deliver a failure
// resolution message (indicating we
// successfully timed out the HTLC).
@@ -1120,15 +1210,23 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) {
t.Fatalf("resolution not sent")
}
- // Mimic CSV lock expiring.
- ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{
- Height: 13,
- }
-
// The timeout tx output should now be given to
// the sweeper.
resolver := ctx.resolver.(*htlcTimeoutResolver)
- inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs
+
+ var (
+ inp input.Input
+ ok bool
+ )
+
+ select {
+ case inp, ok = <-resolver.Sweeper.(*mockSweeper).sweptInputs:
+ require.True(t, ok)
+
+ case <-time.After(1 * time.Second):
+ t.Fatal("expected input to be swept")
+ }
+
op := inp.OutPoint()
exp := wire.OutPoint{
Hash: reSignedHash,
@@ -1138,14 +1236,6 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) {
return fmt.Errorf("wrong outpoint swept")
}
- // Notify about the spend, which should resolve
- // the resolver.
- ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
- SpendingTx: sweepTx,
- SpenderTxHash: &sweepHash,
- SpendingHeight: 14,
- }
-
return nil
},
@@ -1155,7 +1245,6 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) {
incubating: true,
resolved: true,
reports: []*channeldb.ResolverReport{
- firstStage,
secondState,
},
},
@@ -1236,33 +1325,6 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) {
}
checkpoints := []checkpoint{
- {
- // The output should be given to the sweeper.
- preCheckpoint: func(ctx *htlcResolverTestContext,
- _ bool) error {
-
- resolver := ctx.resolver.(*htlcTimeoutResolver)
- inp := <-resolver.Sweeper.(*mockSweeper).sweptInputs
- op := inp.OutPoint()
- if op != commitOutpoint {
- return fmt.Errorf("outpoint %v swept, "+
- "expected %v", op,
- commitOutpoint)
- }
-
- // Emulate the remote sweeping the output with the preimage.
- // re-signed timeout tx.
- ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
- SpendingTx: spendTx,
- SpenderTxHash: &spendTxHash,
- }
-
- return nil
- },
- // incubating=true is used to signal that the
- // second-level transaction was confirmed.
- incubating: true,
- },
{
// We send a confirmation for our sweep tx to indicate
// that our sweep succeeded.
@@ -1277,6 +1339,7 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) {
ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: spendTx,
SpenderTxHash: &spendTxHash,
+ SpentOutPoint: &commitOutpoint,
}
}
@@ -1314,7 +1377,7 @@ func TestHtlcTimeoutSecondStageSweeperRemoteSpend(t *testing.T) {
// After the sweep has confirmed, we expect the
// checkpoint to be resolved, and with the above
// reports.
- incubating: true,
+ incubating: false,
resolved: true,
reports: []*channeldb.ResolverReport{
claim,
@@ -1339,21 +1402,26 @@ func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution,
// for the next portion of the test.
ctx := newHtlcResolverTestContext(t,
func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver {
- return &htlcTimeoutResolver{
+ r := &htlcTimeoutResolver{
contractResolverKit: *newContractResolverKit(cfg),
htlc: htlc,
htlcResolution: resolution,
}
+ r.initLogger("htlcTimeoutResolver")
+
+ return r
},
)
checkpointedState := runFromCheckpoint(t, ctx, checkpoints)
+ t.Log("Running resolver to completion after restart")
+
// Now, from every checkpoint created, we re-create the resolver, and
// run the test from that checkpoint.
for i := range checkpointedState {
cp := bytes.NewReader(checkpointedState[i])
- ctx := newHtlcResolverTestContext(t,
+ ctx := newHtlcResolverTestContextFromReader(t,
func(htlc channeldb.HTLC, cfg ResolverConfig) ContractResolver {
resolver, err := newTimeoutResolverFromReader(cp, cfg)
if err != nil {
@@ -1361,7 +1429,8 @@ func testHtlcTimeout(t *testing.T, resolution lnwallet.OutgoingHtlcResolution,
}
resolver.Supplement(htlc)
- resolver.htlcResolution = resolution
+ resolver.initLogger("htlcTimeoutResolver")
+
return resolver
},
)
diff --git a/contractcourt/mock_registry_test.go b/contractcourt/mock_registry_test.go
index 5bba11afcb..0530ab51dd 100644
--- a/contractcourt/mock_registry_test.go
+++ b/contractcourt/mock_registry_test.go
@@ -29,6 +29,11 @@ func (r *mockRegistry) NotifyExitHopHtlc(payHash lntypes.Hash,
wireCustomRecords lnwire.CustomRecords,
payload invoices.Payload) (invoices.HtlcResolution, error) {
+ // Exit early if the notification channel is nil.
+ if hodlChan == nil {
+ return r.notifyResolution, r.notifyErr
+ }
+
r.notifyChan <- notifyExitHopData{
hodlChan: hodlChan,
payHash: payHash,
diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go
index aef906a0ad..f78be9fa49 100644
--- a/contractcourt/utxonursery.go
+++ b/contractcourt/utxonursery.go
@@ -15,7 +15,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
@@ -794,7 +794,7 @@ func (u *UtxoNursery) graduateClass(classHeight uint32) error {
return err
}
- utxnLog.Infof("Attempting to graduate height=%v: num_kids=%v, "+
+ utxnLog.Debugf("Attempting to graduate height=%v: num_kids=%v, "+
"num_babies=%v", classHeight, len(kgtnOutputs), len(cribOutputs))
// Offer the outputs to the sweeper and set up notifications that will
diff --git a/contractcourt/utxonursery_test.go b/contractcourt/utxonursery_test.go
index 796d1ed239..f1b47cc2ca 100644
--- a/contractcourt/utxonursery_test.go
+++ b/contractcourt/utxonursery_test.go
@@ -18,7 +18,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lnwallet"
diff --git a/discovery/gossiper.go b/discovery/gossiper.go
index aafadba479..9c51734396 100644
--- a/discovery/gossiper.go
+++ b/discovery/gossiper.go
@@ -19,7 +19,7 @@ import (
"github.com/lightningnetwork/lnd/batch"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go
index 85a4e0657e..b74f69bf0a 100644
--- a/discovery/gossiper_test.go
+++ b/discovery/gossiper_test.go
@@ -24,7 +24,7 @@ import (
"github.com/lightningnetwork/lnd/batch"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
diff --git a/docs/grpc/ruby.md b/docs/grpc/ruby.md
index 599dd2bc7a..457246423f 100644
--- a/docs/grpc/ruby.md
+++ b/docs/grpc/ruby.md
@@ -58,7 +58,7 @@ $:.unshift(File.dirname(__FILE__))
require 'grpc'
require 'lightning_services_pb'
-# Due to updated ECDSA generated tls.cert we need to let gprc know that
+# Due to updated ECDSA generated tls.cert we need to let grpc know that
# we need to use that cipher suite otherwise there will be a handshake
# error when we communicate with the lnd rpc server.
ENV['GRPC_SSL_CIPHER_SUITES'] = "HIGH+ECDSA"
diff --git a/docs/release-notes/release-notes-0.18.4.md b/docs/release-notes/release-notes-0.18.4.md
index 1fd299f3d7..ab72c01c31 100644
--- a/docs/release-notes/release-notes-0.18.4.md
+++ b/docs/release-notes/release-notes-0.18.4.md
@@ -23,6 +23,10 @@
cause a nil pointer dereference during the probing of a payment request that
does not contain a payment address.
+* [Fixed a bug](https://github.com/lightningnetwork/lnd/pull/9324) to prevent
+ potential deadlocks when LND depends on external components (e.g. aux
+ components, hooks).
+
# New Features
The main channel state machine and database now allow for processing and storing
diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md
index db5ca738ef..ee0b0c1360 100644
--- a/docs/release-notes/release-notes-0.19.0.md
+++ b/docs/release-notes/release-notes-0.19.0.md
@@ -90,6 +90,10 @@
* [The `walletrpc.FundPsbt` method now has a new option to specify the maximum
fee to output amounts ratio.](https://github.com/lightningnetwork/lnd/pull/8600)
+* When returning the response from list invoices RPC, the `lnrpc.Invoice.Htlcs`
+ are now [sorted](https://github.com/lightningnetwork/lnd/pull/9337) based on
+ the `InvoiceHTLC.HtlcIndex`.
+
## lncli Additions
* [A pre-generated macaroon root key can now be specified in `lncli create` and
diff --git a/funding/aux_funding.go b/funding/aux_funding.go
index 492612145a..c7ef653f47 100644
--- a/funding/aux_funding.go
+++ b/funding/aux_funding.go
@@ -2,7 +2,7 @@ package funding
import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/msgmux"
diff --git a/funding/manager.go b/funding/manager.go
index c8a54d9588..395cccb2a6 100644
--- a/funding/manager.go
+++ b/funding/manager.go
@@ -23,7 +23,7 @@ import (
"github.com/lightningnetwork/lnd/chanacceptor"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/discovery"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
diff --git a/funding/manager_test.go b/funding/manager_test.go
index 525f69f9a5..b6130176d1 100644
--- a/funding/manager_test.go
+++ b/funding/manager_test.go
@@ -27,7 +27,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channelnotifier"
"github.com/lightningnetwork/lnd/discovery"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
diff --git a/go.mod b/go.mod
index 1330f9a84a..bbb421de40 100644
--- a/go.mod
+++ b/go.mod
@@ -36,13 +36,13 @@ require (
github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb
github.com/lightningnetwork/lnd/cert v1.2.2
github.com/lightningnetwork/lnd/clock v1.1.1
- github.com/lightningnetwork/lnd/fn v1.2.5
+ github.com/lightningnetwork/lnd/fn/v2 v2.0.2
github.com/lightningnetwork/lnd/healthcheck v1.2.6
github.com/lightningnetwork/lnd/kvdb v1.4.11
github.com/lightningnetwork/lnd/queue v1.1.1
github.com/lightningnetwork/lnd/sqldb v1.0.5
github.com/lightningnetwork/lnd/ticker v1.1.1
- github.com/lightningnetwork/lnd/tlv v1.2.6
+ github.com/lightningnetwork/lnd/tlv v1.3.0
github.com/lightningnetwork/lnd/tor v1.1.4
github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796
github.com/miekg/dns v1.1.43
diff --git a/go.sum b/go.sum
index aa04dc5fce..4c452df735 100644
--- a/go.sum
+++ b/go.sum
@@ -456,8 +456,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf
github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U=
github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0=
github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ=
-github.com/lightningnetwork/lnd/fn v1.2.5 h1:pGMz0BDUxrhvOtShD4FIysdVy+ulfFAnFvTKjZO5Pp8=
-github.com/lightningnetwork/lnd/fn v1.2.5/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0=
+github.com/lightningnetwork/lnd/fn/v2 v2.0.2 h1:M7o2lYrh/zCp+lntPB3WP/rWTu5U+4ssyHW+kqNJ0fs=
+github.com/lightningnetwork/lnd/fn/v2 v2.0.2/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s=
github.com/lightningnetwork/lnd/healthcheck v1.2.6 h1:1sWhqr93GdkWy4+6U7JxBfcyZIE78MhIHTJZfPx7qqI=
github.com/lightningnetwork/lnd/healthcheck v1.2.6/go.mod h1:Mu02um4CWY/zdTOvFje7WJgJcHyX2zq/FG3MhOAiGaQ=
github.com/lightningnetwork/lnd/kvdb v1.4.11 h1:fk1HMVFrsVK3xqU7q+JWHRgBltw/a2qIg1E3zazMb/8=
@@ -468,8 +468,8 @@ github.com/lightningnetwork/lnd/sqldb v1.0.5 h1:ax5vBPf44tN/uD6C5+hBPBjOJ7cRMrUL
github.com/lightningnetwork/lnd/sqldb v1.0.5/go.mod h1:OG09zL/PHPaBJefp4HsPz2YLUJ+zIQHbpgCtLnOx8I4=
github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM=
github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA=
-github.com/lightningnetwork/lnd/tlv v1.2.6 h1:icvQG2yDr6k3ZuZzfRdG3EJp6pHurcuh3R6dg0gv/Mw=
-github.com/lightningnetwork/lnd/tlv v1.2.6/go.mod h1:/CmY4VbItpOldksocmGT4lxiJqRP9oLxwSZOda2kzNQ=
+github.com/lightningnetwork/lnd/tlv v1.3.0 h1:exS/KCPEgpOgviIttfiXAPaUqw2rHQrnUOpP7HPBPiY=
+github.com/lightningnetwork/lnd/tlv v1.3.0/go.mod h1:pJuiBj1ecr1WWLOtcZ+2+hu9Ey25aJWFIsjmAoPPnmc=
github.com/lightningnetwork/lnd/tor v1.1.4 h1:TUW27EXqoZCcCAQPlD4aaDfh8jMbBS9CghNz50qqwtA=
github.com/lightningnetwork/lnd/tor v1.1.4/go.mod h1:qSRB8llhAK+a6kaTPWOLLXSZc6Hg8ZC0mq1sUQ/8JfI=
github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 h1:sjOGyegMIhvgfq5oaue6Td+hxZuf3tDC8lAPrFldqFw=
diff --git a/graph/builder.go b/graph/builder.go
index c0133e02ec..d6984af709 100644
--- a/graph/builder.go
+++ b/graph/builder.go
@@ -16,7 +16,7 @@ import (
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/batch"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
diff --git a/graph/db/models/channel_edge_info.go b/graph/db/models/channel_edge_info.go
index 0f91e2bbec..6aa67acc6a 100644
--- a/graph/db/models/channel_edge_info.go
+++ b/graph/db/models/channel_edge_info.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
)
// ChannelEdgeInfo represents a fully authenticated channel along with all its
diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go
index c48436173f..6414c9f802 100644
--- a/htlcswitch/interceptable_switch.go
+++ b/htlcswitch/interceptable_switch.go
@@ -8,7 +8,7 @@ import (
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes"
diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go
index d8f55afc69..78cb019892 100644
--- a/htlcswitch/interfaces.go
+++ b/htlcswitch/interfaces.go
@@ -6,7 +6,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lntypes"
@@ -206,6 +206,11 @@ const (
Outgoing LinkDirection = true
)
+// OptionalBandwidth is a type alias for the result of a bandwidth query that
+// may return a bandwidth value or fn.None if the bandwidth is not available or
+// not applicable.
+type OptionalBandwidth = fn.Option[lnwire.MilliSatoshi]
+
// ChannelLink is an interface which represents the subsystem for managing the
// incoming htlc requests, applying the changes to the channel, and also
// propagating/forwarding it to htlc switch.
@@ -267,10 +272,10 @@ type ChannelLink interface {
// in order to signal to the source of the HTLC, the policy consistency
// issue.
CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi,
- amtToForward lnwire.MilliSatoshi,
- incomingTimeout, outgoingTimeout uint32,
- inboundFee models.InboundFee,
- heightNow uint32, scid lnwire.ShortChannelID) *LinkError
+ amtToForward lnwire.MilliSatoshi, incomingTimeout,
+ outgoingTimeout uint32, inboundFee models.InboundFee,
+ heightNow uint32, scid lnwire.ShortChannelID,
+ customRecords lnwire.CustomRecords) *LinkError
// CheckHtlcTransit should return a nil error if the passed HTLC details
// satisfy the current channel policy. Otherwise, a LinkError with a
@@ -278,14 +283,15 @@ type ChannelLink interface {
// the violation. This call is intended to be used for locally initiated
// payments for which there is no corresponding incoming htlc.
CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi,
- timeout uint32, heightNow uint32) *LinkError
+ timeout uint32, heightNow uint32,
+ customRecords lnwire.CustomRecords) *LinkError
// Stats return the statistics of channel link. Number of updates,
// total sent/received milli-satoshis.
Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi)
- // Peer returns the serialized public key of remote peer with which we
- // have the channel link opened.
+ // PeerPubKey returns the serialized public key of remote peer with
+ // which we have the channel link opened.
PeerPubKey() [33]byte
// AttachMailBox delivers an active MailBox to the link. The MailBox may
@@ -302,9 +308,18 @@ type ChannelLink interface {
// commitment of the channel that this link is associated with.
CommitmentCustomBlob() fn.Option[tlv.Blob]
- // Start/Stop are used to initiate the start/stop of the channel link
- // functioning.
+ // AuxBandwidth returns the bandwidth that can be used for a channel,
+ // expressed in milli-satoshi. This might be different from the regular
+ // BTC bandwidth for custom channels. This will always return fn.None()
+ // for a regular (non-custom) channel.
+ AuxBandwidth(amount lnwire.MilliSatoshi, cid lnwire.ShortChannelID,
+ htlcBlob fn.Option[tlv.Blob],
+ ts AuxTrafficShaper) fn.Result[OptionalBandwidth]
+
+ // Start starts the channel link.
Start() error
+
+ // Stop requests the channel link to be shut down.
Stop()
}
@@ -440,7 +455,7 @@ type htlcNotifier interface {
NotifyForwardingEvent(key HtlcKey, info HtlcInfo,
eventType HtlcEventType)
- // NotifyIncomingLinkFailEvent notifies that a htlc has failed on our
+ // NotifyLinkFailEvent notifies that a htlc has failed on our
// incoming link. It takes an isReceive bool to differentiate between
// our node's receives and forwards.
NotifyLinkFailEvent(key HtlcKey, info HtlcInfo,
@@ -461,3 +476,36 @@ type htlcNotifier interface {
NotifyFinalHtlcEvent(key models.CircuitKey,
info channeldb.FinalHtlcInfo)
}
+
+// AuxHtlcModifier is an interface that allows the sender to modify the outgoing
+// HTLC of a payment by changing the amount or the wire message tlv records.
+type AuxHtlcModifier interface {
+ // ProduceHtlcExtraData is a function that, based on the previous extra
+ // data blob of an HTLC, may produce a different blob or modify the
+ // amount of bitcoin this htlc should carry.
+ ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi,
+ htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi,
+ lnwire.CustomRecords, error)
+}
+
+// AuxTrafficShaper is an interface that allows the sender to determine if a
+// payment should be carried by a channel based on the TLV records that may be
+// present in the `update_add_htlc` message or the channel commitment itself.
+type AuxTrafficShaper interface {
+ AuxHtlcModifier
+
+ // ShouldHandleTraffic is called in order to check if the channel
+ // identified by the provided channel ID may have external mechanisms
+ // that would allow it to carry out the payment.
+ ShouldHandleTraffic(cid lnwire.ShortChannelID,
+ fundingBlob fn.Option[tlv.Blob]) (bool, error)
+
+ // PaymentBandwidth returns the available bandwidth for a custom channel
+ // decided by the given channel aux blob and HTLC blob. A return value
+ // of 0 means there is no bandwidth available. To find out if a channel
+ // is a custom channel that should be handled by the traffic shaper, the
+ // ShouldHandleTraffic method should be called first.
+ PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob],
+ linkBandwidth,
+ htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error)
+}
diff --git a/htlcswitch/link.go b/htlcswitch/link.go
index 60062862ef..6e67e87def 100644
--- a/htlcswitch/link.go
+++ b/htlcswitch/link.go
@@ -16,7 +16,7 @@ import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
@@ -293,6 +293,10 @@ type ChannelLinkConfig struct {
// ShouldFwdExpEndorsement is a closure that indicates whether the link
// should forward experimental endorsement signals.
ShouldFwdExpEndorsement func() bool
+
+ // AuxTrafficShaper is an optional auxiliary traffic shaper that can be
+ // used to manage the bandwidth of the link.
+ AuxTrafficShaper fn.Option[AuxTrafficShaper]
}
// channelLink is the service which drives a channel's commitment update
@@ -3233,11 +3237,11 @@ func (l *channelLink) UpdateForwardingPolicy(
// issue.
//
// NOTE: Part of the ChannelLink interface.
-func (l *channelLink) CheckHtlcForward(payHash [32]byte,
- incomingHtlcAmt, amtToForward lnwire.MilliSatoshi,
- incomingTimeout, outgoingTimeout uint32,
- inboundFee models.InboundFee,
- heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError {
+func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt,
+ amtToForward lnwire.MilliSatoshi, incomingTimeout,
+ outgoingTimeout uint32, inboundFee models.InboundFee,
+ heightNow uint32, originalScid lnwire.ShortChannelID,
+ customRecords lnwire.CustomRecords) *LinkError {
l.RLock()
policy := l.cfg.FwrdingPolicy
@@ -3286,7 +3290,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte,
// Check whether the outgoing htlc satisfies the channel policy.
err := l.canSendHtlc(
policy, payHash, amtToForward, outgoingTimeout, heightNow,
- originalScid,
+ originalScid, customRecords,
)
if err != nil {
return err
@@ -3322,8 +3326,8 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte,
// the violation. This call is intended to be used for locally initiated
// payments for which there is no corresponding incoming htlc.
func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
- amt lnwire.MilliSatoshi, timeout uint32,
- heightNow uint32) *LinkError {
+ amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32,
+ customRecords lnwire.CustomRecords) *LinkError {
l.RLock()
policy := l.cfg.FwrdingPolicy
@@ -3334,6 +3338,7 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
// to occur.
return l.canSendHtlc(
policy, payHash, amt, timeout, heightNow, hop.Source,
+ customRecords,
)
}
@@ -3341,7 +3346,8 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
// the channel's amount and time lock constraints.
func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32,
- heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError {
+ heightNow uint32, originalScid lnwire.ShortChannelID,
+ customRecords lnwire.CustomRecords) *LinkError {
// As our first sanity check, we'll ensure that the passed HTLC isn't
// too small for the next hop. If so, then we'll cancel the HTLC
@@ -3399,8 +3405,38 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
return NewLinkError(&lnwire.FailExpiryTooFar{})
}
+ // We now check the available bandwidth to see if this HTLC can be
+ // forwarded.
+ availableBandwidth := l.Bandwidth()
+ auxBandwidth, err := fn.MapOptionZ(
+ l.cfg.AuxTrafficShaper,
+ func(ts AuxTrafficShaper) fn.Result[OptionalBandwidth] {
+ var htlcBlob fn.Option[tlv.Blob]
+ blob, err := customRecords.Serialize()
+ if err != nil {
+ return fn.Err[OptionalBandwidth](
+ fmt.Errorf("unable to serialize "+
+ "custom records: %w", err))
+ }
+
+ if len(blob) > 0 {
+ htlcBlob = fn.Some(blob)
+ }
+
+ return l.AuxBandwidth(amt, originalScid, htlcBlob, ts)
+ },
+ ).Unpack()
+ if err != nil {
+ l.log.Errorf("Unable to determine aux bandwidth: %v", err)
+ return NewLinkError(&lnwire.FailTemporaryNodeFailure{})
+ }
+
+ auxBandwidth.WhenSome(func(bandwidth lnwire.MilliSatoshi) {
+ availableBandwidth = bandwidth
+ })
+
// Check to see if there is enough balance in this channel.
- if amt > l.Bandwidth() {
+ if amt > availableBandwidth {
l.log.Warnf("insufficient bandwidth to route htlc: %v is "+
"larger than %v", amt, l.Bandwidth())
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
@@ -3415,6 +3451,48 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
return nil
}
+// AuxBandwidth returns the bandwidth that can be used for a channel, expressed
+// in milli-satoshi. This might be different from the regular BTC bandwidth for
+// custom channels. This will always return fn.None() for a regular (non-custom)
+// channel.
+func (l *channelLink) AuxBandwidth(amount lnwire.MilliSatoshi,
+ cid lnwire.ShortChannelID, htlcBlob fn.Option[tlv.Blob],
+ ts AuxTrafficShaper) fn.Result[OptionalBandwidth] {
+
+ unknownBandwidth := fn.None[lnwire.MilliSatoshi]()
+
+ fundingBlob := l.FundingCustomBlob()
+ shouldHandle, err := ts.ShouldHandleTraffic(cid, fundingBlob)
+ if err != nil {
+ return fn.Err[OptionalBandwidth](fmt.Errorf("traffic shaper "+
+ "failed to decide whether to handle traffic: %w", err))
+ }
+
+ log.Debugf("ShortChannelID=%v: aux traffic shaper is handling "+
+ "traffic: %v", cid, shouldHandle)
+
+ // If this channel isn't handled by the aux traffic shaper, we'll return
+ // early.
+ if !shouldHandle {
+ return fn.Ok(unknownBandwidth)
+ }
+
+ // Ask for a specific bandwidth to be used for the channel.
+ commitmentBlob := l.CommitmentCustomBlob()
+ auxBandwidth, err := ts.PaymentBandwidth(
+ htlcBlob, commitmentBlob, l.Bandwidth(), amount,
+ )
+ if err != nil {
+ return fn.Err[OptionalBandwidth](fmt.Errorf("failed to get "+
+ "bandwidth from external traffic shaper: %w", err))
+ }
+
+ log.Debugf("ShortChannelID=%v: aux traffic shaper reported available "+
+ "bandwidth: %v", cid, auxBandwidth)
+
+ return fn.Ok(fn.Some(auxBandwidth))
+}
+
// Stats returns the statistics of channel link.
//
// NOTE: Part of the ChannelLink interface.
diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go
index 80632b07e9..1747105597 100644
--- a/htlcswitch/link_test.go
+++ b/htlcswitch/link_test.go
@@ -26,7 +26,7 @@ import (
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
@@ -6243,9 +6243,9 @@ func TestCheckHtlcForward(t *testing.T) {
var hash [32]byte
t.Run("satisfied", func(t *testing.T) {
- result := link.CheckHtlcForward(hash, 1500, 1000,
- 200, 150, models.InboundFee{}, 0,
- lnwire.ShortChannelID{},
+ result := link.CheckHtlcForward(
+ hash, 1500, 1000, 200, 150, models.InboundFee{}, 0,
+ lnwire.ShortChannelID{}, nil,
)
if result != nil {
t.Fatalf("expected policy to be satisfied")
@@ -6253,9 +6253,9 @@ func TestCheckHtlcForward(t *testing.T) {
})
t.Run("below minhtlc", func(t *testing.T) {
- result := link.CheckHtlcForward(hash, 100, 50,
- 200, 150, models.InboundFee{}, 0,
- lnwire.ShortChannelID{},
+ result := link.CheckHtlcForward(
+ hash, 100, 50, 200, 150, models.InboundFee{}, 0,
+ lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok {
t.Fatalf("expected FailAmountBelowMinimum failure code")
@@ -6263,9 +6263,9 @@ func TestCheckHtlcForward(t *testing.T) {
})
t.Run("above maxhtlc", func(t *testing.T) {
- result := link.CheckHtlcForward(hash, 1500, 1200,
- 200, 150, models.InboundFee{}, 0,
- lnwire.ShortChannelID{},
+ result := link.CheckHtlcForward(
+ hash, 1500, 1200, 200, 150, models.InboundFee{}, 0,
+ lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok {
t.Fatalf("expected FailTemporaryChannelFailure failure code")
@@ -6273,9 +6273,9 @@ func TestCheckHtlcForward(t *testing.T) {
})
t.Run("insufficient fee", func(t *testing.T) {
- result := link.CheckHtlcForward(hash, 1005, 1000,
- 200, 150, models.InboundFee{}, 0,
- lnwire.ShortChannelID{},
+ result := link.CheckHtlcForward(
+ hash, 1005, 1000, 200, 150, models.InboundFee{}, 0,
+ lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok {
t.Fatalf("expected FailFeeInsufficient failure code")
@@ -6288,17 +6288,17 @@ func TestCheckHtlcForward(t *testing.T) {
t.Parallel()
result := link.CheckHtlcForward(
- hash, 100005, 100000, 200,
- 150, models.InboundFee{}, 0, lnwire.ShortChannelID{},
+ hash, 100005, 100000, 200, 150, models.InboundFee{}, 0,
+ lnwire.ShortChannelID{}, nil,
)
_, ok := result.WireMessage().(*lnwire.FailFeeInsufficient)
require.True(t, ok, "expected FailFeeInsufficient failure code")
})
t.Run("expiry too soon", func(t *testing.T) {
- result := link.CheckHtlcForward(hash, 1500, 1000,
- 200, 150, models.InboundFee{}, 190,
- lnwire.ShortChannelID{},
+ result := link.CheckHtlcForward(
+ hash, 1500, 1000, 200, 150, models.InboundFee{}, 190,
+ lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok {
t.Fatalf("expected FailExpiryTooSoon failure code")
@@ -6306,9 +6306,9 @@ func TestCheckHtlcForward(t *testing.T) {
})
t.Run("incorrect cltv expiry", func(t *testing.T) {
- result := link.CheckHtlcForward(hash, 1500, 1000,
- 200, 190, models.InboundFee{}, 0,
- lnwire.ShortChannelID{},
+ result := link.CheckHtlcForward(
+ hash, 1500, 1000, 200, 190, models.InboundFee{}, 0,
+ lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok {
t.Fatalf("expected FailIncorrectCltvExpiry failure code")
@@ -6318,9 +6318,9 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("cltv expiry too far in the future", func(t *testing.T) {
// Check that expiry isn't too far in the future.
- result := link.CheckHtlcForward(hash, 1500, 1000,
- 10200, 10100, models.InboundFee{}, 0,
- lnwire.ShortChannelID{},
+ result := link.CheckHtlcForward(
+ hash, 1500, 1000, 10200, 10100, models.InboundFee{}, 0,
+ lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok {
t.Fatalf("expected FailExpiryTooFar failure code")
@@ -6330,9 +6330,11 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("inbound fee satisfied", func(t *testing.T) {
t.Parallel()
- result := link.CheckHtlcForward(hash, 1000+10-2-1, 1000,
- 200, 150, models.InboundFee{Base: -2, Rate: -1_000},
- 0, lnwire.ShortChannelID{})
+ result := link.CheckHtlcForward(
+ hash, 1000+10-2-1, 1000, 200, 150,
+ models.InboundFee{Base: -2, Rate: -1_000},
+ 0, lnwire.ShortChannelID{}, nil,
+ )
if result != nil {
t.Fatalf("expected policy to be satisfied")
}
@@ -6341,9 +6343,11 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("inbound fee insufficient", func(t *testing.T) {
t.Parallel()
- result := link.CheckHtlcForward(hash, 1000+10-10-101-1, 1000,
+ result := link.CheckHtlcForward(
+ hash, 1000+10-10-101-1, 1000,
200, 150, models.InboundFee{Base: -10, Rate: -100_000},
- 0, lnwire.ShortChannelID{})
+ 0, lnwire.ShortChannelID{}, nil,
+ )
msg := result.WireMessage()
if _, ok := msg.(*lnwire.FailFeeInsufficient); !ok {
diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go
index ce791bef32..5cf7966573 100644
--- a/htlcswitch/mock.go
+++ b/htlcswitch/mock.go
@@ -24,7 +24,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/contractcourt"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/invoices"
@@ -846,14 +846,14 @@ func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) {
}
func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi,
lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32,
- lnwire.ShortChannelID) *LinkError {
+ lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError {
return f.checkHtlcForwardResult
}
func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte,
amt lnwire.MilliSatoshi, timeout uint32,
- heightNow uint32) *LinkError {
+ heightNow uint32, _ lnwire.CustomRecords) *LinkError {
return f.checkHtlcTransitResult
}
@@ -968,6 +968,17 @@ func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] {
return fn.None[tlv.Blob]()
}
+// AuxBandwidth returns the bandwidth that can be used for a channel,
+// expressed in milli-satoshi. This might be different from the regular
+// BTC bandwidth for custom channels. This will always return fn.None()
+// for a regular (non-custom) channel.
+func (f *mockChannelLink) AuxBandwidth(lnwire.MilliSatoshi,
+ lnwire.ShortChannelID,
+ fn.Option[tlv.Blob], AuxTrafficShaper) fn.Result[OptionalBandwidth] {
+
+ return fn.Ok(fn.None[lnwire.MilliSatoshi]())
+}
+
var _ ChannelLink = (*mockChannelLink)(nil)
const testInvoiceCltvExpiry = 6
diff --git a/htlcswitch/quiescer.go b/htlcswitch/quiescer.go
index 27d0deb8c6..468ad5e708 100644
--- a/htlcswitch/quiescer.go
+++ b/htlcswitch/quiescer.go
@@ -6,7 +6,7 @@ import (
"time"
"github.com/btcsuite/btclog/v2"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
)
diff --git a/htlcswitch/quiescer_test.go b/htlcswitch/quiescer_test.go
index da08909d57..6ce9563e45 100644
--- a/htlcswitch/quiescer_test.go
+++ b/htlcswitch/quiescer_test.go
@@ -5,7 +5,7 @@ import (
"testing"
"time"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go
index 1a08275ec9..c94c677966 100644
--- a/htlcswitch/switch.go
+++ b/htlcswitch/switch.go
@@ -17,7 +17,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/contractcourt"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/kvdb"
@@ -917,6 +917,7 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) (
currentHeight := atomic.LoadUint32(&s.bestHeight)
htlcErr := link.CheckHtlcTransit(
htlc.PaymentHash, htlc.Amount, htlc.Expiry, currentHeight,
+ htlc.CustomRecords,
)
if htlcErr != nil {
log.Errorf("Link %v policy for local forward not "+
@@ -1605,7 +1606,7 @@ out:
}
}
- log.Infof("Received outside contract resolution, "+
+ log.Debugf("Received outside contract resolution, "+
"mapping to: %v", spew.Sdump(pkt))
// We don't check the error, as the only failure we can
@@ -2887,10 +2888,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket,
failure = link.CheckHtlcForward(
htlc.PaymentHash, packet.incomingAmount,
packet.amount, packet.incomingTimeout,
- packet.outgoingTimeout,
- packet.inboundFee,
- currentHeight,
- packet.originalOutgoingChanID,
+ packet.outgoingTimeout, packet.inboundFee,
+ currentHeight, packet.originalOutgoingChanID,
+ htlc.CustomRecords,
)
}
diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go
index abfb8e4d5b..8809321460 100644
--- a/htlcswitch/switch_test.go
+++ b/htlcswitch/switch_test.go
@@ -17,7 +17,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
diff --git a/input/input.go b/input/input.go
index 088b20401f..4a9a4b55c0 100644
--- a/input/input.go
+++ b/input/input.go
@@ -6,7 +6,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/tlv"
)
diff --git a/input/mocks.go b/input/mocks.go
index bbd4550c5f..6d90bc28df 100644
--- a/input/mocks.go
+++ b/input/mocks.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/tlv"
diff --git a/input/script_utils.go b/input/script_utils.go
index 91ca55292f..000efe9585 100644
--- a/input/script_utils.go
+++ b/input/script_utils.go
@@ -14,7 +14,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"golang.org/x/crypto/ripemd160"
diff --git a/input/taproot.go b/input/taproot.go
index 2ca6e97236..5ca4dd0c66 100644
--- a/input/taproot.go
+++ b/input/taproot.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
)
const (
diff --git a/input/taproot_test.go b/input/taproot_test.go
index a1259be196..3a1e000374 100644
--- a/input/taproot_test.go
+++ b/input/taproot_test.go
@@ -10,7 +10,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/require"
diff --git a/intercepted_forward.go b/intercepted_forward.go
index 791d4bd583..5cb1ca192b 100644
--- a/intercepted_forward.go
+++ b/intercepted_forward.go
@@ -3,7 +3,7 @@ package lnd
import (
"errors"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go
index f5a6c6a95f..cc76d5aefa 100644
--- a/invoices/invoiceregistry.go
+++ b/invoices/invoiceregistry.go
@@ -1275,7 +1275,11 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked(
invoiceToExpire = makeInvoiceExpiry(ctx.hash, invoice)
}
- i.hodlSubscribe(hodlChan, ctx.circuitKey)
+ // Subscribe to the resolution if the caller specified a
+ // notification channel.
+ if hodlChan != nil {
+ i.hodlSubscribe(hodlChan, ctx.circuitKey)
+ }
default:
panic("unknown action")
diff --git a/invoices/modification_interceptor.go b/invoices/modification_interceptor.go
index 97e75e8cc5..58f5b63d07 100644
--- a/invoices/modification_interceptor.go
+++ b/invoices/modification_interceptor.go
@@ -5,7 +5,7 @@ import (
"fmt"
"sync/atomic"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
)
var (
diff --git a/itest/lnd_funding_test.go b/itest/lnd_funding_test.go
index 54180abf57..0b08da32b3 100644
--- a/itest/lnd_funding_test.go
+++ b/itest/lnd_funding_test.go
@@ -12,7 +12,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainreg"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/funding"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
diff --git a/itest/lnd_sweep_test.go b/itest/lnd_sweep_test.go
index 099014aff0..158e8768f9 100644
--- a/itest/lnd_sweep_test.go
+++ b/itest/lnd_sweep_test.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/contractcourt"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lncfg"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/invoicesrpc"
@@ -1119,9 +1119,9 @@ func testSweepHTLCs(ht *lntest.HarnessTest) {
// The sweeping tx has two inputs, one from wallet, the other
// from the force close tx. We now check whether the first tx
// spends from the force close tx of Alice->Bob.
- found := fn.Any(func(inp *wire.TxIn) bool {
+ found := fn.Any(txns[0].TxIn, func(inp *wire.TxIn) bool {
return inp.PreviousOutPoint.Hash == abCloseTxid
- }, txns[0].TxIn)
+ })
// If the first tx spends an outpoint from the force close tx
// of Alice->Bob, then it must be the incoming HTLC sweeping
diff --git a/lnrpc/devrpc/dev_server.go b/lnrpc/devrpc/dev_server.go
index 60f30dd7ed..b26b144c81 100644
--- a/lnrpc/devrpc/dev_server.go
+++ b/lnrpc/devrpc/dev_server.go
@@ -16,7 +16,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lncfg"
"github.com/lightningnetwork/lnd/lnrpc"
diff --git a/lnrpc/invoicesrpc/utils.go b/lnrpc/invoicesrpc/utils.go
index 955ba6acf2..19ade28fd8 100644
--- a/lnrpc/invoicesrpc/utils.go
+++ b/lnrpc/invoicesrpc/utils.go
@@ -1,8 +1,10 @@
package invoicesrpc
import (
+ "cmp"
"encoding/hex"
"fmt"
+ "slices"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg"
@@ -160,6 +162,11 @@ func CreateRPCInvoice(invoice *invoices.Invoice,
rpcHtlcs = append(rpcHtlcs, &rpcHtlc)
}
+ // Perform an inplace sort of the HTLCs to ensure they are ordered.
+ slices.SortFunc(rpcHtlcs, func(i, j *lnrpc.InvoiceHTLC) int {
+ return cmp.Compare(i.HtlcIndex, j.HtlcIndex)
+ })
+
rpcInvoice := &lnrpc.Invoice{
Memo: string(invoice.Memo),
RHash: rHash,
diff --git a/lnrpc/marshall_utils.go b/lnrpc/marshall_utils.go
index 230fea35b6..96d3342d83 100644
--- a/lnrpc/marshall_utils.go
+++ b/lnrpc/marshall_utils.go
@@ -11,7 +11,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/wallet"
"github.com/lightningnetwork/lnd/aliasmgr"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"golang.org/x/exp/maps"
@@ -221,12 +221,18 @@ func UnmarshallCoinSelectionStrategy(strategy CoinSelectionStrategy,
// MarshalAliasMap converts a ScidAliasMap to its proto counterpart. This is
// used in various RPCs that handle scid alias mappings.
func MarshalAliasMap(scidMap aliasmgr.ScidAliasMap) []*AliasMap {
- return fn.Map(func(base lnwire.ShortChannelID) *AliasMap {
- return &AliasMap{
- BaseScid: base.ToUint64(),
- Aliases: fn.Map(func(a lnwire.ShortChannelID) uint64 {
- return a.ToUint64()
- }, scidMap[base]),
- }
- }, maps.Keys(scidMap))
+ return fn.Map(
+ maps.Keys(scidMap),
+ func(base lnwire.ShortChannelID) *AliasMap {
+ return &AliasMap{
+ BaseScid: base.ToUint64(),
+ Aliases: fn.Map(
+ scidMap[base],
+ func(a lnwire.ShortChannelID) uint64 {
+ return a.ToUint64()
+ },
+ ),
+ }
+ },
+ )
}
diff --git a/lnrpc/routerrpc/forward_interceptor.go b/lnrpc/routerrpc/forward_interceptor.go
index 9da831ac04..72df3d0199 100644
--- a/lnrpc/routerrpc/forward_interceptor.go
+++ b/lnrpc/routerrpc/forward_interceptor.go
@@ -3,7 +3,7 @@ package routerrpc
import (
"errors"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnrpc"
diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go
index 9421e991b6..7d73681094 100644
--- a/lnrpc/routerrpc/router_backend.go
+++ b/lnrpc/routerrpc/router_backend.go
@@ -16,7 +16,7 @@ import (
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/feature"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntypes"
diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go
index 7f1a7edf07..9499fa25a3 100644
--- a/lnrpc/routerrpc/router_server.go
+++ b/lnrpc/routerrpc/router_server.go
@@ -16,7 +16,7 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/lightningnetwork/lnd/aliasmgr"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/invoicesrpc"
"github.com/lightningnetwork/lnd/lntypes"
diff --git a/lnrpc/walletrpc/walletkit_server.go b/lnrpc/walletrpc/walletkit_server.go
index c6dec6fbd5..4f477cdbd4 100644
--- a/lnrpc/walletrpc/walletkit_server.go
+++ b/lnrpc/walletrpc/walletkit_server.go
@@ -31,7 +31,7 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/labels"
@@ -1145,9 +1145,9 @@ func (w *WalletKit) getWaitingCloseChannel(
return nil, err
}
- channel := fn.Find(func(c *channeldb.OpenChannel) bool {
+ channel := fn.Find(chans, func(c *channeldb.OpenChannel) bool {
return c.FundingOutpoint == chanPoint
- }, chans)
+ })
return channel.UnwrapOrErr(errors.New("channel not found"))
}
@@ -1231,18 +1231,23 @@ func (w *WalletKit) BumpForceCloseFee(_ context.Context,
pendingSweeps := maps.Values(inputsMap)
// Discard everything except for the anchor sweeps.
- anchors := fn.Filter(func(sweep *sweep.PendingInputResponse) bool {
- // Only filter for anchor inputs because these are the only
- // inputs which can be used to bump a closed unconfirmed
- // commitment transaction.
- if sweep.WitnessType != input.CommitmentAnchor &&
- sweep.WitnessType != input.TaprootAnchorSweepSpend {
-
- return false
- }
+ anchors := fn.Filter(
+ pendingSweeps,
+ func(sweep *sweep.PendingInputResponse) bool {
+ // Only filter for anchor inputs because these are the
+ // only inputs which can be used to bump a closed
+ // unconfirmed commitment transaction.
+ isCommitAnchor := sweep.WitnessType ==
+ input.CommitmentAnchor
+ isTaprootSweepSpend := sweep.WitnessType ==
+ input.TaprootAnchorSweepSpend
+ if !isCommitAnchor && !isTaprootSweepSpend {
+ return false
+ }
- return commitSet.Contains(sweep.OutPoint.Hash)
- }, pendingSweeps)
+ return commitSet.Contains(sweep.OutPoint.Hash)
+ },
+ )
if len(anchors) == 0 {
return nil, fmt.Errorf("unable to find pending anchor outputs")
@@ -1754,7 +1759,7 @@ func (w *WalletKit) fundPsbtInternalWallet(account string,
return true
}
- eligibleUtxos := fn.Filter(filterFn, utxos)
+ eligibleUtxos := fn.Filter(utxos, filterFn)
// Validate all inputs against our known list of UTXOs
// now.
diff --git a/lntest/harness.go b/lntest/harness.go
index f96a3aadd7..8e8fcd3936 100644
--- a/lntest/harness.go
+++ b/lntest/harness.go
@@ -13,7 +13,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/go-errors/errors"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/kvdb/etcd"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/invoicesrpc"
diff --git a/lntest/harness_assertion.go b/lntest/harness_assertion.go
index 1b079fea16..11cbefdd5c 100644
--- a/lntest/harness_assertion.go
+++ b/lntest/harness_assertion.go
@@ -19,7 +19,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/invoicesrpc"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
@@ -270,7 +270,7 @@ func (h *HarnessTest) AssertNumActiveEdges(hn *node.HarnessNode,
IncludeUnannounced: includeUnannounced,
}
resp := hn.RPC.DescribeGraph(req)
- activeEdges := fn.Filter(filterDisabled, resp.Edges)
+ activeEdges := fn.Filter(resp.Edges, filterDisabled)
total := len(activeEdges)
if total-old == expected {
diff --git a/lntest/miner/miner.go b/lntest/miner/miner.go
index e9e380bbb3..0229d6a47f 100644
--- a/lntest/miner/miner.go
+++ b/lntest/miner/miner.go
@@ -17,7 +17,7 @@ import (
"github.com/btcsuite/btcd/integration/rpctest"
"github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntest/node"
"github.com/lightningnetwork/lnd/lntest/wait"
"github.com/stretchr/testify/require"
@@ -296,10 +296,7 @@ func (h *HarnessMiner) AssertTxInMempool(txid chainhash.Hash) *wire.MsgTx {
return fmt.Errorf("empty mempool")
}
- isEqual := func(memTx chainhash.Hash) bool {
- return memTx == txid
- }
- result := fn.Find(isEqual, mempool)
+ result := fn.Find(mempool, fn.Eq(txid))
if result.IsNone() {
return fmt.Errorf("txid %v not found in "+
diff --git a/lntest/mock/walletcontroller.go b/lntest/mock/walletcontroller.go
index 8b7ef55380..fa623bf84d 100644
--- a/lntest/mock/walletcontroller.go
+++ b/lntest/mock/walletcontroller.go
@@ -16,7 +16,7 @@ import (
base "github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/wallet/txauthor"
"github.com/btcsuite/btcwallet/wtxmgr"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
diff --git a/lntest/node/state.go b/lntest/node/state.go
index a89ab7d2cc..38f02f3a4c 100644
--- a/lntest/node/state.go
+++ b/lntest/node/state.go
@@ -7,7 +7,7 @@ import (
"time"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
"github.com/lightningnetwork/lnd/lntest/rpc"
@@ -324,11 +324,11 @@ func (s *State) updateEdgeStats() {
req := &lnrpc.ChannelGraphRequest{IncludeUnannounced: true}
resp := s.rpc.DescribeGraph(req)
- s.Edge.Total = len(fn.Filter(filterDisabled, resp.Edges))
+ s.Edge.Total = len(fn.Filter(resp.Edges, filterDisabled))
req = &lnrpc.ChannelGraphRequest{IncludeUnannounced: false}
resp = s.rpc.DescribeGraph(req)
- s.Edge.Public = len(fn.Filter(filterDisabled, resp.Edges))
+ s.Edge.Public = len(fn.Filter(resp.Edges, filterDisabled))
}
// updateWalletBalance creates stats for the node's wallet balance.
diff --git a/lnwallet/aux_leaf_store.go b/lnwallet/aux_leaf_store.go
index c457a92509..28a78e09db 100644
--- a/lnwallet/aux_leaf_store.go
+++ b/lnwallet/aux_leaf_store.go
@@ -5,7 +5,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/lnwallet/aux_resolutions.go b/lnwallet/aux_resolutions.go
index 382232640d..b36e2d6368 100644
--- a/lnwallet/aux_resolutions.go
+++ b/lnwallet/aux_resolutions.go
@@ -4,7 +4,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
diff --git a/lnwallet/aux_signer.go b/lnwallet/aux_signer.go
index 01abe1aae3..510b64b5d1 100644
--- a/lnwallet/aux_signer.go
+++ b/lnwallet/aux_signer.go
@@ -2,7 +2,7 @@ package lnwallet
import (
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/lnwallet/btcwallet/btcwallet.go b/lnwallet/btcwallet/btcwallet.go
index 5d28574cbe..b9a909fbd3 100644
--- a/lnwallet/btcwallet/btcwallet.go
+++ b/lnwallet/btcwallet/btcwallet.go
@@ -27,7 +27,7 @@ import (
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/blockcache"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
diff --git a/lnwallet/chainfee/filtermanager.go b/lnwallet/chainfee/filtermanager.go
index 26fa56aef1..2d6fd0a2e1 100644
--- a/lnwallet/chainfee/filtermanager.go
+++ b/lnwallet/chainfee/filtermanager.go
@@ -9,7 +9,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/rpcclient"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
)
const (
diff --git a/lnwallet/chancloser/aux_closer.go b/lnwallet/chancloser/aux_closer.go
index 8b1c445ca3..62f475dd43 100644
--- a/lnwallet/chancloser/aux_closer.go
+++ b/lnwallet/chancloser/aux_closer.go
@@ -4,7 +4,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go
index 17112b29e0..398a8a9f3e 100644
--- a/lnwallet/chancloser/chancloser.go
+++ b/lnwallet/chancloser/chancloser.go
@@ -12,7 +12,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go
index 28709fd5f8..fe71fe5e3b 100644
--- a/lnwallet/chancloser/chancloser_test.go
+++ b/lnwallet/chancloser/chancloser_test.go
@@ -14,7 +14,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
diff --git a/lnwallet/chancloser/interface.go b/lnwallet/chancloser/interface.go
index 729cdc545b..f774c81039 100644
--- a/lnwallet/chancloser/interface.go
+++ b/lnwallet/chancloser/interface.go
@@ -6,7 +6,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
diff --git a/lnwallet/chanfunding/canned_assembler.go b/lnwallet/chanfunding/canned_assembler.go
index b3457f21bf..e28cbb96d1 100644
--- a/lnwallet/chanfunding/canned_assembler.go
+++ b/lnwallet/chanfunding/canned_assembler.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
)
diff --git a/lnwallet/chanfunding/interface.go b/lnwallet/chanfunding/interface.go
index 3512b32ff9..e40c4a1157 100644
--- a/lnwallet/chanfunding/interface.go
+++ b/lnwallet/chanfunding/interface.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/wtxmgr"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
diff --git a/lnwallet/chanfunding/psbt_assembler.go b/lnwallet/chanfunding/psbt_assembler.go
index f678f520fc..dd1bedd05a 100644
--- a/lnwallet/chanfunding/psbt_assembler.go
+++ b/lnwallet/chanfunding/psbt_assembler.go
@@ -13,7 +13,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
)
diff --git a/lnwallet/channel.go b/lnwallet/channel.go
index 7f70e600c6..d190acdf5e 100644
--- a/lnwallet/channel.go
+++ b/lnwallet/channel.go
@@ -25,7 +25,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
@@ -600,7 +600,7 @@ func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight,
htlc := htlc
- auxLeaf := fn.ChainOption(
+ auxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
if htlc.Incoming {
@@ -1106,7 +1106,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate,
feeRate, wireMsg.Amount.ToSatoshis(), remoteDustLimit,
)
if !isDustRemote {
- auxLeaf := fn.ChainOption(
+ auxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
return leaves[pd.HtlcIndex].AuxTapLeaf
@@ -2088,7 +2088,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64,
// Since it is the remote breach we are reconstructing, the output
// going to us will be a to-remote script with our local params.
- remoteAuxLeaf := fn.ChainOption(
+ remoteAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
},
@@ -2102,7 +2102,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64,
return nil, err
}
- localAuxLeaf := fn.ChainOption(
+ localAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
},
@@ -2229,7 +2229,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64,
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
- br.LocalResolutionBlob = resolveBlob.Option()
+ br.LocalResolutionBlob = resolveBlob.OkToSome()
}
// Similarly, if their balance exceeds the remote party's dust limit,
@@ -2308,7 +2308,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64,
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
- br.RemoteResolutionBlob = resolveBlob.Option()
+ br.RemoteResolutionBlob = resolveBlob.OkToSome()
}
// Finally, with all the necessary data constructed, we can pad the
@@ -2338,7 +2338,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel,
// We'll generate the original second level witness script now, as
// we'll need it if we're revoking an HTLC output on the remote
// commitment transaction, and *they* go to the second level.
- secondLevelAuxLeaf := fn.ChainOption(
+ secondLevelAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] {
return fn.MapOption(func(val uint16) input.AuxTapLeaf {
idx := input.HtlcIndex(val)
@@ -2366,7 +2366,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel,
// HTLC script. Otherwise, is this was an outgoing HTLC that we sent,
// then from the PoV of the remote commitment state, they're the
// receiver of this HTLC.
- htlcLeaf := fn.ChainOption(
+ htlcLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] {
return fn.MapOption(func(val uint16) input.AuxTapLeaf {
idx := input.HtlcIndex(val)
@@ -2693,13 +2693,13 @@ type HtlcView struct {
// AuxOurUpdates returns the outgoing HTLCs as a read-only copy of
// AuxHtlcDescriptors.
func (v *HtlcView) AuxOurUpdates() []AuxHtlcDescriptor {
- return fn.Map(newAuxHtlcDescriptor, v.Updates.Local)
+ return fn.Map(v.Updates.Local, newAuxHtlcDescriptor)
}
// AuxTheirUpdates returns the incoming HTLCs as a read-only copy of
// AuxHtlcDescriptors.
func (v *HtlcView) AuxTheirUpdates() []AuxHtlcDescriptor {
- return fn.Map(newAuxHtlcDescriptor, v.Updates.Remote)
+ return fn.Map(v.Updates.Remote, newAuxHtlcDescriptor)
}
// fetchHTLCView returns all the candidate HTLC updates which should be
@@ -2917,9 +2917,9 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView,
// The fee rate of our view is always the last UpdateFee message from
// the channel's OpeningParty.
openerUpdates := view.Updates.GetForParty(lc.channelState.Initiator())
- feeUpdates := fn.Filter(func(u *paymentDescriptor) bool {
+ feeUpdates := fn.Filter(openerUpdates, func(u *paymentDescriptor) bool {
return u.EntryType == FeeUpdate
- }, openerUpdates)
+ })
lastFeeUpdate := fn.Last(feeUpdates)
lastFeeUpdate.WhenSome(func(pd *paymentDescriptor) {
newView.FeePerKw = chainfee.SatPerKWeight(
@@ -2942,14 +2942,17 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView,
for _, party := range parties {
// First we run through non-add entries in both logs,
// populating the skip sets.
- resolutions := fn.Filter(func(pd *paymentDescriptor) bool {
- switch pd.EntryType {
- case Settle, Fail, MalformedFail:
- return true
- default:
- return false
- }
- }, view.Updates.GetForParty(party))
+ resolutions := fn.Filter(
+ view.Updates.GetForParty(party),
+ func(pd *paymentDescriptor) bool {
+ switch pd.EntryType {
+ case Settle, Fail, MalformedFail:
+ return true
+ default:
+ return false
+ }
+ },
+ )
for _, entry := range resolutions {
addEntry, err := lc.fetchParent(
@@ -3002,10 +3005,16 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView,
// settled HTLCs, and debiting the chain state balance due to any newly
// added HTLCs.
for _, party := range parties {
- liveAdds := fn.Filter(func(pd *paymentDescriptor) bool {
- return pd.EntryType == Add &&
- !skip.GetForParty(party).Contains(pd.HtlcIndex)
- }, view.Updates.GetForParty(party))
+ liveAdds := fn.Filter(
+ view.Updates.GetForParty(party),
+ func(pd *paymentDescriptor) bool {
+ isAdd := pd.EntryType == Add
+ shouldSkip := skip.GetForParty(party).
+ Contains(pd.HtlcIndex)
+
+ return isAdd && !shouldSkip
+ },
+ )
for _, entry := range liveAdds {
// Skip the entries that have already had their add
@@ -3063,7 +3072,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView,
uncommittedUpdates := lntypes.MapDual(
view.Updates,
func(us []*paymentDescriptor) []*paymentDescriptor {
- return fn.Filter(isUncommitted, us)
+ return fn.Filter(us, isUncommitted)
},
)
@@ -3189,7 +3198,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing,
htlcFee := HtlcTimeoutFee(chanType, feePerKw)
outputAmt := htlc.Amount.ToSatoshis() - htlcFee
- auxLeaf := fn.ChainOption(
+ auxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.IncomingHtlcLeaves
return leaves[htlc.HtlcIndex].SecondLevelLeaf
@@ -3270,7 +3279,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing,
htlcFee := HtlcSuccessFee(chanType, feePerKw)
outputAmt := htlc.Amount.ToSatoshis() - htlcFee
- auxLeaf := fn.ChainOption(
+ auxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
return leaves[htlc.HtlcIndex].SecondLevelLeaf
@@ -4802,7 +4811,7 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel,
htlcFee := HtlcSuccessFee(chanType, feePerKw)
outputAmt := htlc.Amount.ToSatoshis() - htlcFee
- auxLeaf := fn.ChainOption(func(
+ auxLeaf := fn.FlatMapOption(func(
l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.IncomingHtlcLeaves
@@ -4895,7 +4904,7 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel,
htlcFee := HtlcTimeoutFee(chanType, feePerKw)
outputAmt := htlc.Amount.ToSatoshis() - htlcFee
- auxLeaf := fn.ChainOption(func(
+ auxLeaf := fn.FlatMapOption(func(
l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
@@ -6766,7 +6775,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, //nolint:funlen
// Before we can generate the proper sign descriptor, we'll need to
// locate the output index of our non-delayed output on the commitment
// transaction.
- remoteAuxLeaf := fn.ChainOption(
+ remoteAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
},
@@ -6870,7 +6879,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, //nolint:funlen
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
- commitResolution.ResolutionBlob = resolveBlob.Option()
+ commitResolution.ResolutionBlob = resolveBlob.OkToSome()
}
closeSummary := channeldb.ChannelCloseSummary{
@@ -7059,7 +7068,7 @@ func newOutgoingHtlcResolution(signer input.Signer,
// First, we'll re-generate the script used to send the HTLC to the
// remote party within their commitment transaction.
- auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf {
+ auxLeaf := fn.FlatMapOption(func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.OutgoingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf
})(auxLeaves)
htlcScriptInfo, err := genHtlcScript(
@@ -7149,7 +7158,7 @@ func newOutgoingHtlcResolution(signer input.Signer,
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
- resolutionBlob := resolveRes.Option()
+ resolutionBlob := resolveRes.OkToSome()
return &OutgoingHtlcResolution{
Expiry: htlc.RefundTimeout,
@@ -7171,7 +7180,7 @@ func newOutgoingHtlcResolution(signer input.Signer,
// With the fee calculated, re-construct the second level timeout
// transaction.
- secondLevelAuxLeaf := fn.ChainOption(
+ secondLevelAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
return leaves[htlc.HtlcIndex].SecondLevelLeaf
@@ -7366,7 +7375,7 @@ func newOutgoingHtlcResolution(signer input.Signer,
if err := resolveRes.Err(); err != nil {
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
- resolutionBlob := resolveRes.Option()
+ resolutionBlob := resolveRes.OkToSome()
return &OutgoingHtlcResolution{
Expiry: htlc.RefundTimeout,
@@ -7406,7 +7415,7 @@ func newIncomingHtlcResolution(signer input.Signer,
// First, we'll re-generate the script the remote party used to
// send the HTLC to us in their commitment transaction.
- auxLeaf := fn.ChainOption(func(l CommitAuxLeaves) input.AuxTapLeaf {
+ auxLeaf := fn.FlatMapOption(func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.IncomingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf
})(auxLeaves)
scriptInfo, err := genHtlcScript(
@@ -7497,7 +7506,7 @@ func newIncomingHtlcResolution(signer input.Signer,
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
- resolutionBlob := resolveRes.Option()
+ resolutionBlob := resolveRes.OkToSome()
return &IncomingHtlcResolution{
ClaimOutpoint: op,
@@ -7507,7 +7516,7 @@ func newIncomingHtlcResolution(signer input.Signer,
}, nil
}
- secondLevelAuxLeaf := fn.ChainOption(
+ secondLevelAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.IncomingHtlcLeaves
return leaves[htlc.HtlcIndex].SecondLevelLeaf
@@ -7707,7 +7716,7 @@ func newIncomingHtlcResolution(signer input.Signer,
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
- resolutionBlob := resolveRes.Option()
+ resolutionBlob := resolveRes.OkToSome()
return &IncomingHtlcResolution{
SignedSuccessTx: successTx,
@@ -8011,7 +8020,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel,
leaseExpiry = chanState.ThawHeight
}
- localAuxLeaf := fn.ChainOption(
+ localAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
},
@@ -8126,7 +8135,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel,
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
- commitResolution.ResolutionBlob = resolveBlob.Option()
+ commitResolution.ResolutionBlob = resolveBlob.OkToSome()
}
// Once the delay output has been found (if it exists), then we'll also
diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go
index f7ecd32277..d0caa97812 100644
--- a/lnwallet/channel_test.go
+++ b/lnwallet/channel_test.go
@@ -25,7 +25,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
@@ -730,9 +730,12 @@ func TestCommitHTLCSigCustomRecordSize(t *testing.T) {
// Replace the default PackSigs implementation to return a
// large custom records blob.
- mockSigner.ExpectedCalls = fn.Filter(func(c *mock.Call) bool {
- return c.Method != "PackSigs"
- }, mockSigner.ExpectedCalls)
+ mockSigner.ExpectedCalls = fn.Filter(
+ mockSigner.ExpectedCalls,
+ func(c *mock.Call) bool {
+ return c.Method != "PackSigs"
+ },
+ )
mockSigner.On("PackSigs", mock.Anything).
Return(fn.Ok(fn.Some(largeBlob)))
})
diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go
index 8b364a01df..787e8a71e1 100644
--- a/lnwallet/commitment.go
+++ b/lnwallet/commitment.go
@@ -11,7 +11,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
@@ -836,7 +836,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
continue
}
- auxLeaf := fn.ChainOption(
+ auxLeaf := fn.FlatMapOption(
func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf {
return leaves[htlc.HtlcIndex].AuxTapLeaf
},
@@ -864,7 +864,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
continue
}
- auxLeaf := fn.ChainOption(
+ auxLeaf := fn.FlatMapOption(
func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf {
return leaves[htlc.HtlcIndex].AuxTapLeaf
},
@@ -1323,7 +1323,7 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash,
// Compute the to_local script. From our PoV, when facing a remote
// commitment, the to_local output belongs to them.
- localAuxLeaf := fn.ChainOption(
+ localAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
},
@@ -1338,7 +1338,7 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash,
// Compute the to_remote script. From our PoV, when facing a remote
// commitment, the to_remote output belongs to us.
- remoteAuxLeaf := fn.ChainOption(
+ remoteAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
},
diff --git a/lnwallet/commitment_chain.go b/lnwallet/commitment_chain.go
index fa2abe0aa2..871a139c5c 100644
--- a/lnwallet/commitment_chain.go
+++ b/lnwallet/commitment_chain.go
@@ -1,7 +1,7 @@
package lnwallet
import (
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
)
// commitmentChain represents a chain of unrevoked commitments. The tail of the
diff --git a/lnwallet/config.go b/lnwallet/config.go
index 425fe15dad..c60974be6d 100644
--- a/lnwallet/config.go
+++ b/lnwallet/config.go
@@ -5,7 +5,7 @@ import (
"github.com/btcsuite/btcwallet/wallet"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
diff --git a/lnwallet/interface.go b/lnwallet/interface.go
index c9dee9202a..64f8546310 100644
--- a/lnwallet/interface.go
+++ b/lnwallet/interface.go
@@ -19,7 +19,7 @@ import (
base "github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/wallet/txauthor"
"github.com/btcsuite/btcwallet/wtxmgr"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwallet/chanvalidate"
diff --git a/lnwallet/mock.go b/lnwallet/mock.go
index a8610dc779..39e520d276 100644
--- a/lnwallet/mock.go
+++ b/lnwallet/mock.go
@@ -18,7 +18,7 @@ import (
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/tlv"
diff --git a/lnwallet/musig_session.go b/lnwallet/musig_session.go
index 822aa48a14..748e5fa958 100644
--- a/lnwallet/musig_session.go
+++ b/lnwallet/musig_session.go
@@ -11,7 +11,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/lnwallet/reservation.go b/lnwallet/reservation.go
index fd35d95076..a8a0cacd4b 100644
--- a/lnwallet/reservation.go
+++ b/lnwallet/reservation.go
@@ -12,7 +12,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
diff --git a/lnwallet/rpcwallet/rpcwallet.go b/lnwallet/rpcwallet/rpcwallet.go
index bf6aa61df3..426712b597 100644
--- a/lnwallet/rpcwallet/rpcwallet.go
+++ b/lnwallet/rpcwallet/rpcwallet.go
@@ -22,7 +22,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
basewallet "github.com/btcsuite/btcwallet/wallet"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lncfg"
diff --git a/lnwallet/test/test_interface.go b/lnwallet/test/test_interface.go
index c006aa2e50..27de51708c 100644
--- a/lnwallet/test/test_interface.go
+++ b/lnwallet/test/test_interface.go
@@ -34,7 +34,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/chainntnfs/btcdnotify"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go
index ff9adfbd79..738558e224 100644
--- a/lnwallet/test_utils.go
+++ b/lnwallet/test_utils.go
@@ -16,7 +16,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go
index 135d1866bc..38131eaa72 100644
--- a/lnwallet/transactions_test.go
+++ b/lnwallet/transactions_test.go
@@ -21,7 +21,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
diff --git a/lnwallet/update_log.go b/lnwallet/update_log.go
index 2d1f65c9fa..b2b8af58d1 100644
--- a/lnwallet/update_log.go
+++ b/lnwallet/update_log.go
@@ -1,7 +1,7 @@
package lnwallet
import (
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
)
// updateLog is an append-only log that stores updates to a node's commitment
diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go
index ad6354e2e8..2646d7c8f0 100644
--- a/lnwallet/wallet.go
+++ b/lnwallet/wallet.go
@@ -23,7 +23,7 @@ import (
"github.com/btcsuite/btcwallet/wallet"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
@@ -733,7 +733,7 @@ func (l *LightningWallet) RegisterFundingIntent(expectedID [32]byte,
}
if _, ok := l.fundingIntents[expectedID]; ok {
- return fmt.Errorf("%w: already has intent registered: %v",
+ return fmt.Errorf("%w: already has intent registered: %x",
ErrDuplicatePendingChanID, expectedID[:])
}
diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go
index e523279498..577379623f 100644
--- a/lnwire/channel_reestablish.go
+++ b/lnwire/channel_reestablish.go
@@ -5,7 +5,7 @@ import (
"io"
"github.com/btcsuite/btcd/btcec/v2"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
diff --git a/lnwire/custom_records.go b/lnwire/custom_records.go
index 8177cbe821..a63aa5dfb0 100644
--- a/lnwire/custom_records.go
+++ b/lnwire/custom_records.go
@@ -6,7 +6,7 @@ import (
"io"
"sort"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
@@ -179,9 +179,12 @@ func (c CustomRecords) SerializeTo(w io.Writer) error {
// ProduceRecordsSorted converts a slice of record producers into a slice of
// records and then sorts it by type.
func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record {
- records := fn.Map(func(producer tlv.RecordProducer) tlv.Record {
- return producer.Record()
- }, recordProducers)
+ records := fn.Map(
+ recordProducers,
+ func(producer tlv.RecordProducer) tlv.Record {
+ return producer.Record()
+ },
+ )
// Ensure that the set of records are sorted before we attempt to
// decode from the stream, to ensure they're canonical.
@@ -212,9 +215,9 @@ func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
// RecordsAsProducers converts a slice of records into a slice of record
// producers.
func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer {
- return fn.Map(func(record tlv.Record) tlv.RecordProducer {
+ return fn.Map(records, func(record tlv.Record) tlv.RecordProducer {
return &record
- }, records)
+ })
}
// EncodeRecords encodes the given records into a byte slice.
diff --git a/lnwire/custom_records_test.go b/lnwire/custom_records_test.go
index 8ff6af10ba..d4aad2e546 100644
--- a/lnwire/custom_records_test.go
+++ b/lnwire/custom_records_test.go
@@ -4,7 +4,7 @@ import (
"bytes"
"testing"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -182,9 +182,12 @@ func TestCustomRecordsExtendRecordProducers(t *testing.T) {
func serializeRecordProducers(t *testing.T,
producers []tlv.RecordProducer) []byte {
- tlvRecords := fn.Map(func(p tlv.RecordProducer) tlv.Record {
- return p.Record()
- }, producers)
+ tlvRecords := fn.Map(
+ producers,
+ func(p tlv.RecordProducer) tlv.Record {
+ return p.Record()
+ },
+ )
stream, err := tlv.NewStream(tlvRecords...)
require.NoError(t, err)
diff --git a/lnwire/dyn_ack.go b/lnwire/dyn_ack.go
index 24f23a228d..d477461e7b 100644
--- a/lnwire/dyn_ack.go
+++ b/lnwire/dyn_ack.go
@@ -5,7 +5,7 @@ import (
"io"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
diff --git a/lnwire/dyn_propose.go b/lnwire/dyn_propose.go
index b0cc1198e9..394fff6f37 100644
--- a/lnwire/dyn_propose.go
+++ b/lnwire/dyn_propose.go
@@ -6,7 +6,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/tlv"
)
diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go
index c4ca260e1e..4681426cbb 100644
--- a/lnwire/extra_bytes.go
+++ b/lnwire/extra_bytes.go
@@ -5,7 +5,7 @@ import (
"fmt"
"io"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go
index 952e90a7e6..6bfbb465ec 100644
--- a/lnwire/lnwire_test.go
+++ b/lnwire/lnwire_test.go
@@ -22,7 +22,7 @@ import (
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/tor"
diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go
index 5f05e1ef9f..7b65a85f4e 100644
--- a/lnwire/onion_error.go
+++ b/lnwire/onion_error.go
@@ -10,7 +10,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go
index 9c39be6d5c..5c3d0291a5 100644
--- a/lnwire/onion_error_test.go
+++ b/lnwire/onion_error_test.go
@@ -9,7 +9,7 @@ import (
"testing"
"github.com/davecgh/go-spew/spew"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/stretchr/testify/require"
)
diff --git a/log.go b/log.go
index a3efd03335..46047fb56c 100644
--- a/log.go
+++ b/log.go
@@ -9,6 +9,7 @@ import (
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/autopilot"
"github.com/lightningnetwork/lnd/build"
+ "github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/chainreg"
"github.com/lightningnetwork/lnd/chanacceptor"
@@ -196,6 +197,7 @@ func SetupLoggers(root *build.SubLoggerManager, interceptor signal.Interceptor)
root, blindedpath.Subsystem, interceptor, blindedpath.UseLogger,
)
AddV1SubLogger(root, graphdb.Subsystem, interceptor, graphdb.UseLogger)
+ AddSubLogger(root, chainio.Subsystem, interceptor, chainio.UseLogger)
}
// AddSubLogger is a helper method to conveniently create and register the
diff --git a/msgmux/msg_router.go b/msgmux/msg_router.go
index db9e783990..736c085a95 100644
--- a/msgmux/msg_router.go
+++ b/msgmux/msg_router.go
@@ -6,7 +6,7 @@ import (
"sync"
"github.com/btcsuite/btcd/btcec/v2"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
@@ -91,8 +91,8 @@ func sendQueryErr[Q any](sendChan chan fn.Req[Q, error], queryArg Q,
quitChan chan struct{}) error {
return fn.ElimEither(
- fn.Iden, fn.Iden,
sendQuery(sendChan, queryArg, quitChan).Either,
+ fn.Iden, fn.Iden,
)
}
diff --git a/peer/brontide.go b/peer/brontide.go
index 6bc49445ee..a41b5080cb 100644
--- a/peer/brontide.go
+++ b/peer/brontide.go
@@ -26,7 +26,7 @@ import (
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/feature"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/funding"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
@@ -400,6 +400,10 @@ type Config struct {
// way contracts are resolved.
AuxResolver fn.Option[lnwallet.AuxContractResolver]
+ // AuxTrafficShaper is an optional auxiliary traffic shaper that can be
+ // used to manage the bandwidth of peer links.
+ AuxTrafficShaper fn.Option[htlcswitch.AuxTrafficShaper]
+
// PongBuf is a slice we'll reuse instead of allocating memory on the
// heap. Since only reads will occur and no writes, there is no need
// for any synchronization primitives. As a result, it's safe to share
@@ -1330,6 +1334,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint,
ShouldFwdExpEndorsement: p.cfg.ShouldFwdExpEndorsement,
DisallowQuiescence: p.cfg.DisallowQuiescence ||
!p.remoteFeatures.HasFeature(lnwire.QuiescenceOptional),
+ AuxTrafficShaper: p.cfg.AuxTrafficShaper,
}
// Before adding our new link, purge the switch of any pending or live
diff --git a/peer/brontide_test.go b/peer/brontide_test.go
index c3d1bee48b..eded658887 100644
--- a/peer/brontide_test.go
+++ b/peer/brontide_test.go
@@ -13,7 +13,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntest/wait"
"github.com/lightningnetwork/lnd/lnwallet"
diff --git a/peer/musig_chan_closer.go b/peer/musig_chan_closer.go
index 6f69a8c5b8..149ebcfa0c 100644
--- a/peer/musig_chan_closer.go
+++ b/peer/musig_chan_closer.go
@@ -4,7 +4,7 @@ import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chancloser"
diff --git a/peer/test_utils.go b/peer/test_utils.go
index eb510a53b1..34c42e2f7c 100644
--- a/peer/test_utils.go
+++ b/peer/test_utils.go
@@ -18,7 +18,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channelnotifier"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/input"
diff --git a/protofsm/daemon_events.go b/protofsm/daemon_events.go
index e5de0b6951..bca7283d39 100644
--- a/protofsm/daemon_events.go
+++ b/protofsm/daemon_events.go
@@ -5,7 +5,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
diff --git a/protofsm/msg_mapper.go b/protofsm/msg_mapper.go
index b96d677e6b..5e24255fa3 100644
--- a/protofsm/msg_mapper.go
+++ b/protofsm/msg_mapper.go
@@ -1,7 +1,7 @@
package protofsm
import (
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go
index b71d5efe42..a81f5746b2 100644
--- a/protofsm/state_machine.go
+++ b/protofsm/state_machine.go
@@ -10,7 +10,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire"
)
@@ -21,6 +21,12 @@ const (
pollInterval = time.Millisecond * 100
)
+var (
+ // ErrStateMachineShutdown occurs when trying to feed an event to a
+ // StateMachine that has been asked to Stop.
+ ErrStateMachineShutdown = fmt.Errorf("StateMachine is shutting down")
+)
+
// EmittedEvent is a special type that can be emitted by a state transition.
// This can container internal events which are to be routed back to the state,
// or external events which are to be sent to the daemon.
@@ -287,7 +293,7 @@ func (s *StateMachine[Event, Env]) CurrentState() (State[Event, Env], error) {
}
if !fn.SendOrQuit(s.stateQuery, query, s.quit) {
- return nil, fmt.Errorf("state machine is shutting down")
+ return nil, ErrStateMachineShutdown
}
return fn.RecvOrTimeout(query.CurrentState, time.Second)
@@ -322,6 +328,8 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
// executeDaemonEvent executes a daemon event, which is a special type of event
// that can be emitted as part of the state transition function of the state
// machine. An error is returned if the type of event is unknown.
+//
+//nolint:funlen
func (s *StateMachine[Event, Env]) executeDaemonEvent(
event DaemonEvent) error {
@@ -347,7 +355,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
// If a post-send event was specified, then we'll funnel
// that back into the main state machine now as well.
return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:ll
- return s.wg.Go(func(ctx context.Context) {
+ launched := s.wg.Go(func(ctx context.Context) {
log.Debugf("FSM(%v): sending "+
"post-send event: %v",
s.cfg.Env.Name(),
@@ -356,6 +364,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
s.SendEvent(event)
})
+
+ if !launched {
+ return ErrStateMachineShutdown
+ }
+
+ return nil
})
}
@@ -368,7 +382,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
// Otherwise, this has a SendWhen predicate, so we'll need
// launch a goroutine to poll the SendWhen, then send only once
// the predicate is true.
- return s.wg.Go(func(ctx context.Context) {
+ launched := s.wg.Go(func(ctx context.Context) {
predicateTicker := time.NewTicker(
s.cfg.CustomPollInterval.UnwrapOr(pollInterval),
)
@@ -407,6 +421,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
}
})
+ if !launched {
+ return ErrStateMachineShutdown
+ }
+
+ return nil
+
// If this is a broadcast transaction event, then we'll broadcast with
// the label attached.
case *BroadcastTxn:
@@ -436,7 +456,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
return fmt.Errorf("unable to register spend: %w", err)
}
- return s.wg.Go(func(ctx context.Context) {
+ launched := s.wg.Go(func(ctx context.Context) {
for {
select {
case spend, ok := <-spendEvent.Spend:
@@ -461,6 +481,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
}
})
+ if !launched {
+ return ErrStateMachineShutdown
+ }
+
+ return nil
+
// The state machine has requested a new event to be sent once a
// specified txid+pkScript pair has confirmed.
case *RegisterConf[Event]:
@@ -476,7 +502,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
return fmt.Errorf("unable to register conf: %w", err)
}
- return s.wg.Go(func(ctx context.Context) {
+ launched := s.wg.Go(func(ctx context.Context) {
for {
select {
case <-confEvent.Confirmed:
@@ -498,6 +524,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
}
}
})
+
+ if !launched {
+ return ErrStateMachineShutdown
+ }
+
+ return nil
}
return fmt.Errorf("unknown daemon event: %T", event)
diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go
index fc30fcefc3..fc7a4ccfdc 100644
--- a/protofsm/state_machine_test.go
+++ b/protofsm/state_machine_test.go
@@ -10,7 +10,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
diff --git a/routing/bandwidth.go b/routing/bandwidth.go
index 12e82131dc..a552628c79 100644
--- a/routing/bandwidth.go
+++ b/routing/bandwidth.go
@@ -3,7 +3,7 @@ package routing
import (
"fmt"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire"
@@ -29,39 +29,6 @@ type bandwidthHints interface {
firstHopCustomBlob() fn.Option[tlv.Blob]
}
-// TlvTrafficShaper is an interface that allows the sender to determine if a
-// payment should be carried by a channel based on the TLV records that may be
-// present in the `update_add_htlc` message or the channel commitment itself.
-type TlvTrafficShaper interface {
- AuxHtlcModifier
-
- // ShouldHandleTraffic is called in order to check if the channel
- // identified by the provided channel ID may have external mechanisms
- // that would allow it to carry out the payment.
- ShouldHandleTraffic(cid lnwire.ShortChannelID,
- fundingBlob fn.Option[tlv.Blob]) (bool, error)
-
- // PaymentBandwidth returns the available bandwidth for a custom channel
- // decided by the given channel aux blob and HTLC blob. A return value
- // of 0 means there is no bandwidth available. To find out if a channel
- // is a custom channel that should be handled by the traffic shaper, the
- // HandleTraffic method should be called first.
- PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob],
- linkBandwidth,
- htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error)
-}
-
-// AuxHtlcModifier is an interface that allows the sender to modify the outgoing
-// HTLC of a payment by changing the amount or the wire message tlv records.
-type AuxHtlcModifier interface {
- // ProduceHtlcExtraData is a function that, based on the previous extra
- // data blob of an HTLC, may produce a different blob or modify the
- // amount of bitcoin this htlc should carry.
- ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi,
- htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi,
- lnwire.CustomRecords, error)
-}
-
// getLinkQuery is the function signature used to lookup a link.
type getLinkQuery func(lnwire.ShortChannelID) (
htlcswitch.ChannelLink, error)
@@ -73,7 +40,7 @@ type bandwidthManager struct {
getLink getLinkQuery
localChans map[lnwire.ShortChannelID]struct{}
firstHopBlob fn.Option[tlv.Blob]
- trafficShaper fn.Option[TlvTrafficShaper]
+ trafficShaper fn.Option[htlcswitch.AuxTrafficShaper]
}
// newBandwidthManager creates a bandwidth manager for the source node provided
@@ -84,13 +51,14 @@ type bandwidthManager struct {
// that are inactive, or just don't have enough bandwidth to carry the payment.
func newBandwidthManager(graph Graph, sourceNode route.Vertex,
linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob],
- trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) {
+ ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager,
+ error) {
manager := &bandwidthManager{
getLink: linkQuery,
localChans: make(map[lnwire.ShortChannelID]struct{}),
firstHopBlob: firstHopBlob,
- trafficShaper: trafficShaper,
+ trafficShaper: ts,
}
// First, we'll collect the set of outbound edges from the target
@@ -166,44 +134,15 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID,
result, err := fn.MapOptionZ(
b.trafficShaper,
- func(ts TlvTrafficShaper) fn.Result[bandwidthResult] {
- fundingBlob := link.FundingCustomBlob()
- shouldHandle, err := ts.ShouldHandleTraffic(
- cid, fundingBlob,
- )
- if err != nil {
- return bandwidthErr(fmt.Errorf("traffic "+
- "shaper failed to decide whether to "+
- "handle traffic: %w", err))
- }
-
- log.Debugf("ShortChannelID=%v: external traffic "+
- "shaper is handling traffic: %v", cid,
- shouldHandle)
-
- // If this channel isn't handled by the external traffic
- // shaper, we'll return early.
- if !shouldHandle {
- return fn.Ok(bandwidthResult{})
- }
-
- // Ask for a specific bandwidth to be used for the
- // channel.
- commitmentBlob := link.CommitmentCustomBlob()
- auxBandwidth, err := ts.PaymentBandwidth(
- b.firstHopBlob, commitmentBlob, linkBandwidth,
- amount,
- )
+ func(s htlcswitch.AuxTrafficShaper) fn.Result[bandwidthResult] {
+ auxBandwidth, err := link.AuxBandwidth(
+ amount, cid, b.firstHopBlob, s,
+ ).Unpack()
if err != nil {
return bandwidthErr(fmt.Errorf("failed to get "+
- "bandwidth from external traffic "+
- "shaper: %w", err))
+ "auxiliary bandwidth: %w", err))
}
- log.Debugf("ShortChannelID=%v: external traffic "+
- "shaper reported available bandwidth: %v", cid,
- auxBandwidth)
-
// We don't know the actual HTLC amount that will be
// sent using the custom channel. But we'll still want
// to make sure we can add another HTLC, using the
@@ -213,7 +152,7 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID,
// the max number of HTLCs on the channel. A proper
// balance check is done elsewhere.
return fn.Ok(bandwidthResult{
- bandwidth: fn.Some(auxBandwidth),
+ bandwidth: auxBandwidth,
htlcAmount: fn.Some[lnwire.MilliSatoshi](0),
})
},
diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go
index 4872b5a7ec..b31d0095ac 100644
--- a/routing/bandwidth_test.go
+++ b/routing/bandwidth_test.go
@@ -5,7 +5,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/go-errors/errors"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
@@ -118,7 +118,9 @@ func TestBandwidthManager(t *testing.T) {
m, err := newBandwidthManager(
g, sourceNode.pubkey, testCase.linkQuery,
fn.None[[]byte](),
- fn.Some[TlvTrafficShaper](&mockTrafficShaper{}),
+ fn.Some[htlcswitch.AuxTrafficShaper](
+ &mockTrafficShaper{},
+ ),
)
require.NoError(t, err)
diff --git a/routing/blinding.go b/routing/blinding.go
index 7c84063469..0c27e87439 100644
--- a/routing/blinding.go
+++ b/routing/blinding.go
@@ -6,7 +6,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
sphinx "github.com/lightningnetwork/lightning-onion"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
diff --git a/routing/blinding_test.go b/routing/blinding_test.go
index 410dfaf643..8f83f7fd82 100644
--- a/routing/blinding_test.go
+++ b/routing/blinding_test.go
@@ -6,7 +6,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
sphinx "github.com/lightningnetwork/lightning-onion"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go
index 315b0dff22..e4241dac53 100644
--- a/routing/integrated_routing_context_test.go
+++ b/routing/integrated_routing_context_test.go
@@ -7,7 +7,7 @@ import (
"testing"
"time"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go
index d7380439ac..cd9e58fcaa 100644
--- a/routing/localchans/manager.go
+++ b/routing/localchans/manager.go
@@ -11,7 +11,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/discovery"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/routing/missioncontrol.go b/routing/missioncontrol.go
index de892392e7..3bc9be7aba 100644
--- a/routing/missioncontrol.go
+++ b/routing/missioncontrol.go
@@ -13,7 +13,7 @@ import (
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
diff --git a/routing/mock_test.go b/routing/mock_test.go
index 3cdb5ebaf2..86fd765499 100644
--- a/routing/mock_test.go
+++ b/routing/mock_test.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
@@ -107,7 +107,7 @@ var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil)
func (m *mockPaymentSessionSourceOld) NewPaymentSession(
_ *LightningPayment, _ fn.Option[tlv.Blob],
- _ fn.Option[TlvTrafficShaper]) (PaymentSession, error) {
+ _ fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) {
return &mockPaymentSessionOld{
routes: m.routes,
@@ -635,7 +635,8 @@ var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)
func (m *mockPaymentSessionSource) NewPaymentSession(
payment *LightningPayment, firstHopBlob fn.Option[tlv.Blob],
- tlvShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) {
+ tlvShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession,
+ error) {
args := m.Called(payment, firstHopBlob, tlvShaper)
return args.Get(0).(PaymentSession), args.Error(1)
@@ -895,6 +896,19 @@ func (m *mockLink) Bandwidth() lnwire.MilliSatoshi {
return m.bandwidth
}
+// AuxBandwidth returns the bandwidth that can be used for a channel,
+// expressed in milli-satoshi. This might be different from the regular
+// BTC bandwidth for custom channels. This will always return fn.None()
+// for a regular (non-custom) channel.
+func (m *mockLink) AuxBandwidth(lnwire.MilliSatoshi, lnwire.ShortChannelID,
+ fn.Option[tlv.Blob],
+ htlcswitch.AuxTrafficShaper) fn.Result[htlcswitch.OptionalBandwidth] {
+
+ return fn.Ok[htlcswitch.OptionalBandwidth](
+ fn.None[lnwire.MilliSatoshi](),
+ )
+}
+
// EligibleToForward returns the mock's configured eligibility.
func (m *mockLink) EligibleToForward() bool {
return !m.ineligible
diff --git a/routing/pathfind.go b/routing/pathfind.go
index 8e40c5bc4b..80f4b1e68f 100644
--- a/routing/pathfind.go
+++ b/routing/pathfind.go
@@ -12,7 +12,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/feature"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnutils"
diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go
index da29c79a25..c463b8135b 100644
--- a/routing/pathfind_test.go
+++ b/routing/pathfind_test.go
@@ -21,7 +21,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
sphinx "github.com/lightningnetwork/lightning-onion"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch"
diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go
index 267ce3965d..5c4bae5585 100644
--- a/routing/payment_lifecycle.go
+++ b/routing/payment_lifecycle.go
@@ -10,7 +10,7 @@ import (
"github.com/davecgh/go-spew/spew"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
@@ -761,7 +761,8 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error {
// and apply its side effects to the UpdateAddHTLC message.
result, err := fn.MapOptionZ(
p.router.cfg.TrafficShaper,
- func(ts TlvTrafficShaper) fn.Result[extraDataRequest] {
+ //nolint:ll
+ func(ts htlcswitch.AuxTrafficShaper) fn.Result[extraDataRequest] {
newAmt, newRecords, err := ts.ProduceHtlcExtraData(
rt.TotalAmount, p.firstHopCustomRecords,
)
@@ -774,7 +775,7 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error {
return fn.Err[extraDataRequest](err)
}
- log.Debugf("TLV traffic shaper returned custom "+
+ log.Debugf("Aux traffic shaper returned custom "+
"records %v and amount %d msat for HTLC",
spew.Sdump(newRecords), newAmt)
diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go
index 315c1bad58..72aa631419 100644
--- a/routing/payment_lifecycle_test.go
+++ b/routing/payment_lifecycle_test.go
@@ -9,7 +9,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnmock"
"github.com/lightningnetwork/lnd/lntest/wait"
@@ -30,7 +30,7 @@ func createTestPaymentLifecycle() *paymentLifecycle {
quitChan := make(chan struct{})
rt := &ChannelRouter{
cfg: &Config{
- TrafficShaper: fn.Some[TlvTrafficShaper](
+ TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper](
&mockTrafficShaper{},
),
},
@@ -83,7 +83,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) {
Payer: mockPayer,
Clock: mockClock,
MissionControl: mockMissionControl,
- TrafficShaper: fn.Some[TlvTrafficShaper](
+ TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper](
&mockTrafficShaper{},
),
},
diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go
index d5f1a6af41..5e4eb23d7f 100644
--- a/routing/payment_session_source.go
+++ b/routing/payment_session_source.go
@@ -2,8 +2,9 @@ package routing
import (
"github.com/btcsuite/btcd/btcec/v2"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
+ "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
@@ -52,7 +53,8 @@ type SessionSource struct {
// payment's destination.
func (m *SessionSource) NewPaymentSession(p *LightningPayment,
firstHopBlob fn.Option[tlv.Blob],
- trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) {
+ trafficShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession,
+ error) {
getBandwidthHints := func(graph Graph) (bandwidthHints, error) {
return newBandwidthManager(
diff --git a/routing/result_interpretation.go b/routing/result_interpretation.go
index 089213d65e..bc1749dfb1 100644
--- a/routing/result_interpretation.go
+++ b/routing/result_interpretation.go
@@ -6,7 +6,7 @@ import (
"io"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
@@ -578,7 +578,7 @@ func extractMCRoute(r *route.Route) *mcRoute {
// extractMCHops extracts the Hop fields that MC actually uses from a slice of
// Hops.
func extractMCHops(hops []*route.Hop) mcHops {
- return fn.Map(extractMCHop, hops)
+ return fn.Map(hops, extractMCHop)
}
// extractMCHop extracts the Hop fields that MC actually uses from a Hop.
diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go
index b213eb1835..8c67bdeea9 100644
--- a/routing/result_interpretation_test.go
+++ b/routing/result_interpretation_test.go
@@ -6,7 +6,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/davecgh/go-spew/spew"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
diff --git a/routing/router.go b/routing/router.go
index 9eabe0b2ae..468510a6c7 100644
--- a/routing/router.go
+++ b/routing/router.go
@@ -19,7 +19,7 @@ import (
"github.com/lightningnetwork/lnd/amp"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
@@ -157,7 +157,7 @@ type PaymentSessionSource interface {
// finding a path to the payment's destination.
NewPaymentSession(p *LightningPayment,
firstHopBlob fn.Option[tlv.Blob],
- trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession,
+ ts fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession,
error)
// NewPaymentSessionEmpty creates a new paymentSession instance that is
@@ -297,7 +297,7 @@ type Config struct {
// TrafficShaper is an optional traffic shaper that can be used to
// control the outgoing channel of a payment.
- TrafficShaper fn.Option[TlvTrafficShaper]
+ TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper]
}
// EdgeLocator is a struct used to identify a specific edge.
diff --git a/routing/router_test.go b/routing/router_test.go
index 2923f1fb90..22c9d14e50 100644
--- a/routing/router_test.go
+++ b/routing/router_test.go
@@ -23,7 +23,7 @@ import (
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
@@ -170,7 +170,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
Clock: clock.NewTestClock(time.Unix(1, 0)),
ApplyChannelUpdate: graphBuilder.ApplyChannelUpdate,
ClosedSCIDs: mockClosedSCIDs,
- TrafficShaper: fn.Some[TlvTrafficShaper](
+ TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper](
&mockTrafficShaper{},
),
})
@@ -2206,8 +2206,10 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) {
NextPaymentID: func() (uint64, error) {
return 0, nil
},
- ClosedSCIDs: mockClosedSCIDs,
- TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}),
+ ClosedSCIDs: mockClosedSCIDs,
+ TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper](
+ &mockTrafficShaper{},
+ ),
}}
// Register mockers with the expected method calls.
@@ -2291,8 +2293,10 @@ func TestSendToRouteSkipTempErrNonMPP(t *testing.T) {
NextPaymentID: func() (uint64, error) {
return 0, nil
},
- ClosedSCIDs: mockClosedSCIDs,
- TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}),
+ ClosedSCIDs: mockClosedSCIDs,
+ TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper](
+ &mockTrafficShaper{},
+ ),
}}
// Expect an error to be returned.
@@ -2347,8 +2351,10 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) {
NextPaymentID: func() (uint64, error) {
return 0, nil
},
- ClosedSCIDs: mockClosedSCIDs,
- TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}),
+ ClosedSCIDs: mockClosedSCIDs,
+ TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper](
+ &mockTrafficShaper{},
+ ),
}}
// Create the error to be returned.
@@ -2431,8 +2437,10 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) {
NextPaymentID: func() (uint64, error) {
return 0, nil
},
- ClosedSCIDs: mockClosedSCIDs,
- TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}),
+ ClosedSCIDs: mockClosedSCIDs,
+ TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper](
+ &mockTrafficShaper{},
+ ),
}}
// Create the error to be returned.
@@ -2519,8 +2527,10 @@ func TestSendToRouteTempFailure(t *testing.T) {
NextPaymentID: func() (uint64, error) {
return 0, nil
},
- ClosedSCIDs: mockClosedSCIDs,
- TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}),
+ ClosedSCIDs: mockClosedSCIDs,
+ TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper](
+ &mockTrafficShaper{},
+ ),
}}
// Create the error to be returned.
diff --git a/rpcserver.go b/rpcserver.go
index d7d2e0186c..72e2fa4afd 100644
--- a/rpcserver.go
+++ b/rpcserver.go
@@ -46,7 +46,7 @@ import (
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/feature"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/funding"
"github.com/lightningnetwork/lnd/graph"
graphdb "github.com/lightningnetwork/lnd/graph/db"
@@ -8068,9 +8068,9 @@ func (r *rpcServer) VerifyChanBackup(ctx context.Context,
}
return &lnrpc.VerifyChanBackupResponse{
- ChanPoints: fn.Map(func(c chanbackup.Single) string {
+ ChanPoints: fn.Map(channels, func(c chanbackup.Single) string {
return c.FundingOutpoint.String()
- }, channels),
+ }),
}, nil
}
diff --git a/rpcserver_test.go b/rpcserver_test.go
index b4b66e719c..b686c9020a 100644
--- a/rpcserver_test.go
+++ b/rpcserver_test.go
@@ -5,7 +5,7 @@ import (
"testing"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
diff --git a/server.go b/server.go
index f8f8239ed6..9fd0b7a006 100644
--- a/server.go
+++ b/server.go
@@ -28,6 +28,7 @@ import (
"github.com/lightningnetwork/lnd/aliasmgr"
"github.com/lightningnetwork/lnd/autopilot"
"github.com/lightningnetwork/lnd/brontide"
+ "github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainreg"
"github.com/lightningnetwork/lnd/chanacceptor"
"github.com/lightningnetwork/lnd/chanbackup"
@@ -39,7 +40,7 @@ import (
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/feature"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/funding"
"github.com/lightningnetwork/lnd/graph"
graphdb "github.com/lightningnetwork/lnd/graph/db"
@@ -356,6 +357,10 @@ type server struct {
// txPublisher is a publisher with fee-bumping capability.
txPublisher *sweep.TxPublisher
+ // blockbeatDispatcher is a block dispatcher that notifies subscribers
+ // of new blocks.
+ blockbeatDispatcher *chainio.BlockbeatDispatcher
+
quit chan struct{}
wg sync.WaitGroup
@@ -623,6 +628,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
readPool: readPool,
chansToRestore: chansToRestore,
+ blockbeatDispatcher: chainio.NewBlockbeatDispatcher(
+ cc.ChainNotifier,
+ ),
channelNotifier: channelnotifier.New(
dbs.ChanStateDB.ChannelStateDB(),
),
@@ -665,6 +673,17 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
quit: make(chan struct{}),
}
+ // Start the low-level services once they are initialized.
+ //
+ // TODO(yy): break the server startup into four steps,
+ // 1. init the low-level services.
+ // 2. start the low-level services.
+ // 3. init the high-level services.
+ // 4. start the high-level services.
+ if err := s.startLowLevelServices(); err != nil {
+ return nil, err
+ }
+
currentHash, currentHeight, err := s.cc.ChainIO.GetBestBlock()
if err != nil {
return nil, err
@@ -1813,6 +1832,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
}
s.connMgr = cmgr
+ // Finally, register the subsystems in blockbeat.
+ s.registerBlockConsumers()
+
return s, nil
}
@@ -1845,6 +1867,25 @@ func (s *server) UpdateRoutingConfig(cfg *routing.MissionControlConfig) {
routerCfg.MaxMcHistory = cfg.MaxMcHistory
}
+// registerBlockConsumers registers the subsystems that consume block events.
+// By calling `RegisterQueue`, a list of subsystems are registered in the
+// blockbeat for block notifications. When a new block arrives, the subsystems
+// in the same queue are notified sequentially, and different queues are
+// notified concurrently.
+//
+// NOTE: To put a subsystem in a different queue, create a slice and pass it to
+// a new `RegisterQueue` call.
+func (s *server) registerBlockConsumers() {
+ // In this queue, when a new block arrives, it will be received and
+ // processed in this order: chainArb -> sweeper -> txPublisher.
+ consumers := []chainio.Consumer{
+ s.chainArb,
+ s.sweeper,
+ s.txPublisher,
+ }
+ s.blockbeatDispatcher.RegisterQueue(consumers)
+}
+
// signAliasUpdate takes a ChannelUpdate and returns the signature. This is
// used for option_scid_alias channels where the ChannelUpdate to be sent back
// may differ from what is on disk.
@@ -2067,12 +2108,41 @@ func (c cleaner) run() {
}
}
+// startLowLevelServices starts the low-level services of the server. These
+// services must be started successfully before running the main server. The
+// services are,
+// 1. the chain notifier.
+//
+// TODO(yy): identify and add more low-level services here.
+func (s *server) startLowLevelServices() error {
+ var startErr error
+
+ cleanup := cleaner{}
+
+ cleanup = cleanup.add(s.cc.ChainNotifier.Stop)
+ if err := s.cc.ChainNotifier.Start(); err != nil {
+ startErr = err
+ }
+
+ if startErr != nil {
+ cleanup.run()
+ }
+
+ return startErr
+}
+
// Start starts the main daemon server, all requested listeners, and any helper
// goroutines.
// NOTE: This function is safe for concurrent access.
//
//nolint:funlen
func (s *server) Start() error {
+ // Get the current blockbeat.
+ beat, err := s.getStartingBeat()
+ if err != nil {
+ return err
+ }
+
var startErr error
// If one sub system fails to start, the following code ensures that the
@@ -2126,12 +2196,6 @@ func (s *server) Start() error {
return
}
- cleanup = cleanup.add(s.cc.ChainNotifier.Stop)
- if err := s.cc.ChainNotifier.Start(); err != nil {
- startErr = err
- return
- }
-
cleanup = cleanup.add(s.cc.BestBlockTracker.Stop)
if err := s.cc.BestBlockTracker.Start(); err != nil {
startErr = err
@@ -2167,13 +2231,13 @@ func (s *server) Start() error {
}
cleanup = cleanup.add(s.txPublisher.Stop)
- if err := s.txPublisher.Start(); err != nil {
+ if err := s.txPublisher.Start(beat); err != nil {
startErr = err
return
}
cleanup = cleanup.add(s.sweeper.Stop)
- if err := s.sweeper.Start(); err != nil {
+ if err := s.sweeper.Start(beat); err != nil {
startErr = err
return
}
@@ -2218,7 +2282,7 @@ func (s *server) Start() error {
}
cleanup = cleanup.add(s.chainArb.Stop)
- if err := s.chainArb.Start(); err != nil {
+ if err := s.chainArb.Start(beat); err != nil {
startErr = err
return
}
@@ -2459,6 +2523,17 @@ func (s *server) Start() error {
srvrLog.Infof("Auto peer bootstrapping is disabled")
}
+ // Start the blockbeat after all other subsystems have been
+ // started so they are ready to receive new blocks.
+ cleanup = cleanup.add(func() error {
+ s.blockbeatDispatcher.Stop()
+ return nil
+ })
+ if err := s.blockbeatDispatcher.Start(); err != nil {
+ startErr = err
+ return
+ }
+
// Set the active flag now that we've completed the full
// startup.
atomic.StoreInt32(&s.active, 1)
@@ -2483,6 +2558,9 @@ func (s *server) Stop() error {
// Shutdown connMgr first to prevent conns during shutdown.
s.connMgr.Stop()
+ // Stop dispatching blocks to other systems immediately.
+ s.blockbeatDispatcher.Stop()
+
// Shutdown the wallet, funding manager, and the rpc server.
if err := s.chanStatusMgr.Stop(); err != nil {
srvrLog.Warnf("failed to stop chanStatusMgr: %v", err)
@@ -4222,6 +4300,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq,
MsgRouter: s.implCfg.MsgRouter,
AuxChanCloser: s.implCfg.AuxChanCloser,
AuxResolver: s.implCfg.AuxContractResolver,
+ AuxTrafficShaper: s.implCfg.TrafficShaper,
ShouldFwdExpEndorsement: func() bool {
if s.cfg.ProtocolOptions.NoExperimentalEndorsement() {
return false
@@ -5151,3 +5230,35 @@ func (s *server) fetchClosedChannelSCIDs() map[lnwire.ShortChannelID]struct{} {
return closedSCIDs
}
+
+// getStartingBeat returns the current beat. This is used during the startup to
+// initialize blockbeat consumers.
+func (s *server) getStartingBeat() (*chainio.Beat, error) {
+ // beat is the current blockbeat.
+ var beat *chainio.Beat
+
+ // We should get a notification with the current best block immediately
+ // by passing a nil block.
+ blockEpochs, err := s.cc.ChainNotifier.RegisterBlockEpochNtfn(nil)
+ if err != nil {
+ return beat, fmt.Errorf("register block epoch ntfn: %w", err)
+ }
+ defer blockEpochs.Cancel()
+
+ // We registered for the block epochs with a nil request. The notifier
+ // should send us the current best block immediately. So we need to
+ // wait for it here because we need to know the current best height.
+ select {
+ case bestBlock := <-blockEpochs.Epochs:
+ srvrLog.Infof("Received initial block %v at height %d",
+ bestBlock.Hash, bestBlock.Height)
+
+ // Update the current blockbeat.
+ beat = chainio.NewBeat(*bestBlock)
+
+ case <-s.quit:
+ srvrLog.Debug("LND shutting down")
+ }
+
+ return beat, nil
+}
diff --git a/subrpcserver_config.go b/subrpcserver_config.go
index 30755c05e4..102e211187 100644
--- a/subrpcserver_config.go
+++ b/subrpcserver_config.go
@@ -11,7 +11,7 @@ import (
"github.com/lightningnetwork/lnd/autopilot"
"github.com/lightningnetwork/lnd/chainreg"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/invoices"
diff --git a/sweep/aggregator.go b/sweep/aggregator.go
index a0a1b0a540..e97ccb9a21 100644
--- a/sweep/aggregator.go
+++ b/sweep/aggregator.go
@@ -5,7 +5,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
diff --git a/sweep/aggregator_test.go b/sweep/aggregator_test.go
index 6df0d73fa2..2cb89bdc38 100644
--- a/sweep/aggregator_test.go
+++ b/sweep/aggregator_test.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go
index adb4db65ed..a43f875850 100644
--- a/sweep/fee_bumper.go
+++ b/sweep/fee_bumper.go
@@ -12,8 +12,9 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/chain"
+ "github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lntypes"
@@ -65,7 +66,7 @@ type Bumper interface {
// and monitors its confirmation status for potential fee bumping. It
// returns a chan that the caller can use to receive updates about the
// broadcast result and potential RBF attempts.
- Broadcast(req *BumpRequest) (<-chan *BumpResult, error)
+ Broadcast(req *BumpRequest) <-chan *BumpResult
}
// BumpEvent represents the event of a fee bumping attempt.
@@ -75,7 +76,17 @@ const (
// TxPublished is sent when the broadcast attempt is finished.
TxPublished BumpEvent = iota
- // TxFailed is sent when the broadcast attempt fails.
+ // TxFailed is sent when the tx has encountered a fee-related error
+ // during its creation or broadcast, or an internal error from the fee
+ // bumper. In either case the inputs in this tx should be retried with
+ // either a different grouping strategy or an increased budget.
+ //
+ // NOTE: We also send this event when there's a third party spend
+ // event, and the sweeper will handle cleaning this up once it's
+ // confirmed.
+ //
+ // TODO(yy): Remove the above usage once we remove sweeping non-CPFP
+ // anchors.
TxFailed
// TxReplaced is sent when the original tx is replaced by a new one.
@@ -84,6 +95,11 @@ const (
// TxConfirmed is sent when the tx is confirmed.
TxConfirmed
+ // TxFatal is sent when the inputs in this tx cannot be retried. Txns
+ // will end up in this state if they have encountered a non-fee related
+ // error, which means they cannot be retried with increased budget.
+ TxFatal
+
// sentinalEvent is used to check if an event is unknown.
sentinalEvent
)
@@ -99,6 +115,8 @@ func (e BumpEvent) String() string {
return "Replaced"
case TxConfirmed:
return "Confirmed"
+ case TxFatal:
+ return "Fatal"
default:
return "Unknown"
}
@@ -136,6 +154,10 @@ type BumpRequest struct {
// ExtraTxOut tracks if this bump request has an optional set of extra
// outputs to add to the transaction.
ExtraTxOut fn.Option[SweepOutput]
+
+ // Immediate is used to specify that the tx should be broadcast
+ // immediately.
+ Immediate bool
}
// MaxFeeRateAllowed returns the maximum fee rate allowed for the given
@@ -145,13 +167,13 @@ type BumpRequest struct {
func (r *BumpRequest) MaxFeeRateAllowed() (chainfee.SatPerKWeight, error) {
// We'll want to know if we have any blobs, as we need to factor this
// into the max fee rate for this bump request.
- hasBlobs := fn.Any(func(i input.Input) bool {
+ hasBlobs := fn.Any(r.Inputs, func(i input.Input) bool {
return fn.MapOptionZ(
i.ResolutionBlob(), func(b tlv.Blob) bool {
return len(b) > 0
},
)
- }, r.Inputs)
+ })
sweepAddrs := [][]byte{
r.DeliveryAddress.DeliveryAddress,
@@ -246,10 +268,22 @@ type BumpResult struct {
requestID uint64
}
+// String returns a human-readable string for the result.
+func (b *BumpResult) String() string {
+ desc := fmt.Sprintf("Event=%v", b.Event)
+ if b.Tx != nil {
+ desc += fmt.Sprintf(", Tx=%v", b.Tx.TxHash())
+ }
+
+ return fmt.Sprintf("[%s]", desc)
+}
+
// Validate validates the BumpResult so it's safe to use.
func (b *BumpResult) Validate() error {
- // Every result must have a tx.
- if b.Tx == nil {
+ isFailureEvent := b.Event == TxFailed || b.Event == TxFatal
+
+ // Every result must have a tx except the fatal or failed case.
+ if b.Tx == nil && !isFailureEvent {
return fmt.Errorf("%w: nil tx", ErrInvalidBumpResult)
}
@@ -263,8 +297,8 @@ func (b *BumpResult) Validate() error {
return fmt.Errorf("%w: nil replacing tx", ErrInvalidBumpResult)
}
- // If it's a failed event, it must have an error.
- if b.Event == TxFailed && b.Err == nil {
+ // If it's a failed or fatal event, it must have an error.
+ if isFailureEvent && b.Err == nil {
return fmt.Errorf("%w: nil error", ErrInvalidBumpResult)
}
@@ -311,6 +345,10 @@ type TxPublisher struct {
started atomic.Bool
stopped atomic.Bool
+ // Embed the blockbeat consumer struct to get access to the method
+ // `NotifyBlockProcessed` and the `BlockbeatChan`.
+ chainio.BeatConsumer
+
wg sync.WaitGroup
// cfg specifies the configuration of the TxPublisher.
@@ -338,14 +376,22 @@ type TxPublisher struct {
// Compile-time constraint to ensure TxPublisher implements Bumper.
var _ Bumper = (*TxPublisher)(nil)
+// Compile-time check for the chainio.Consumer interface.
+var _ chainio.Consumer = (*TxPublisher)(nil)
+
// NewTxPublisher creates a new TxPublisher.
func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher {
- return &TxPublisher{
+ tp := &TxPublisher{
cfg: &cfg,
records: lnutils.SyncMap[uint64, *monitorRecord]{},
subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{},
quit: make(chan struct{}),
}
+
+ // Mount the block consumer.
+ tp.BeatConsumer = chainio.NewBeatConsumer(tp.quit, tp.Name())
+
+ return tp
}
// isNeutrinoBackend checks if the wallet backend is neutrino.
@@ -353,60 +399,69 @@ func (t *TxPublisher) isNeutrinoBackend() bool {
return t.cfg.Wallet.BackEnd() == "neutrino"
}
-// Broadcast is used to publish the tx created from the given inputs. It will,
-// 1. init a fee function based on the given strategy.
-// 2. create an RBF-compliant tx and monitor it for confirmation.
-// 3. notify the initial broadcast result back to the caller.
-// The initial broadcast is guaranteed to be RBF-compliant unless the budget
-// specified cannot cover the fee.
+// Broadcast is used to publish the tx created from the given inputs. It will
+// register the broadcast request and return a chan to the caller to subscribe
+// the broadcast result. The initial broadcast is guaranteed to be
+// RBF-compliant unless the budget specified cannot cover the fee.
//
// NOTE: part of the Bumper interface.
-func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
- log.Tracef("Received broadcast request: %s", lnutils.SpewLogClosure(
- req))
+func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult {
+ log.Tracef("Received broadcast request: %s",
+ lnutils.SpewLogClosure(req))
- // Attempt an initial broadcast which is guaranteed to comply with the
- // RBF rules.
- result, err := t.initialBroadcast(req)
- if err != nil {
- log.Errorf("Initial broadcast failed: %v", err)
-
- return nil, err
- }
+ // Store the request.
+ requestID, record := t.storeInitialRecord(req)
// Create a chan to send the result to the caller.
subscriber := make(chan *BumpResult, 1)
- t.subscriberChans.Store(result.requestID, subscriber)
+ t.subscriberChans.Store(requestID, subscriber)
- // Send the initial broadcast result to the caller.
- t.handleResult(result)
+ // Publish the tx immediately if specified.
+ if req.Immediate {
+ t.handleInitialBroadcast(record, requestID)
+ }
+
+ return subscriber
+}
+
+// storeInitialRecord initializes a monitor record and saves it in the map.
+func (t *TxPublisher) storeInitialRecord(req *BumpRequest) (
+ uint64, *monitorRecord) {
- return subscriber, nil
+ // Increase the request counter.
+ //
+ // NOTE: this is the only place where we increase the counter.
+ requestID := t.requestCounter.Add(1)
+
+ // Register the record.
+ record := &monitorRecord{req: req}
+ t.records.Store(requestID, record)
+
+ return requestID, record
+}
+
+// NOTE: part of the `chainio.Consumer` interface.
+func (t *TxPublisher) Name() string {
+ return "TxPublisher"
}
-// initialBroadcast initializes a fee function, creates an RBF-compliant tx and
-// broadcasts it.
-func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) {
+// initializeTx initializes a fee function and creates an RBF-compliant tx. If
+// succeeded, the initial tx is stored in the records map.
+func (t *TxPublisher) initializeTx(requestID uint64, req *BumpRequest) error {
// Create a fee bumping algorithm to be used for future RBF.
feeAlgo, err := t.initializeFeeFunction(req)
if err != nil {
- return nil, fmt.Errorf("init fee function: %w", err)
+ return fmt.Errorf("init fee function: %w", err)
}
// Create the initial tx to be broadcasted. This tx is guaranteed to
// comply with the RBF restrictions.
- requestID, err := t.createRBFCompliantTx(req, feeAlgo)
+ err = t.createRBFCompliantTx(requestID, req, feeAlgo)
if err != nil {
- return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
+ return fmt.Errorf("create RBF-compliant tx: %w", err)
}
- // Broadcast the tx and return the monitored record.
- result, err := t.broadcast(requestID)
- if err != nil {
- return nil, fmt.Errorf("broadcast sweep tx: %w", err)
- }
-
- return result, nil
+ return nil
}
// initializeFeeFunction initializes a fee function to be used for this request
@@ -442,8 +497,8 @@ func (t *TxPublisher) initializeFeeFunction(
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
// and redo the process until the tx is valid, or return an error when non-RBF
// related errors occur or the budget has been used up.
-func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
- f FeeFunction) (uint64, error) {
+func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest,
+ f FeeFunction) error {
for {
// Create a new tx with the given fee rate and check its
@@ -452,18 +507,19 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
switch {
case err == nil:
- // The tx is valid, return the request ID.
- requestID := t.storeRecord(
- sweepCtx.tx, req, f, sweepCtx.fee,
+ // The tx is valid, store it.
+ t.storeRecord(
+ requestID, sweepCtx.tx, req, f, sweepCtx.fee,
sweepCtx.outpointToTxIndex,
)
- log.Infof("Created tx %v for %v inputs: feerate=%v, "+
- "fee=%v, inputs=%v", sweepCtx.tx.TxHash(),
- len(req.Inputs), f.FeeRate(), sweepCtx.fee,
+ log.Infof("Created initial sweep tx=%v for %v inputs: "+
+ "feerate=%v, fee=%v, inputs:\n%v",
+ sweepCtx.tx.TxHash(), len(req.Inputs),
+ f.FeeRate(), sweepCtx.fee,
inputTypeSummary(req.Inputs))
- return requestID, nil
+ return nil
// If the error indicates the fees paid is not enough, we will
// ask the fee function to increase the fee rate and retry.
@@ -494,7 +550,7 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
// cluster these inputs differetly.
increased, err = f.Increment()
if err != nil {
- return 0, err
+ return err
}
}
@@ -504,21 +560,15 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest,
// mempool acceptance.
default:
log.Debugf("Failed to create RBF-compliant tx: %v", err)
- return 0, err
+ return err
}
}
}
// storeRecord stores the given record in the records map.
-func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest,
- f FeeFunction, fee btcutil.Amount,
- outpointToTxIndex map[wire.OutPoint]int) uint64 {
-
- // Increase the request counter.
- //
- // NOTE: this is the only place where we increase the
- // counter.
- requestID := t.requestCounter.Add(1)
+func (t *TxPublisher) storeRecord(requestID uint64, tx *wire.MsgTx,
+ req *BumpRequest, f FeeFunction, fee btcutil.Amount,
+ outpointToTxIndex map[wire.OutPoint]int) {
// Register the record.
t.records.Store(requestID, &monitorRecord{
@@ -528,8 +578,6 @@ func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest,
fee: fee,
outpointToTxIndex: outpointToTxIndex,
})
-
- return requestID
}
// createAndCheckTx creates a tx based on the given inputs, change output
@@ -659,8 +707,7 @@ func (t *TxPublisher) notifyResult(result *BumpResult) {
return
}
- log.Debugf("Sending result for requestID=%v, tx=%v", id,
- result.Tx.TxHash())
+ log.Debugf("Sending result %v for requestID=%v", result, id)
select {
// Send the result to the subscriber.
@@ -678,20 +725,31 @@ func (t *TxPublisher) notifyResult(result *BumpResult) {
func (t *TxPublisher) removeResult(result *BumpResult) {
id := result.requestID
- // Remove the record from the maps if there's an error. This means this
- // tx has failed its broadcast and cannot be retried. There are two
- // cases,
- // - when the budget cannot cover the fee.
- // - when a non-RBF related error occurs.
+ var txid chainhash.Hash
+ if result.Tx != nil {
+ txid = result.Tx.TxHash()
+ }
+
+ // Remove the record from the maps if there's an error or the tx is
+ // confirmed. When there's an error, it means this tx has failed its
+ // broadcast and cannot be retried. There are two cases it may fail,
+ // - when the budget cannot cover the increased fee calculated by the
+ // fee function, hence the budget is used up.
+ // - when a non-fee related error returned from PublishTransaction.
switch result.Event {
case TxFailed:
log.Errorf("Removing monitor record=%v, tx=%v, due to err: %v",
- id, result.Tx.TxHash(), result.Err)
+ id, txid, result.Err)
case TxConfirmed:
- // Remove the record is the tx is confirmed.
+ // Remove the record if the tx is confirmed.
log.Debugf("Removing confirmed monitor record=%v, tx=%v", id,
- result.Tx.TxHash())
+ txid)
+
+ case TxFatal:
+ // Remove the record if there's an error.
+ log.Debugf("Removing monitor record=%v due to fatal err: %v",
+ id, result.Err)
// Do nothing if it's neither failed or confirmed.
default:
@@ -737,20 +795,18 @@ type monitorRecord struct {
// Start starts the publisher by subscribing to block epoch updates and kicking
// off the monitor loop.
-func (t *TxPublisher) Start() error {
+func (t *TxPublisher) Start(beat chainio.Blockbeat) error {
log.Info("TxPublisher starting...")
if t.started.Swap(true) {
return fmt.Errorf("TxPublisher started more than once")
}
- blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil)
- if err != nil {
- return fmt.Errorf("register block epoch ntfn: %w", err)
- }
+ // Set the current height.
+ t.currentHeight.Store(beat.Height())
t.wg.Add(1)
- go t.monitor(blockEvent)
+ go t.monitor()
log.Debugf("TxPublisher started")
@@ -778,33 +834,25 @@ func (t *TxPublisher) Stop() error {
// to be bumped. If so, it will attempt to bump the fee of the tx.
//
// NOTE: Must be run as a goroutine.
-func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) {
- defer blockEvent.Cancel()
+func (t *TxPublisher) monitor() {
defer t.wg.Done()
for {
select {
- case epoch, ok := <-blockEvent.Epochs:
- if !ok {
- // We should stop the publisher before stopping
- // the chain service. Otherwise it indicates an
- // error.
- log.Error("Block epoch channel closed, exit " +
- "monitor")
-
- return
- }
-
- log.Debugf("TxPublisher received new block: %v",
- epoch.Height)
+ case beat := <-t.BlockbeatChan:
+ height := beat.Height()
+ log.Debugf("TxPublisher received new block: %v", height)
// Update the best known height for the publisher.
- t.currentHeight.Store(epoch.Height)
+ t.currentHeight.Store(height)
// Check all monitored txns to see if any of them needs
// to be bumped.
t.processRecords()
+ // Notify we've processed the block.
+ t.NotifyBlockProcessed(beat, nil)
+
case <-t.quit:
log.Debug("Fee bumper stopped, exit monitor")
return
@@ -819,18 +867,27 @@ func (t *TxPublisher) processRecords() {
// confirmed.
confirmedRecords := make(map[uint64]*monitorRecord)
- // feeBumpRecords stores a map of the records which need to be bumped.
+ // feeBumpRecords stores a map of records which need to be bumped.
feeBumpRecords := make(map[uint64]*monitorRecord)
- // failedRecords stores a map of the records which has inputs being
- // spent by a third party.
+ // failedRecords stores a map of records which has inputs being spent
+ // by a third party.
//
// NOTE: this is only used for neutrino backend.
failedRecords := make(map[uint64]*monitorRecord)
+ // initialRecords stores a map of records which are being created and
+ // published for the first time.
+ initialRecords := make(map[uint64]*monitorRecord)
+
// visitor is a helper closure that visits each record and divides them
// into two groups.
visitor := func(requestID uint64, r *monitorRecord) error {
+ if r.tx == nil {
+ initialRecords[requestID] = r
+ return nil
+ }
+
log.Tracef("Checking monitor recordID=%v for tx=%v", requestID,
r.tx.TxHash())
@@ -858,17 +915,20 @@ func (t *TxPublisher) processRecords() {
return nil
}
- // Iterate through all the records and divide them into two groups.
+ // Iterate through all the records and divide them into four groups.
t.records.ForEach(visitor)
+ // Handle the initial broadcast.
+ for requestID, r := range initialRecords {
+ t.handleInitialBroadcast(r, requestID)
+ }
+
// For records that are confirmed, we'll notify the caller about this
// result.
for requestID, r := range confirmedRecords {
- rec := r
-
log.Debugf("Tx=%v is confirmed", r.tx.TxHash())
t.wg.Add(1)
- go t.handleTxConfirmed(rec, requestID)
+ go t.handleTxConfirmed(r, requestID)
}
// Get the current height to be used in the following goroutines.
@@ -876,22 +936,18 @@ func (t *TxPublisher) processRecords() {
// For records that are not confirmed, we perform a fee bump if needed.
for requestID, r := range feeBumpRecords {
- rec := r
-
log.Debugf("Attempting to fee bump Tx=%v", r.tx.TxHash())
t.wg.Add(1)
- go t.handleFeeBumpTx(requestID, rec, currentHeight)
+ go t.handleFeeBumpTx(requestID, r, currentHeight)
}
// For records that are failed, we'll notify the caller about this
// result.
for requestID, r := range failedRecords {
- rec := r
-
log.Debugf("Tx=%v has inputs been spent by a third party, "+
"failing it now", r.tx.TxHash())
t.wg.Add(1)
- go t.handleThirdPartySpent(rec, requestID)
+ go t.handleThirdPartySpent(r, requestID)
}
}
@@ -916,6 +972,96 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) {
t.handleResult(result)
}
+// handleInitialTxError takes the error from `initializeTx` and decides the
+// bump event. It will construct a BumpResult and handles it.
+func (t *TxPublisher) handleInitialTxError(requestID uint64, err error) {
+ // We now decide what type of event to send.
+ var event BumpEvent
+
+ switch {
+ // When the error is due to a dust output, we'll send a TxFailed so
+ // these inputs can be retried with a different group in the next
+ // block.
+ case errors.Is(err, ErrTxNoOutput):
+ event = TxFailed
+
+ // When the error is due to budget being used up, we'll send a TxFailed
+ // so these inputs can be retried with a different group in the next
+ // block.
+ case errors.Is(err, ErrMaxPosition):
+ event = TxFailed
+
+ // When the error is due to zero fee rate delta, we'll send a TxFailed
+ // so these inputs can be retried in the next block.
+ case errors.Is(err, ErrZeroFeeRateDelta):
+ event = TxFailed
+
+ // Otherwise this is not a fee-related error and the tx cannot be
+ // retried. In that case we will fail ALL the inputs in this tx, which
+ // means they will be removed from the sweeper and never be tried
+ // again.
+ //
+ // TODO(yy): Find out which input is causing the failure and fail that
+ // one only.
+ default:
+ event = TxFatal
+ }
+
+ result := &BumpResult{
+ Event: event,
+ Err: err,
+ requestID: requestID,
+ }
+
+ t.handleResult(result)
+}
+
+// handleInitialBroadcast is called when a new request is received. It will
+// handle the initial tx creation and broadcast. In details,
+// 1. init a fee function based on the given strategy.
+// 2. create an RBF-compliant tx and monitor it for confirmation.
+// 3. notify the initial broadcast result back to the caller.
+func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
+ requestID uint64) {
+
+ log.Debugf("Initial broadcast for requestID=%v", requestID)
+
+ var (
+ result *BumpResult
+ err error
+ )
+
+ // Attempt an initial broadcast which is guaranteed to comply with the
+ // RBF rules.
+ //
+ // Create the initial tx to be broadcasted.
+ err = t.initializeTx(requestID, r.req)
+ if err != nil {
+ log.Errorf("Initial broadcast failed: %v", err)
+
+ // We now handle the initialization error and exit.
+ t.handleInitialTxError(requestID, err)
+
+ return
+ }
+
+ // Successfully created the first tx, now broadcast it.
+ result, err = t.broadcast(requestID)
+ if err != nil {
+ // The broadcast failed, which can only happen if the tx record
+ // cannot be found or the aux sweeper returns an error. In
+ // either case, we will send back a TxFail event so these
+ // inputs can be retried.
+ result = &BumpResult{
+ Event: TxFailed,
+ Err: err,
+ requestID: requestID,
+ }
+ }
+
+ t.handleResult(result)
+}
+
// handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will
// attempt to bump the fee of the tx.
//
@@ -1382,7 +1528,7 @@ func prepareSweepTx(inputs []input.Input, changePkScript lnwallet.AddrWithKey,
return err
}
- extraChangeOut = extraOut.LeftToOption()
+ extraChangeOut = extraOut.LeftToSome()
return nil
},
diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go
index 5030dee227..54c67dbe28 100644
--- a/sweep/fee_bumper_test.go
+++ b/sweep/fee_bumper_test.go
@@ -11,7 +11,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/chain"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
@@ -91,6 +91,12 @@ func TestBumpResultValidate(t *testing.T) {
}
require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult)
+ // A fatal event without a failure reason will give an error.
+ b = BumpResult{
+ Event: TxFailed,
+ }
+ require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult)
+
// A confirmed event without fee info will give an error.
b = BumpResult{
Tx: &wire.MsgTx{},
@@ -104,6 +110,20 @@ func TestBumpResultValidate(t *testing.T) {
Event: TxPublished,
}
require.NoError(t, b.Validate())
+
+ // Tx is allowed to be nil in a TxFailed event.
+ b = BumpResult{
+ Event: TxFailed,
+ Err: errDummy,
+ }
+ require.NoError(t, b.Validate())
+
+ // Tx is allowed to be nil in a TxFatal event.
+ b = BumpResult{
+ Event: TxFatal,
+ Err: errDummy,
+ }
+ require.NoError(t, b.Validate())
}
// TestCalcSweepTxWeight checks that the weight of the sweep tx is calculated
@@ -332,13 +352,10 @@ func TestStoreRecord(t *testing.T) {
}
// Call the method under test.
- requestID := tp.storeRecord(tx, req, feeFunc, fee, utxoIndex)
-
- // Check the request ID is as expected.
- require.Equal(t, initialCounter+1, requestID)
+ tp.storeRecord(initialCounter, tx, req, feeFunc, fee, utxoIndex)
// Read the saved record and compare.
- record, ok := tp.records.Load(requestID)
+ record, ok := tp.records.Load(initialCounter)
require.True(t, ok)
require.Equal(t, tx, record.tx)
require.Equal(t, feeFunc, record.feeFunction)
@@ -635,23 +652,19 @@ func TestCreateRBFCompliantTx(t *testing.T) {
},
}
+ var requestCounter atomic.Uint64
for _, tc := range testCases {
tc := tc
+ rid := requestCounter.Add(1)
t.Run(tc.name, func(t *testing.T) {
tc.setupMock()
// Call the method under test.
- id, err := tp.createRBFCompliantTx(req, m.feeFunc)
+ err := tp.createRBFCompliantTx(rid, req, m.feeFunc)
// Check the result is as expected.
require.ErrorIs(t, err, tc.expectedErr)
-
- // If there's an error, expect the requestID to be
- // empty.
- if tc.expectedErr != nil {
- require.Zero(t, id)
- }
})
}
}
@@ -684,7 +697,8 @@ func TestTxPublisherBroadcast(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
- requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex)
+ requestID := uint64(1)
+ tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Quickly check when the requestID cannot be found, an error is
// returned.
@@ -779,6 +793,9 @@ func TestRemoveResult(t *testing.T) {
op: 0,
}
+ // Create a test request ID counter.
+ requestCounter := atomic.Uint64{}
+
testCases := []struct {
name string
setupRecord func() uint64
@@ -790,12 +807,13 @@ func TestRemoveResult(t *testing.T) {
// removed.
name: "remove on TxConfirmed",
setupRecord: func() uint64 {
- id := tp.storeRecord(
- tx, req, m.feeFunc, fee, utxoIndex,
+ rid := requestCounter.Add(1)
+ tp.storeRecord(
+ rid, tx, req, m.feeFunc, fee, utxoIndex,
)
- tp.subscriberChans.Store(id, nil)
+ tp.subscriberChans.Store(rid, nil)
- return id
+ return rid
},
result: &BumpResult{
Event: TxConfirmed,
@@ -807,12 +825,13 @@ func TestRemoveResult(t *testing.T) {
// When the tx is failed, the records will be removed.
name: "remove on TxFailed",
setupRecord: func() uint64 {
- id := tp.storeRecord(
- tx, req, m.feeFunc, fee, utxoIndex,
+ rid := requestCounter.Add(1)
+ tp.storeRecord(
+ rid, tx, req, m.feeFunc, fee, utxoIndex,
)
- tp.subscriberChans.Store(id, nil)
+ tp.subscriberChans.Store(rid, nil)
- return id
+ return rid
},
result: &BumpResult{
Event: TxFailed,
@@ -825,12 +844,13 @@ func TestRemoveResult(t *testing.T) {
// Noop when the tx is neither confirmed or failed.
name: "noop when tx is not confirmed or failed",
setupRecord: func() uint64 {
- id := tp.storeRecord(
- tx, req, m.feeFunc, fee, utxoIndex,
+ rid := requestCounter.Add(1)
+ tp.storeRecord(
+ rid, tx, req, m.feeFunc, fee, utxoIndex,
)
- tp.subscriberChans.Store(id, nil)
+ tp.subscriberChans.Store(rid, nil)
- return id
+ return rid
},
result: &BumpResult{
Event: TxPublished,
@@ -885,7 +905,8 @@ func TestNotifyResult(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
- requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex)
+ requestID := uint64(1)
+ tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Create a subscription to the event.
subscriber := make(chan *BumpResult, 1)
@@ -933,41 +954,17 @@ func TestNotifyResult(t *testing.T) {
}
}
-// TestBroadcastSuccess checks the public `Broadcast` method can successfully
-// broadcast a tx based on the request.
-func TestBroadcastSuccess(t *testing.T) {
+// TestBroadcast checks the public `Broadcast` method can successfully register
+// a broadcast request.
+func TestBroadcast(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
- tp, m := createTestPublisher(t)
+ tp, _ := createTestPublisher(t)
// Create a test feerate.
feerate := chainfee.SatPerKWeight(1000)
- // Mock the fee estimator to return the testing fee rate.
- //
- // We are not testing `NewLinearFeeFunction` here, so the actual params
- // used are irrelevant.
- m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
- feerate, nil).Once()
- m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
-
- // Mock the signer to always return a valid script.
- //
- // NOTE: we are not testing the utility of creating valid txes here, so
- // this is fine to be mocked. This behaves essentially as skipping the
- // Signer check and alaways assume the tx has a valid sig.
- script := &input.Script{}
- m.signer.On("ComputeInputScript", mock.Anything,
- mock.Anything).Return(script, nil)
-
- // Mock the testmempoolaccept to pass.
- m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
-
- // Mock the wallet to publish successfully.
- m.wallet.On("PublishTransaction",
- mock.Anything, mock.Anything).Return(nil).Once()
-
// Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash)
@@ -981,27 +978,24 @@ func TestBroadcastSuccess(t *testing.T) {
}
// Send the req and expect no error.
- resultChan, err := tp.Broadcast(req)
- require.NoError(t, err)
-
- // Check the result is sent back.
- select {
- case <-time.After(time.Second):
- t.Fatal("timeout waiting for subscriber to receive result")
-
- case result := <-resultChan:
- // We expect the first result to be TxPublished.
- require.Equal(t, TxPublished, result.Event)
- }
+ resultChan := tp.Broadcast(req)
+ require.NotNil(t, resultChan)
// Validate the record was stored.
require.Equal(t, 1, tp.records.Len())
require.Equal(t, 1, tp.subscriberChans.Len())
+
+ // Validate the record.
+ rid := tp.requestCounter.Load()
+ record, found := tp.records.Load(rid)
+ require.True(t, found)
+ require.Equal(t, req, record.req)
}
-// TestBroadcastFail checks the public `Broadcast` returns the error or a
-// failed result when the broadcast fails.
-func TestBroadcastFail(t *testing.T) {
+// TestBroadcastImmediate checks the public `Broadcast` method can successfully
+// register a broadcast request and publish the tx when `Immediate` flag is
+// set.
+func TestBroadcastImmediate(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
@@ -1020,64 +1014,27 @@ func TestBroadcastFail(t *testing.T) {
Budget: btcutil.Amount(1000),
MaxFeeRate: feerate * 10,
DeadlineHeight: 10,
+ Immediate: true,
}
- // Mock the fee estimator to return the testing fee rate.
+ // Mock the fee estimator to return an error.
//
- // We are not testing `NewLinearFeeFunction` here, so the actual params
- // used are irrelevant.
+ // NOTE: We are not testing `handleInitialBroadcast` here, but only
+ // interested in checking that this method is indeed called when
+ // `Immediate` is true. Thus we mock the method to return an error to
+ // quickly abort. As long as this mocked method is called, we know the
+ // `Immediate` flag works.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
- feerate, nil).Twice()
- m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
+ chainfee.SatPerKWeight(0), errDummy).Once()
- // Mock the signer to always return a valid script.
- //
- // NOTE: we are not testing the utility of creating valid txes here, so
- // this is fine to be mocked. This behaves essentially as skipping the
- // Signer check and alaways assume the tx has a valid sig.
- script := &input.Script{}
- m.signer.On("ComputeInputScript", mock.Anything,
- mock.Anything).Return(script, nil)
-
- // Mock the testmempoolaccept to return an error.
- m.wallet.On("CheckMempoolAcceptance",
- mock.Anything).Return(errDummy).Once()
-
- // Send the req and expect an error returned.
- resultChan, err := tp.Broadcast(req)
- require.ErrorIs(t, err, errDummy)
- require.Nil(t, resultChan)
-
- // Validate the record was NOT stored.
- require.Equal(t, 0, tp.records.Len())
- require.Equal(t, 0, tp.subscriberChans.Len())
-
- // Mock the testmempoolaccept again, this time it passes.
- m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
-
- // Mock the wallet to fail on publish.
- m.wallet.On("PublishTransaction",
- mock.Anything, mock.Anything).Return(errDummy).Once()
-
- // Send the req and expect no error returned.
- resultChan, err = tp.Broadcast(req)
- require.NoError(t, err)
-
- // Check the result is sent back.
- select {
- case <-time.After(time.Second):
- t.Fatal("timeout waiting for subscriber to receive result")
-
- case result := <-resultChan:
- // We expect the result to be TxFailed and the error is set in
- // the result.
- require.Equal(t, TxFailed, result.Event)
- require.ErrorIs(t, result.Err, errDummy)
- }
+ // Send the req and expect no error.
+ resultChan := tp.Broadcast(req)
+ require.NotNil(t, resultChan)
- // Validate the record was removed.
- require.Equal(t, 0, tp.records.Len())
- require.Equal(t, 0, tp.subscriberChans.Len())
+ // Validate the record was removed due to an error returned in initial
+ // broadcast.
+ require.Empty(t, tp.records.Len())
+ require.Empty(t, tp.subscriberChans.Len())
}
// TestCreateAnPublishFail checks all the error cases are handled properly in
@@ -1250,7 +1207,8 @@ func TestHandleTxConfirmed(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
- requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex)
+ requestID := uint64(1)
+ tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
record, ok := tp.records.Load(requestID)
require.True(t, ok)
@@ -1330,7 +1288,8 @@ func TestHandleFeeBumpTx(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
- requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex)
+ requestID := uint64(1)
+ tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Create a subscription to the event.
subscriber := make(chan *BumpResult, 1)
@@ -1531,3 +1490,183 @@ func TestProcessRecords(t *testing.T) {
require.Equal(t, requestID2, result.requestID)
}
}
+
+// TestHandleInitialBroadcastSuccess checks `handleInitialBroadcast` method can
+// successfully broadcast a tx based on the request.
+func TestHandleInitialBroadcastSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Create a publisher using the mocks.
+ tp, m := createTestPublisher(t)
+
+ // Create a test feerate.
+ feerate := chainfee.SatPerKWeight(1000)
+
+ // Mock the fee estimator to return the testing fee rate.
+ //
+ // We are not testing `NewLinearFeeFunction` here, so the actual params
+ // used are irrelevant.
+ m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
+ feerate, nil).Once()
+ m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
+
+ // Mock the signer to always return a valid script.
+ //
+ // NOTE: we are not testing the utility of creating valid txes here, so
+ // this is fine to be mocked. This behaves essentially as skipping the
+ // Signer check and alaways assume the tx has a valid sig.
+ script := &input.Script{}
+ m.signer.On("ComputeInputScript", mock.Anything,
+ mock.Anything).Return(script, nil)
+
+ // Mock the testmempoolaccept to pass.
+ m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
+
+ // Mock the wallet to publish successfully.
+ m.wallet.On("PublishTransaction",
+ mock.Anything, mock.Anything).Return(nil).Once()
+
+ // Create a test request.
+ inp := createTestInput(1000, input.WitnessKeyHash)
+
+ // Create a testing bump request.
+ req := &BumpRequest{
+ DeliveryAddress: changePkScript,
+ Inputs: []input.Input{&inp},
+ Budget: btcutil.Amount(1000),
+ MaxFeeRate: feerate * 10,
+ DeadlineHeight: 10,
+ }
+
+ // Register the testing record use `Broadcast`.
+ resultChan := tp.Broadcast(req)
+
+ // Grab the monitor record from the map.
+ rid := tp.requestCounter.Load()
+ rec, ok := tp.records.Load(rid)
+ require.True(t, ok)
+
+ // Call the method under test.
+ tp.wg.Add(1)
+ tp.handleInitialBroadcast(rec, rid)
+
+ // Check the result is sent back.
+ select {
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for subscriber to receive result")
+
+ case result := <-resultChan:
+ // We expect the first result to be TxPublished.
+ require.Equal(t, TxPublished, result.Event)
+ }
+
+ // Validate the record was stored.
+ require.Equal(t, 1, tp.records.Len())
+ require.Equal(t, 1, tp.subscriberChans.Len())
+}
+
+// TestHandleInitialBroadcastFail checks `handleInitialBroadcast` returns the
+// error or a failed result when the broadcast fails.
+func TestHandleInitialBroadcastFail(t *testing.T) {
+ t.Parallel()
+
+ // Create a publisher using the mocks.
+ tp, m := createTestPublisher(t)
+
+ // Create a test feerate.
+ feerate := chainfee.SatPerKWeight(1000)
+
+ // Create a test request.
+ inp := createTestInput(1000, input.WitnessKeyHash)
+
+ // Create a testing bump request.
+ req := &BumpRequest{
+ DeliveryAddress: changePkScript,
+ Inputs: []input.Input{&inp},
+ Budget: btcutil.Amount(1000),
+ MaxFeeRate: feerate * 10,
+ DeadlineHeight: 10,
+ }
+
+ // Mock the fee estimator to return the testing fee rate.
+ //
+ // We are not testing `NewLinearFeeFunction` here, so the actual params
+ // used are irrelevant.
+ m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
+ feerate, nil).Twice()
+ m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
+
+ // Mock the signer to always return a valid script.
+ //
+ // NOTE: we are not testing the utility of creating valid txes here, so
+ // this is fine to be mocked. This behaves essentially as skipping the
+ // Signer check and alaways assume the tx has a valid sig.
+ script := &input.Script{}
+ m.signer.On("ComputeInputScript", mock.Anything,
+ mock.Anything).Return(script, nil)
+
+ // Mock the testmempoolaccept to return an error.
+ m.wallet.On("CheckMempoolAcceptance",
+ mock.Anything).Return(errDummy).Once()
+
+ // Register the testing record use `Broadcast`.
+ resultChan := tp.Broadcast(req)
+
+ // Grab the monitor record from the map.
+ rid := tp.requestCounter.Load()
+ rec, ok := tp.records.Load(rid)
+ require.True(t, ok)
+
+ // Call the method under test and expect an error returned.
+ tp.wg.Add(1)
+ tp.handleInitialBroadcast(rec, rid)
+
+ // Check the result is sent back.
+ select {
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for subscriber to receive result")
+
+ case result := <-resultChan:
+ // We expect the first result to be TxFatal.
+ require.Equal(t, TxFatal, result.Event)
+ }
+
+ // Validate the record was NOT stored.
+ require.Equal(t, 0, tp.records.Len())
+ require.Equal(t, 0, tp.subscriberChans.Len())
+
+ // Mock the testmempoolaccept again, this time it passes.
+ m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
+
+ // Mock the wallet to fail on publish.
+ m.wallet.On("PublishTransaction",
+ mock.Anything, mock.Anything).Return(errDummy).Once()
+
+ // Register the testing record use `Broadcast`.
+ resultChan = tp.Broadcast(req)
+
+ // Grab the monitor record from the map.
+ rid = tp.requestCounter.Load()
+ rec, ok = tp.records.Load(rid)
+ require.True(t, ok)
+
+ // Call the method under test.
+ tp.wg.Add(1)
+ tp.handleInitialBroadcast(rec, rid)
+
+ // Check the result is sent back.
+ select {
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for subscriber to receive result")
+
+ case result := <-resultChan:
+ // We expect the result to be TxFailed and the error is set in
+ // the result.
+ require.Equal(t, TxFailed, result.Event)
+ require.ErrorIs(t, result.Err, errDummy)
+ }
+
+ // Validate the record was removed.
+ require.Equal(t, 0, tp.records.Len())
+ require.Equal(t, 0, tp.subscriberChans.Len())
+}
diff --git a/sweep/fee_function.go b/sweep/fee_function.go
index cbf283e37d..eb2ed4d6b1 100644
--- a/sweep/fee_function.go
+++ b/sweep/fee_function.go
@@ -5,7 +5,7 @@ import (
"fmt"
"github.com/btcsuite/btcd/btcutil"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
)
@@ -14,6 +14,9 @@ var (
// ErrMaxPosition is returned when trying to increase the position of
// the fee function while it's already at its max.
ErrMaxPosition = errors.New("position already at max")
+
+ // ErrZeroFeeRateDelta is returned when the fee rate delta is zero.
+ ErrZeroFeeRateDelta = errors.New("fee rate delta is zero")
)
// mSatPerKWeight represents a fee rate in msat/kw.
@@ -169,7 +172,7 @@ func NewLinearFeeFunction(maxFeeRate chainfee.SatPerKWeight,
"endingFeeRate=%v, width=%v, delta=%v", start, end,
l.width, l.deltaFeeRate)
- return nil, fmt.Errorf("fee rate delta is zero")
+ return nil, ErrZeroFeeRateDelta
}
// Attach the calculated values to the fee function.
diff --git a/sweep/fee_function_test.go b/sweep/fee_function_test.go
index c278bb7f06..a55ce79a78 100644
--- a/sweep/fee_function_test.go
+++ b/sweep/fee_function_test.go
@@ -3,7 +3,7 @@ package sweep
import (
"testing"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/stretchr/testify/require"
)
diff --git a/sweep/interface.go b/sweep/interface.go
index f2fff84b08..6c8c2cfad2 100644
--- a/sweep/interface.go
+++ b/sweep/interface.go
@@ -4,7 +4,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
diff --git a/sweep/mock_test.go b/sweep/mock_test.go
index 34202b1453..f9471f22a0 100644
--- a/sweep/mock_test.go
+++ b/sweep/mock_test.go
@@ -4,7 +4,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
@@ -268,6 +268,13 @@ func (m *MockInputSet) StartingFeeRate() fn.Option[chainfee.SatPerKWeight] {
return args.Get(0).(fn.Option[chainfee.SatPerKWeight])
}
+// Immediate returns whether the inputs should be swept immediately.
+func (m *MockInputSet) Immediate() bool {
+ args := m.Called()
+
+ return args.Bool(0)
+}
+
// MockBumper is a mock implementation of the interface Bumper.
type MockBumper struct {
mock.Mock
@@ -277,14 +284,14 @@ type MockBumper struct {
var _ Bumper = (*MockBumper)(nil)
// Broadcast broadcasts the transaction to the network.
-func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) {
+func (m *MockBumper) Broadcast(req *BumpRequest) <-chan *BumpResult {
args := m.Called(req)
if args.Get(0) == nil {
- return nil, args.Error(1)
+ return nil
}
- return args.Get(0).(chan *BumpResult), args.Error(1)
+ return args.Get(0).(chan *BumpResult)
}
// MockFeeFunction is a mock implementation of the FeeFunction interface.
diff --git a/sweep/sweeper.go b/sweep/sweeper.go
index 6257faac1f..d49f104af0 100644
--- a/sweep/sweeper.go
+++ b/sweep/sweeper.go
@@ -10,8 +10,9 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
+ "github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
@@ -222,6 +223,35 @@ func (p *SweeperInput) terminated() bool {
}
}
+// isMature returns a boolean indicating whether the input has a timelock that
+// has been reached or not. The locktime found is also returned.
+func (p *SweeperInput) isMature(currentHeight uint32) (bool, uint32) {
+ locktime, _ := p.RequiredLockTime()
+ if currentHeight < locktime {
+ log.Debugf("Input %v has locktime=%v, current height is %v",
+ p.OutPoint(), locktime, currentHeight)
+
+ return false, locktime
+ }
+
+ // If the input has a CSV that's not yet reached, we will skip
+ // this input and wait for the expiry.
+ //
+ // NOTE: We need to consider whether this input can be included in the
+ // next block or not, which means the CSV will be checked against the
+ // currentHeight plus one.
+ locktime = p.BlocksToMaturity() + p.HeightHint()
+ if currentHeight+1 < locktime {
+ log.Debugf("Input %v has CSV expiry=%v, current height is %v, "+
+ "skipped sweeping", p.OutPoint(), locktime,
+ currentHeight)
+
+ return false, locktime
+ }
+
+ return true, locktime
+}
+
// InputsMap is a type alias for a set of pending inputs.
type InputsMap = map[wire.OutPoint]*SweeperInput
@@ -280,6 +310,10 @@ type UtxoSweeper struct {
started uint32 // To be used atomically.
stopped uint32 // To be used atomically.
+ // Embed the blockbeat consumer struct to get access to the method
+ // `NotifyBlockProcessed` and the `BlockbeatChan`.
+ chainio.BeatConsumer
+
cfg *UtxoSweeperConfig
newInputs chan *sweepInputMessage
@@ -309,11 +343,14 @@ type UtxoSweeper struct {
// updated whenever a new block epoch is received.
currentHeight int32
- // bumpResultChan is a channel that receives broadcast results from the
+ // bumpRespChan is a channel that receives broadcast results from the
// TxPublisher.
- bumpResultChan chan *BumpResult
+ bumpRespChan chan *bumpResp
}
+// Compile-time check for the chainio.Consumer interface.
+var _ chainio.Consumer = (*UtxoSweeper)(nil)
+
// UtxoSweeperConfig contains dependencies of UtxoSweeper.
type UtxoSweeperConfig struct {
// GenSweepScript generates a P2WKH script belonging to the wallet where
@@ -387,7 +424,7 @@ type sweepInputMessage struct {
// New returns a new Sweeper instance.
func New(cfg *UtxoSweeperConfig) *UtxoSweeper {
- return &UtxoSweeper{
+ s := &UtxoSweeper{
cfg: cfg,
newInputs: make(chan *sweepInputMessage),
spendChan: make(chan *chainntnfs.SpendDetail),
@@ -395,12 +432,17 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper {
pendingSweepsReqs: make(chan *pendingSweepsReq),
quit: make(chan struct{}),
inputs: make(InputsMap),
- bumpResultChan: make(chan *BumpResult, 100),
+ bumpRespChan: make(chan *bumpResp, 100),
}
+
+ // Mount the block consumer.
+ s.BeatConsumer = chainio.NewBeatConsumer(s.quit, s.Name())
+
+ return s
}
// Start starts the process of constructing and publish sweep txes.
-func (s *UtxoSweeper) Start() error {
+func (s *UtxoSweeper) Start(beat chainio.Blockbeat) error {
if !atomic.CompareAndSwapUint32(&s.started, 0, 1) {
return nil
}
@@ -411,49 +453,12 @@ func (s *UtxoSweeper) Start() error {
// not change from here on.
s.relayFeeRate = s.cfg.FeeEstimator.RelayFeePerKW()
- // We need to register for block epochs and retry sweeping every block.
- // We should get a notification with the current best block immediately
- // if we don't provide any epoch. We'll wait for that in the collector.
- blockEpochs, err := s.cfg.Notifier.RegisterBlockEpochNtfn(nil)
- if err != nil {
- return fmt.Errorf("register block epoch ntfn: %w", err)
- }
+ // Set the current height.
+ s.currentHeight = beat.Height()
// Start sweeper main loop.
s.wg.Add(1)
- go func() {
- defer blockEpochs.Cancel()
- defer s.wg.Done()
-
- s.collector(blockEpochs.Epochs)
-
- // The collector exited and won't longer handle incoming
- // requests. This can happen on shutdown, when the block
- // notifier shuts down before the sweeper and its clients. In
- // order to not deadlock the clients waiting for their requests
- // being handled, we handle them here and immediately return an
- // error. When the sweeper finally is shut down we can exit as
- // the clients will be notified.
- for {
- select {
- case inp := <-s.newInputs:
- inp.resultChan <- Result{
- Err: ErrSweeperShuttingDown,
- }
-
- case req := <-s.pendingSweepsReqs:
- req.errChan <- ErrSweeperShuttingDown
-
- case req := <-s.updateReqs:
- req.responseChan <- &updateResp{
- err: ErrSweeperShuttingDown,
- }
-
- case <-s.quit:
- return
- }
- }
- }()
+ go s.collector()
return nil
}
@@ -480,6 +485,11 @@ func (s *UtxoSweeper) Stop() error {
return nil
}
+// NOTE: part of the `chainio.Consumer` interface.
+func (s *UtxoSweeper) Name() string {
+ return "UtxoSweeper"
+}
+
// SweepInput sweeps inputs back into the wallet. The inputs will be batched and
// swept after the batch time window ends. A custom fee preference can be
// provided to determine what fee rate should be used for the input. Note that
@@ -502,7 +512,7 @@ func (s *UtxoSweeper) SweepInput(inp input.Input,
}
absoluteTimeLock, _ := inp.RequiredLockTime()
- log.Infof("Sweep request received: out_point=%v, witness_type=%v, "+
+ log.Debugf("Sweep request received: out_point=%v, witness_type=%v, "+
"relative_time_lock=%v, absolute_time_lock=%v, amount=%v, "+
"parent=(%v), params=(%v)", inp.OutPoint(), inp.WitnessType(),
inp.BlocksToMaturity(), absoluteTimeLock,
@@ -611,17 +621,8 @@ func (s *UtxoSweeper) removeConflictSweepDescendants(
// collector is the sweeper main loop. It processes new inputs, spend
// notifications and counts down to publication of the sweep tx.
-func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
- // We registered for the block epochs with a nil request. The notifier
- // should send us the current best block immediately. So we need to wait
- // for it here because we need to know the current best height.
- select {
- case bestBlock := <-blockEpochs:
- s.currentHeight = bestBlock.Height
-
- case <-s.quit:
- return
- }
+func (s *UtxoSweeper) collector() {
+ defer s.wg.Done()
for {
// Clean inputs, which will remove inputs that are swept,
@@ -681,9 +682,9 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
s.sweepPendingInputs(inputs)
}
- case result := <-s.bumpResultChan:
+ case resp := <-s.bumpRespChan:
// Handle the bump event.
- err := s.handleBumpEvent(result)
+ err := s.handleBumpEvent(resp)
if err != nil {
log.Errorf("Failed to handle bump event: %v",
err)
@@ -691,28 +692,33 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
// A new block comes in, update the bestHeight, perform a check
// over all pending inputs and publish sweeping txns if needed.
- case epoch, ok := <-blockEpochs:
- if !ok {
- // We should stop the sweeper before stopping
- // the chain service. Otherwise it indicates an
- // error.
- log.Error("Block epoch channel closed")
-
- return
- }
-
+ case beat := <-s.BlockbeatChan:
// Update the sweeper to the best height.
- s.currentHeight = epoch.Height
+ s.currentHeight = beat.Height()
// Update the inputs with the latest height.
inputs := s.updateSweeperInputs()
log.Debugf("Received new block: height=%v, attempt "+
- "sweeping %d inputs", epoch.Height, len(inputs))
+ "sweeping %d inputs:\n%s",
+ s.currentHeight, len(inputs),
+ lnutils.NewLogClosure(func() string {
+ inps := make(
+ []input.Input, 0, len(inputs),
+ )
+ for _, in := range inputs {
+ inps = append(inps, in)
+ }
+
+ return inputTypeSummary(inps)
+ }))
// Attempt to sweep any pending inputs.
s.sweepPendingInputs(inputs)
+ // Notify we've processed the block.
+ s.NotifyBlockProcessed(beat, nil)
+
case <-s.quit:
return
}
@@ -827,6 +833,7 @@ func (s *UtxoSweeper) sweep(set InputSet) error {
DeliveryAddress: sweepAddr,
MaxFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(),
StartingFeeRate: set.StartingFeeRate(),
+ Immediate: set.Immediate(),
// TODO(yy): pass the strategy here.
}
@@ -837,27 +844,13 @@ func (s *UtxoSweeper) sweep(set InputSet) error {
// Broadcast will return a read-only chan that we will listen to for
// this publish result and future RBF attempt.
- resp, err := s.cfg.Publisher.Broadcast(req)
- if err != nil {
- outpoints := make([]wire.OutPoint, len(set.Inputs()))
- for i, inp := range set.Inputs() {
- outpoints[i] = inp.OutPoint()
- }
-
- log.Errorf("Initial broadcast failed: %v, inputs=\n%v", err,
- inputTypeSummary(set.Inputs()))
-
- // TODO(yy): find out which input is causing the failure.
- s.markInputsPublishFailed(outpoints)
-
- return err
- }
+ resp := s.cfg.Publisher.Broadcast(req)
// Successfully sent the broadcast attempt, we now handle the result by
// subscribing to the result chan and listen for future updates about
// this tx.
s.wg.Add(1)
- go s.monitorFeeBumpResult(resp)
+ go s.monitorFeeBumpResult(set, resp)
return nil
}
@@ -867,14 +860,14 @@ func (s *UtxoSweeper) sweep(set InputSet) error {
func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) {
// Reschedule sweep.
for _, input := range set.Inputs() {
- pi, ok := s.inputs[input.OutPoint()]
+ op := input.OutPoint()
+ pi, ok := s.inputs[op]
if !ok {
// It could be that this input is an additional wallet
// input that was attached. In that case there also
// isn't a pending input to update.
log.Tracef("Skipped marking input as pending "+
- "published: %v not found in pending inputs",
- input.OutPoint())
+ "published: %v not found in pending inputs", op)
continue
}
@@ -885,8 +878,7 @@ func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) {
// publish.
if pi.terminated() {
log.Errorf("Expect input %v to not have terminated "+
- "state, instead it has %v",
- input.OutPoint, pi.state)
+ "state, instead it has %v", op, pi.state)
continue
}
@@ -901,9 +893,7 @@ func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) {
// markInputsPublished updates the sweeping tx in db and marks the list of
// inputs as published.
-func (s *UtxoSweeper) markInputsPublished(tr *TxRecord,
- inputs []*wire.TxIn) error {
-
+func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, set InputSet) error {
// Mark this tx in db once successfully published.
//
// NOTE: this will behave as an overwrite, which is fine as the record
@@ -915,15 +905,15 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord,
}
// Reschedule sweep.
- for _, input := range inputs {
- pi, ok := s.inputs[input.PreviousOutPoint]
+ for _, input := range set.Inputs() {
+ op := input.OutPoint()
+ pi, ok := s.inputs[op]
if !ok {
// It could be that this input is an additional wallet
// input that was attached. In that case there also
// isn't a pending input to update.
log.Tracef("Skipped marking input as published: %v "+
- "not found in pending inputs",
- input.PreviousOutPoint)
+ "not found in pending inputs", op)
continue
}
@@ -932,8 +922,7 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord,
if pi.state != PendingPublish {
// We may get a Published if this is a replacement tx.
log.Debugf("Expect input %v to have %v, instead it "+
- "has %v", input.PreviousOutPoint,
- PendingPublish, pi.state)
+ "has %v", op, PendingPublish, pi.state)
continue
}
@@ -949,9 +938,10 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord,
}
// markInputsPublishFailed marks the list of inputs as failed to be published.
-func (s *UtxoSweeper) markInputsPublishFailed(outpoints []wire.OutPoint) {
+func (s *UtxoSweeper) markInputsPublishFailed(set InputSet) {
// Reschedule sweep.
- for _, op := range outpoints {
+ for _, inp := range set.Inputs() {
+ op := inp.OutPoint()
pi, ok := s.inputs[op]
if !ok {
// It could be that this input is an additional wallet
@@ -1054,6 +1044,12 @@ func (s *UtxoSweeper) handlePendingSweepsReq(
resps := make(map[wire.OutPoint]*PendingInputResponse, len(s.inputs))
for _, inp := range s.inputs {
+ // Skip immature inputs for compatibility.
+ mature, _ := inp.isMature(uint32(s.currentHeight))
+ if !mature {
+ continue
+ }
+
// Only the exported fields are set, as we expect the response
// to only be consumed externally.
op := inp.OutPoint()
@@ -1189,17 +1185,34 @@ func (s *UtxoSweeper) mempoolLookup(op wire.OutPoint) fn.Option[wire.MsgTx] {
return s.cfg.Mempool.LookupInputMempoolSpend(op)
}
-// handleNewInput processes a new input by registering spend notification and
-// scheduling sweeping for it.
-func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error {
+// calculateDefaultDeadline calculates the default deadline height for a sweep
+// request that has no deadline height specified.
+func (s *UtxoSweeper) calculateDefaultDeadline(pi *SweeperInput) int32 {
// Create a default deadline height, which will be used when there's no
// DeadlineHeight specified for a given input.
defaultDeadline := s.currentHeight + int32(s.cfg.NoDeadlineConfTarget)
+ // If the input is immature and has a locktime, we'll use the locktime
+ // height as the starting height.
+ matured, locktime := pi.isMature(uint32(s.currentHeight))
+ if !matured {
+ defaultDeadline = int32(locktime + s.cfg.NoDeadlineConfTarget)
+ log.Debugf("Input %v is immature, using locktime=%v instead "+
+ "of current height=%d as starting height",
+ pi.OutPoint(), locktime, s.currentHeight)
+ }
+
+ return defaultDeadline
+}
+
+// handleNewInput processes a new input by registering spend notification and
+// scheduling sweeping for it.
+func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error {
outpoint := input.input.OutPoint()
pi, pending := s.inputs[outpoint]
if pending {
- log.Debugf("Already has pending input %v received", outpoint)
+ log.Infof("Already has pending input %v received, old params: "+
+ "%v, new params %v", outpoint, pi.params, input.params)
s.handleExistingInput(input, pi)
@@ -1220,15 +1233,22 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error {
Input: input.input,
params: input.params,
rbf: rbfInfo,
- // Set the acutal deadline height.
- DeadlineHeight: input.params.DeadlineHeight.UnwrapOr(
- defaultDeadline,
- ),
}
+ // Set the acutal deadline height.
+ pi.DeadlineHeight = input.params.DeadlineHeight.UnwrapOr(
+ s.calculateDefaultDeadline(pi),
+ )
+
s.inputs[outpoint] = pi
log.Tracef("input %v, state=%v, added to inputs", outpoint, pi.state)
+ log.Infof("Registered sweep request at block %d: out_point=%v, "+
+ "witness_type=%v, amount=%v, deadline=%d, params=(%v)",
+ s.currentHeight, pi.OutPoint(), pi.WitnessType(),
+ btcutil.Amount(pi.SignDesc().Output.Value), pi.DeadlineHeight,
+ pi.params)
+
// Start watching for spend of this input, either by us or the remote
// party.
cancel, err := s.monitorSpend(
@@ -1457,11 +1477,6 @@ func (s *UtxoSweeper) markInputFailed(pi *SweeperInput, err error) {
pi.state = Failed
- // Remove all other inputs in this exclusive group.
- if pi.params.ExclusiveGroup != nil {
- s.removeExclusiveGroup(*pi.params.ExclusiveGroup)
- }
-
s.signalResult(pi, Result{Err: err})
}
@@ -1479,6 +1494,8 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap {
// turn this inputs map into a SyncMap in case we wanna add concurrent
// access to the map in the future.
for op, input := range s.inputs {
+ log.Tracef("Checking input: %s, state=%v", input, input.state)
+
// If the input has reached a final state, that it's either
// been swept, or failed, or excluded, we will remove it from
// our sweeper.
@@ -1506,20 +1523,9 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap {
// If the input has a locktime that's not yet reached, we will
// skip this input and wait for the locktime to be reached.
- locktime, _ := input.RequiredLockTime()
- if uint32(s.currentHeight) < locktime {
- log.Warnf("Skipping input %v due to locktime=%v not "+
- "reached, current height is %v", op, locktime,
- s.currentHeight)
-
- continue
- }
-
- // If the input has a CSV that's not yet reached, we will skip
- // this input and wait for the expiry.
- locktime = input.BlocksToMaturity() + input.HeightHint()
- if s.currentHeight < int32(locktime)-1 {
- log.Infof("Skipping input %v due to CSV expiry=%v not "+
+ mature, locktime := input.isMature(uint32(s.currentHeight))
+ if !mature {
+ log.Debugf("Skipping input %v due to locktime=%v not "+
"reached, current height is %v", op, locktime,
s.currentHeight)
@@ -1539,6 +1545,8 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap {
// sweepPendingInputs is called when the ticker fires. It will create clusters
// and attempt to create and publish the sweeping transactions.
func (s *UtxoSweeper) sweepPendingInputs(inputs InputsMap) {
+ log.Debugf("Sweeping %v inputs", len(inputs))
+
// Cluster all of our inputs based on the specific Aggregator.
sets := s.cfg.Aggregator.ClusterInputs(inputs)
@@ -1580,11 +1588,24 @@ func (s *UtxoSweeper) sweepPendingInputs(inputs InputsMap) {
}
}
+// bumpResp wraps the result of a bump attempt returned from the fee bumper and
+// the inputs being used.
+type bumpResp struct {
+ // result is the result of the bump attempt returned from the fee
+ // bumper.
+ result *BumpResult
+
+ // set is the input set that was used in the bump attempt.
+ set InputSet
+}
+
// monitorFeeBumpResult subscribes to the passed result chan to listen for
// future updates about the sweeping tx.
//
// NOTE: must run as a goroutine.
-func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) {
+func (s *UtxoSweeper) monitorFeeBumpResult(set InputSet,
+ resultChan <-chan *BumpResult) {
+
defer s.wg.Done()
for {
@@ -1596,9 +1617,14 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) {
continue
}
+ resp := &bumpResp{
+ result: r,
+ set: set,
+ }
+
// Send the result back to the main event loop.
select {
- case s.bumpResultChan <- r:
+ case s.bumpRespChan <- resp:
case <-s.quit:
log.Debug("Sweeper shutting down, skip " +
"sending bump result")
@@ -1613,6 +1639,14 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) {
// in sweeper and rely solely on this event to mark
// inputs as Swept?
if r.Event == TxConfirmed || r.Event == TxFailed {
+ // Exit if the tx is failed to be created.
+ if r.Tx == nil {
+ log.Debugf("Received %v for nil tx, "+
+ "exit monitor", r.Event)
+
+ return
+ }
+
log.Debugf("Received %v for sweep tx %v, exit "+
"fee bump monitor", r.Event,
r.Tx.TxHash())
@@ -1634,25 +1668,28 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) {
// handleBumpEventTxFailed handles the case where the tx has been failed to
// publish.
-func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error {
+func (s *UtxoSweeper) handleBumpEventTxFailed(resp *bumpResp) {
+ r := resp.result
tx, err := r.Tx, r.Err
- log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err)
-
- outpoints := make([]wire.OutPoint, 0, len(tx.TxIn))
- for _, inp := range tx.TxIn {
- outpoints = append(outpoints, inp.PreviousOutPoint)
+ if tx != nil {
+ log.Warnf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(),
+ err)
}
+ // NOTE: When marking the inputs as failed, we are using the input set
+ // instead of the inputs found in the tx. This is fine for current
+ // version of the sweeper because we always create a tx using ALL of
+ // the inputs specified by the set.
+ //
// TODO(yy): should we also remove the failed tx from db?
- s.markInputsPublishFailed(outpoints)
-
- return err
+ s.markInputsPublishFailed(resp.set)
}
// handleBumpEventTxReplaced handles the case where the sweeping tx has been
// replaced by a new one.
-func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error {
+func (s *UtxoSweeper) handleBumpEventTxReplaced(resp *bumpResp) error {
+ r := resp.result
oldTx := r.ReplacedTx
newTx := r.Tx
@@ -1692,12 +1729,13 @@ func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error {
}
// Mark the inputs as published using the replacing tx.
- return s.markInputsPublished(tr, r.Tx.TxIn)
+ return s.markInputsPublished(tr, resp.set)
}
// handleBumpEventTxPublished handles the case where the sweeping tx has been
// successfully published.
-func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error {
+func (s *UtxoSweeper) handleBumpEventTxPublished(resp *bumpResp) error {
+ r := resp.result
tx := r.Tx
tr := &TxRecord{
Txid: tx.TxHash(),
@@ -1707,7 +1745,7 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error {
// Inputs have been successfully published so we update their
// states.
- err := s.markInputsPublished(tr, tx.TxIn)
+ err := s.markInputsPublished(tr, resp.set)
if err != nil {
return err
}
@@ -1723,15 +1761,71 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error {
return nil
}
+// handleBumpEventTxFatal handles the case where there's an unexpected error
+// when creating or publishing the sweeping tx. In this case, the tx will be
+// removed from the sweeper store and the inputs will be marked as `Failed`,
+// which means they will not be retried.
+func (s *UtxoSweeper) handleBumpEventTxFatal(resp *bumpResp) error {
+ r := resp.result
+
+ // Remove the tx from the sweeper store if there is one. Since this is
+ // a broadcast error, it's likely there isn't a tx here.
+ if r.Tx != nil {
+ txid := r.Tx.TxHash()
+ log.Infof("Tx=%v failed with unexpected error: %v", txid, r.Err)
+
+ // Remove the tx from the sweeper db if it exists.
+ if err := s.cfg.Store.DeleteTx(txid); err != nil {
+ return fmt.Errorf("delete tx record for %v: %w", txid,
+ err)
+ }
+ }
+
+ // Mark the inputs as failed.
+ s.markInputsFailed(resp.set, r.Err)
+
+ return nil
+}
+
+// markInputsFailed marks all inputs found in the tx as failed. It will also
+// notify all the subscribers of these inputs.
+func (s *UtxoSweeper) markInputsFailed(set InputSet, err error) {
+ for _, inp := range set.Inputs() {
+ outpoint := inp.OutPoint()
+
+ input, ok := s.inputs[outpoint]
+ if !ok {
+ // It's very likely that a spending tx contains inputs
+ // that we don't know.
+ log.Tracef("Skipped marking input as failed: %v not "+
+ "found in pending inputs", outpoint)
+
+ continue
+ }
+
+ // If the input is already in a terminal state, we don't want
+ // to rewrite it, which also indicates an error as we only get
+ // an error event during the initial broadcast.
+ if input.terminated() {
+ log.Errorf("Skipped marking input=%v as failed due to "+
+ "unexpected state=%v", outpoint, input.state)
+
+ continue
+ }
+
+ s.markInputFailed(input, err)
+ }
+}
+
// handleBumpEvent handles the result sent from the bumper based on its event
// type.
//
// NOTE: TxConfirmed event is not handled, since we already subscribe to the
// input's spending event, we don't need to do anything here.
-func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error {
- log.Debugf("Received bump event [%v] for tx %v", r.Event, r.Tx.TxHash())
+func (s *UtxoSweeper) handleBumpEvent(r *bumpResp) error {
+ log.Debugf("Received bump result %v", r.result)
- switch r.Event {
+ switch r.result.Event {
// The tx has been published, we update the inputs' state and create a
// record to be stored in the sweeper db.
case TxPublished:
@@ -1739,12 +1833,18 @@ func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error {
// The tx has failed, we update the inputs' state.
case TxFailed:
- return s.handleBumpEventTxFailed(r)
+ s.handleBumpEventTxFailed(r)
+ return nil
// The tx has been replaced, we will remove the old tx and replace it
// with the new one.
case TxReplaced:
return s.handleBumpEventTxReplaced(r)
+
+ // There's a fatal error in creating the tx, we will remove the tx from
+ // the sweeper db and mark the inputs as failed.
+ case TxFatal:
+ return s.handleBumpEventTxFatal(r)
}
return nil
diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go
index 2b61f67933..16a4a46fbe 100644
--- a/sweep/sweeper_test.go
+++ b/sweep/sweeper_test.go
@@ -1,6 +1,7 @@
package sweep
import (
+ "crypto/rand"
"errors"
"testing"
"time"
@@ -10,8 +11,9 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
+ "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/stretchr/testify/mock"
@@ -33,6 +35,41 @@ var (
})
)
+// createMockInput creates a mock input and saves it to the sweeper's inputs
+// map. The created input has the specified state and a random outpoint. It
+// will assert the method `OutPoint` is called at least once.
+func createMockInput(t *testing.T, s *UtxoSweeper,
+ state SweepState) *input.MockInput {
+
+ inp := &input.MockInput{}
+ t.Cleanup(func() {
+ inp.AssertExpectations(t)
+ })
+
+ randBuf := make([]byte, lntypes.HashSize)
+ _, err := rand.Read(randBuf)
+ require.NoError(t, err, "internal error, cannot generate random bytes")
+
+ randHash, err := chainhash.NewHash(randBuf)
+ require.NoError(t, err)
+
+ inp.On("OutPoint").Return(wire.OutPoint{
+ Hash: *randHash,
+ Index: 0,
+ })
+
+ // We don't do branch switches based on the witness type here so we
+ // just mock it.
+ inp.On("WitnessType").Return(input.CommitmentTimeLock).Maybe()
+
+ s.inputs[inp.OutPoint()] = &SweeperInput{
+ Input: inp,
+ state: state,
+ }
+
+ return inp
+}
+
// TestMarkInputsPendingPublish checks that given a list of inputs with
// different states, only the non-terminal state will be marked as `Published`.
func TestMarkInputsPendingPublish(t *testing.T) {
@@ -47,50 +84,21 @@ func TestMarkInputsPendingPublish(t *testing.T) {
set := &MockInputSet{}
defer set.AssertExpectations(t)
- // Create three testing inputs.
- //
- // inputNotExist specifies an input that's not found in the sweeper's
- // `pendingInputs` map.
- inputNotExist := &input.MockInput{}
- defer inputNotExist.AssertExpectations(t)
-
- inputNotExist.On("OutPoint").Return(wire.OutPoint{Index: 0})
-
- // inputInit specifies a newly created input.
- inputInit := &input.MockInput{}
- defer inputInit.AssertExpectations(t)
-
- inputInit.On("OutPoint").Return(wire.OutPoint{Index: 1})
-
- s.inputs[inputInit.OutPoint()] = &SweeperInput{
- state: Init,
- }
-
- // inputPendingPublish specifies an input that's about to be published.
- inputPendingPublish := &input.MockInput{}
- defer inputPendingPublish.AssertExpectations(t)
-
- inputPendingPublish.On("OutPoint").Return(wire.OutPoint{Index: 2})
-
- s.inputs[inputPendingPublish.OutPoint()] = &SweeperInput{
- state: PendingPublish,
- }
-
- // inputTerminated specifies an input that's terminated.
- inputTerminated := &input.MockInput{}
- defer inputTerminated.AssertExpectations(t)
-
- inputTerminated.On("OutPoint").Return(wire.OutPoint{Index: 3})
-
- s.inputs[inputTerminated.OutPoint()] = &SweeperInput{
- state: Excluded,
- }
+ // Create three inputs with different states.
+ // - inputInit specifies a newly created input.
+ // - inputPendingPublish specifies an input about to be published.
+ // - inputTerminated specifies an input that's terminated.
+ var (
+ inputInit = createMockInput(t, s, Init)
+ inputPendingPublish = createMockInput(t, s, PendingPublish)
+ inputTerminated = createMockInput(t, s, Excluded)
+ )
// Mark the test inputs. We expect the non-exist input and the
// inputTerminated to be skipped, and the rest to be marked as pending
// publish.
set.On("Inputs").Return([]input.Input{
- inputNotExist, inputInit, inputPendingPublish, inputTerminated,
+ inputInit, inputPendingPublish, inputTerminated,
})
s.markInputsPendingPublish(set)
@@ -122,36 +130,22 @@ func TestMarkInputsPublished(t *testing.T) {
dummyTR := &TxRecord{}
dummyErr := errors.New("dummy error")
+ // Create a mock input set.
+ set := &MockInputSet{}
+ defer set.AssertExpectations(t)
+
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Store: mockStore,
})
- // Create three testing inputs.
- //
- // inputNotExist specifies an input that's not found in the sweeper's
- // `inputs` map.
- inputNotExist := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 1},
- }
-
- // inputInit specifies a newly created input. When marking this as
- // published, we should see an error log as this input hasn't been
- // published yet.
- inputInit := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 2},
- }
- s.inputs[inputInit.PreviousOutPoint] = &SweeperInput{
- state: Init,
- }
-
- // inputPendingPublish specifies an input that's about to be published.
- inputPendingPublish := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 3},
- }
- s.inputs[inputPendingPublish.PreviousOutPoint] = &SweeperInput{
- state: PendingPublish,
- }
+ // Create two inputs with different states.
+ // - inputInit specifies a newly created input.
+ // - inputPendingPublish specifies an input about to be published.
+ var (
+ inputInit = createMockInput(t, s, Init)
+ inputPendingPublish = createMockInput(t, s, PendingPublish)
+ )
// First, check that when an error is returned from db, it's properly
// returned here.
@@ -171,9 +165,9 @@ func TestMarkInputsPublished(t *testing.T) {
// Mark the test inputs. We expect the non-exist input and the
// inputInit to be skipped, and the final input to be marked as
// published.
- err = s.markInputsPublished(dummyTR, []*wire.TxIn{
- inputNotExist, inputInit, inputPendingPublish,
- })
+ set.On("Inputs").Return([]input.Input{inputInit, inputPendingPublish})
+
+ err = s.markInputsPublished(dummyTR, set)
require.NoError(err)
// We expect unchanged number of pending inputs.
@@ -181,11 +175,11 @@ func TestMarkInputsPublished(t *testing.T) {
// We expect the init input's state to stay unchanged.
require.Equal(Init,
- s.inputs[inputInit.PreviousOutPoint].state)
+ s.inputs[inputInit.OutPoint()].state)
// We expect the pending-publish input's is now marked as published.
require.Equal(Published,
- s.inputs[inputPendingPublish.PreviousOutPoint].state)
+ s.inputs[inputPendingPublish.OutPoint()].state)
// Assert mocked statements are executed as expected.
mockStore.AssertExpectations(t)
@@ -202,117 +196,75 @@ func TestMarkInputsPublishFailed(t *testing.T) {
// Create a mock sweeper store.
mockStore := NewMockSweeperStore()
+ // Create a mock input set.
+ set := &MockInputSet{}
+ defer set.AssertExpectations(t)
+
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Store: mockStore,
})
- // Create testing inputs for each state.
- //
- // inputNotExist specifies an input that's not found in the sweeper's
- // `inputs` map.
- inputNotExist := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 1},
- }
-
- // inputInit specifies a newly created input. When marking this as
- // published, we should see an error log as this input hasn't been
- // published yet.
- inputInit := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 2},
- }
- s.inputs[inputInit.PreviousOutPoint] = &SweeperInput{
- state: Init,
- }
-
- // inputPendingPublish specifies an input that's about to be published.
- inputPendingPublish := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 3},
- }
- s.inputs[inputPendingPublish.PreviousOutPoint] = &SweeperInput{
- state: PendingPublish,
- }
-
- // inputPublished specifies an input that's published.
- inputPublished := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 4},
- }
- s.inputs[inputPublished.PreviousOutPoint] = &SweeperInput{
- state: Published,
- }
-
- // inputPublishFailed specifies an input that's failed to be published.
- inputPublishFailed := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 5},
- }
- s.inputs[inputPublishFailed.PreviousOutPoint] = &SweeperInput{
- state: PublishFailed,
- }
-
- // inputSwept specifies an input that's swept.
- inputSwept := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 6},
- }
- s.inputs[inputSwept.PreviousOutPoint] = &SweeperInput{
- state: Swept,
- }
-
- // inputExcluded specifies an input that's excluded.
- inputExcluded := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 7},
- }
- s.inputs[inputExcluded.PreviousOutPoint] = &SweeperInput{
- state: Excluded,
- }
-
- // inputFailed specifies an input that's failed.
- inputFailed := &wire.TxIn{
- PreviousOutPoint: wire.OutPoint{Index: 8},
- }
- s.inputs[inputFailed.PreviousOutPoint] = &SweeperInput{
- state: Failed,
- }
+ // Create inputs with different states.
+ // - inputInit specifies a newly created input. When marking this as
+ // published, we should see an error log as this input hasn't been
+ // published yet.
+ // - inputPendingPublish specifies an input about to be published.
+ // - inputPublished specifies an input that's published.
+ // - inputPublishFailed specifies an input that's failed to be
+ // published.
+ // - inputSwept specifies an input that's swept.
+ // - inputExcluded specifies an input that's excluded.
+ // - inputFailed specifies an input that's failed.
+ var (
+ inputInit = createMockInput(t, s, Init)
+ inputPendingPublish = createMockInput(t, s, PendingPublish)
+ inputPublished = createMockInput(t, s, Published)
+ inputPublishFailed = createMockInput(t, s, PublishFailed)
+ inputSwept = createMockInput(t, s, Swept)
+ inputExcluded = createMockInput(t, s, Excluded)
+ inputFailed = createMockInput(t, s, Failed)
+ )
- // Gather all inputs' outpoints.
- pendingOps := make([]wire.OutPoint, 0, len(s.inputs)+1)
- for op := range s.inputs {
- pendingOps = append(pendingOps, op)
- }
- pendingOps = append(pendingOps, inputNotExist.PreviousOutPoint)
+ // Gather all inputs.
+ set.On("Inputs").Return([]input.Input{
+ inputInit, inputPendingPublish, inputPublished,
+ inputPublishFailed, inputSwept, inputExcluded, inputFailed,
+ })
// Mark the test inputs. We expect the non-exist input and the
// inputInit to be skipped, and the final input to be marked as
// published.
- s.markInputsPublishFailed(pendingOps)
+ s.markInputsPublishFailed(set)
// We expect unchanged number of pending inputs.
require.Len(s.inputs, 7)
// We expect the init input's state to stay unchanged.
require.Equal(Init,
- s.inputs[inputInit.PreviousOutPoint].state)
+ s.inputs[inputInit.OutPoint()].state)
// We expect the pending-publish input's is now marked as publish
// failed.
require.Equal(PublishFailed,
- s.inputs[inputPendingPublish.PreviousOutPoint].state)
+ s.inputs[inputPendingPublish.OutPoint()].state)
// We expect the published input's is now marked as publish failed.
require.Equal(PublishFailed,
- s.inputs[inputPublished.PreviousOutPoint].state)
+ s.inputs[inputPublished.OutPoint()].state)
// We expect the publish failed input to stay unchanged.
require.Equal(PublishFailed,
- s.inputs[inputPublishFailed.PreviousOutPoint].state)
+ s.inputs[inputPublishFailed.OutPoint()].state)
// We expect the swept input to stay unchanged.
- require.Equal(Swept, s.inputs[inputSwept.PreviousOutPoint].state)
+ require.Equal(Swept, s.inputs[inputSwept.OutPoint()].state)
// We expect the excluded input to stay unchanged.
- require.Equal(Excluded, s.inputs[inputExcluded.PreviousOutPoint].state)
+ require.Equal(Excluded, s.inputs[inputExcluded.OutPoint()].state)
// We expect the failed input to stay unchanged.
- require.Equal(Failed, s.inputs[inputFailed.PreviousOutPoint].state)
+ require.Equal(Failed, s.inputs[inputFailed.OutPoint()].state)
// Assert mocked statements are executed as expected.
mockStore.AssertExpectations(t)
@@ -491,6 +443,7 @@ func TestUpdateSweeperInputs(t *testing.T) {
// returned.
inp2.On("RequiredLockTime").Return(
uint32(s.currentHeight+1), true).Once()
+ inp2.On("OutPoint").Return(wire.OutPoint{Index: 2}).Maybe()
input7 := &SweeperInput{state: Init, Input: inp2}
// Mock the input to have a CSV expiry in the future so it will NOT be
@@ -499,6 +452,7 @@ func TestUpdateSweeperInputs(t *testing.T) {
uint32(s.currentHeight), false).Once()
inp3.On("BlocksToMaturity").Return(uint32(2)).Once()
inp3.On("HeightHint").Return(uint32(s.currentHeight)).Once()
+ inp3.On("OutPoint").Return(wire.OutPoint{Index: 3}).Maybe()
input8 := &SweeperInput{state: Init, Input: inp3}
// Add the inputs to the sweeper. After the update, we should see the
@@ -704,11 +658,13 @@ func TestSweepPendingInputs(t *testing.T) {
setNeedWallet.On("Budget").Return(btcutil.Amount(1)).Once()
setNeedWallet.On("StartingFeeRate").Return(
fn.None[chainfee.SatPerKWeight]()).Once()
+ setNeedWallet.On("Immediate").Return(false).Once()
normalSet.On("Inputs").Return(nil).Maybe()
normalSet.On("DeadlineHeight").Return(testHeight).Once()
normalSet.On("Budget").Return(btcutil.Amount(1)).Once()
normalSet.On("StartingFeeRate").Return(
fn.None[chainfee.SatPerKWeight]()).Once()
+ normalSet.On("Immediate").Return(false).Once()
// Make pending inputs for testing. We don't need real values here as
// the returned clusters are mocked.
@@ -719,13 +675,8 @@ func TestSweepPendingInputs(t *testing.T) {
setNeedWallet, normalSet,
})
- // Mock `Broadcast` to return an error. This should cause the
- // `createSweepTx` inside `sweep` to fail. This is done so we can
- // terminate the method early as we are only interested in testing the
- // workflow in `sweepPendingInputs`. We don't need to test `sweep` here
- // as it should be tested in its own unit test.
- dummyErr := errors.New("dummy error")
- publisher.On("Broadcast", mock.Anything).Return(nil, dummyErr).Twice()
+ // Mock `Broadcast` to return a result.
+ publisher.On("Broadcast", mock.Anything).Return(nil).Twice()
// Call the method under test.
s.sweepPendingInputs(pis)
@@ -736,33 +687,33 @@ func TestSweepPendingInputs(t *testing.T) {
func TestHandleBumpEventTxFailed(t *testing.T) {
t.Parallel()
+ // Create a mock input set.
+ set := &MockInputSet{}
+ defer set.AssertExpectations(t)
+
// Create a test sweeper.
s := New(&UtxoSweeperConfig{})
- var (
- // Create four testing outpoints.
- op1 = wire.OutPoint{Hash: chainhash.Hash{1}}
- op2 = wire.OutPoint{Hash: chainhash.Hash{2}}
- op3 = wire.OutPoint{Hash: chainhash.Hash{3}}
- opNotExist = wire.OutPoint{Hash: chainhash.Hash{4}}
- )
+ // inputNotExist specifies an input that's not found in the sweeper's
+ // `pendingInputs` map.
+ inputNotExist := &input.MockInput{}
+ defer inputNotExist.AssertExpectations(t)
+ inputNotExist.On("OutPoint").Return(wire.OutPoint{Index: 0})
+ opNotExist := inputNotExist.OutPoint()
// Create three mock inputs.
- input1 := &input.MockInput{}
- defer input1.AssertExpectations(t)
-
- input2 := &input.MockInput{}
- defer input2.AssertExpectations(t)
+ var (
+ input1 = createMockInput(t, s, PendingPublish)
+ input2 = createMockInput(t, s, PendingPublish)
+ input3 = createMockInput(t, s, PendingPublish)
+ )
- input3 := &input.MockInput{}
- defer input3.AssertExpectations(t)
+ op1 := input1.OutPoint()
+ op2 := input2.OutPoint()
+ op3 := input3.OutPoint()
// Construct the initial state for the sweeper.
- s.inputs = InputsMap{
- op1: &SweeperInput{Input: input1, state: PendingPublish},
- op2: &SweeperInput{Input: input2, state: PendingPublish},
- op3: &SweeperInput{Input: input3, state: PendingPublish},
- }
+ set.On("Inputs").Return([]input.Input{input1, input2, input3})
// Create a testing tx that spends the first two inputs.
tx := &wire.MsgTx{
@@ -780,16 +731,26 @@ func TestHandleBumpEventTxFailed(t *testing.T) {
Err: errDummy,
}
+ // Create a testing bump response.
+ resp := &bumpResp{
+ result: br,
+ set: set,
+ }
+
// Call the method under test.
- err := s.handleBumpEvent(br)
- require.ErrorIs(t, err, errDummy)
+ err := s.handleBumpEvent(resp)
+ require.NoError(t, err)
// Assert the states of the first two inputs are updated.
require.Equal(t, PublishFailed, s.inputs[op1].state)
require.Equal(t, PublishFailed, s.inputs[op2].state)
- // Assert the state of the third input is not updated.
- require.Equal(t, PendingPublish, s.inputs[op3].state)
+ // Assert the state of the third input.
+ //
+ // NOTE: Although the tx doesn't spend it, we still mark this input as
+ // failed as we are treating the input set as the single source of
+ // truth.
+ require.Equal(t, PublishFailed, s.inputs[op3].state)
// Assert the non-existing input is not added to the pending inputs.
require.NotContains(t, s.inputs, opNotExist)
@@ -808,23 +769,21 @@ func TestHandleBumpEventTxReplaced(t *testing.T) {
wallet := &MockWallet{}
defer wallet.AssertExpectations(t)
+ // Create a mock input set.
+ set := &MockInputSet{}
+ defer set.AssertExpectations(t)
+
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Store: store,
Wallet: wallet,
})
- // Create a testing outpoint.
- op := wire.OutPoint{Hash: chainhash.Hash{1}}
-
// Create a mock input.
- inp := &input.MockInput{}
- defer inp.AssertExpectations(t)
+ inp := createMockInput(t, s, PendingPublish)
+ set.On("Inputs").Return([]input.Input{inp})
- // Construct the initial state for the sweeper.
- s.inputs = InputsMap{
- op: &SweeperInput{Input: inp, state: PendingPublish},
- }
+ op := inp.OutPoint()
// Create a testing tx that spends the input.
tx := &wire.MsgTx{
@@ -849,12 +808,18 @@ func TestHandleBumpEventTxReplaced(t *testing.T) {
Event: TxReplaced,
}
+ // Create a testing bump response.
+ resp := &bumpResp{
+ result: br,
+ set: set,
+ }
+
// Mock the store to return an error.
dummyErr := errors.New("dummy error")
store.On("GetTx", tx.TxHash()).Return(nil, dummyErr).Once()
// Call the method under test and assert the error is returned.
- err := s.handleBumpEventTxReplaced(br)
+ err := s.handleBumpEventTxReplaced(resp)
require.ErrorIs(t, err, dummyErr)
// Mock the store to return the old tx record.
@@ -869,7 +834,7 @@ func TestHandleBumpEventTxReplaced(t *testing.T) {
store.On("DeleteTx", tx.TxHash()).Return(dummyErr).Once()
// Call the method under test and assert the error is returned.
- err = s.handleBumpEventTxReplaced(br)
+ err = s.handleBumpEventTxReplaced(resp)
require.ErrorIs(t, err, dummyErr)
// Mock the store to return the old tx record and delete it without
@@ -889,7 +854,7 @@ func TestHandleBumpEventTxReplaced(t *testing.T) {
wallet.On("CancelRebroadcast", tx.TxHash()).Once()
// Call the method under test.
- err = s.handleBumpEventTxReplaced(br)
+ err = s.handleBumpEventTxReplaced(resp)
require.NoError(t, err)
// Assert the state of the input is updated.
@@ -905,22 +870,20 @@ func TestHandleBumpEventTxPublished(t *testing.T) {
store := &MockSweeperStore{}
defer store.AssertExpectations(t)
+ // Create a mock input set.
+ set := &MockInputSet{}
+ defer set.AssertExpectations(t)
+
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Store: store,
})
- // Create a testing outpoint.
- op := wire.OutPoint{Hash: chainhash.Hash{1}}
-
// Create a mock input.
- inp := &input.MockInput{}
- defer inp.AssertExpectations(t)
+ inp := createMockInput(t, s, PendingPublish)
+ set.On("Inputs").Return([]input.Input{inp})
- // Construct the initial state for the sweeper.
- s.inputs = InputsMap{
- op: &SweeperInput{Input: inp, state: PendingPublish},
- }
+ op := inp.OutPoint()
// Create a testing tx that spends the input.
tx := &wire.MsgTx{
@@ -936,6 +899,12 @@ func TestHandleBumpEventTxPublished(t *testing.T) {
Event: TxPublished,
}
+ // Create a testing bump response.
+ resp := &bumpResp{
+ result: br,
+ set: set,
+ }
+
// Mock the store to save the new tx record.
store.On("StoreTx", &TxRecord{
Txid: tx.TxHash(),
@@ -943,7 +912,7 @@ func TestHandleBumpEventTxPublished(t *testing.T) {
}).Return(nil).Once()
// Call the method under test.
- err := s.handleBumpEventTxPublished(br)
+ err := s.handleBumpEventTxPublished(resp)
require.NoError(t, err)
// Assert the state of the input is updated.
@@ -961,25 +930,21 @@ func TestMonitorFeeBumpResult(t *testing.T) {
wallet := &MockWallet{}
defer wallet.AssertExpectations(t)
+ // Create a mock input set.
+ set := &MockInputSet{}
+ defer set.AssertExpectations(t)
+
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
Store: store,
Wallet: wallet,
})
- // Create a testing outpoint.
- op := wire.OutPoint{Hash: chainhash.Hash{1}}
-
// Create a mock input.
- inp := &input.MockInput{}
- defer inp.AssertExpectations(t)
-
- // Construct the initial state for the sweeper.
- s.inputs = InputsMap{
- op: &SweeperInput{Input: inp, state: PendingPublish},
- }
+ inp := createMockInput(t, s, PendingPublish)
// Create a testing tx that spends the input.
+ op := inp.OutPoint()
tx := &wire.MsgTx{
LockTime: 1,
TxIn: []*wire.TxIn{
@@ -1058,7 +1023,8 @@ func TestMonitorFeeBumpResult(t *testing.T) {
return resultChan
},
shouldExit: false,
- }, {
+ },
+ {
// When the sweeper is shutting down, the monitor loop
// should exit.
name: "exit on sweeper shutdown",
@@ -1085,7 +1051,7 @@ func TestMonitorFeeBumpResult(t *testing.T) {
s.wg.Add(1)
go func() {
- s.monitorFeeBumpResult(resultChan)
+ s.monitorFeeBumpResult(set, resultChan)
close(done)
}()
@@ -1111,3 +1077,125 @@ func TestMonitorFeeBumpResult(t *testing.T) {
})
}
}
+
+// TestMarkInputsFailed checks that given a list of inputs with different
+// states, the method `markInputsFailed` correctly marks the inputs as failed.
+func TestMarkInputsFailed(t *testing.T) {
+ t.Parallel()
+
+ require := require.New(t)
+
+ // Create a mock input set.
+ set := &MockInputSet{}
+ defer set.AssertExpectations(t)
+
+ // Create a test sweeper.
+ s := New(&UtxoSweeperConfig{})
+
+ // Create testing inputs for each state.
+ // - inputInit specifies a newly created input. When marking this as
+ // published, we should see an error log as this input hasn't been
+ // published yet.
+ // - inputPendingPublish specifies an input about to be published.
+ // - inputPublished specifies an input that's published.
+ // - inputPublishFailed specifies an input that's failed to be
+ // published.
+ // - inputSwept specifies an input that's swept.
+ // - inputExcluded specifies an input that's excluded.
+ // - inputFailed specifies an input that's failed.
+ var (
+ inputInit = createMockInput(t, s, Init)
+ inputPendingPublish = createMockInput(t, s, PendingPublish)
+ inputPublished = createMockInput(t, s, Published)
+ inputPublishFailed = createMockInput(t, s, PublishFailed)
+ inputSwept = createMockInput(t, s, Swept)
+ inputExcluded = createMockInput(t, s, Excluded)
+ inputFailed = createMockInput(t, s, Failed)
+ )
+
+ // Gather all inputs.
+ set.On("Inputs").Return([]input.Input{
+ inputInit, inputPendingPublish, inputPublished,
+ inputPublishFailed, inputSwept, inputExcluded, inputFailed,
+ })
+
+ // Mark the test inputs. We expect the non-exist input and
+ // inputSwept/inputExcluded/inputFailed to be skipped.
+ s.markInputsFailed(set, errDummy)
+
+ // We expect unchanged number of pending inputs.
+ require.Len(s.inputs, 7)
+
+ // We expect the init input's to be marked as failed.
+ require.Equal(Failed, s.inputs[inputInit.OutPoint()].state)
+
+ // We expect the pending-publish input to be marked as failed.
+ require.Equal(Failed, s.inputs[inputPendingPublish.OutPoint()].state)
+
+ // We expect the published input to be marked as failed.
+ require.Equal(Failed, s.inputs[inputPublished.OutPoint()].state)
+
+ // We expect the publish failed input to be markd as failed.
+ require.Equal(Failed, s.inputs[inputPublishFailed.OutPoint()].state)
+
+ // We expect the swept input to stay unchanged.
+ require.Equal(Swept, s.inputs[inputSwept.OutPoint()].state)
+
+ // We expect the excluded input to stay unchanged.
+ require.Equal(Excluded, s.inputs[inputExcluded.OutPoint()].state)
+
+ // We expect the failed input to stay unchanged.
+ require.Equal(Failed, s.inputs[inputFailed.OutPoint()].state)
+}
+
+// TestHandleBumpEventTxFatal checks that `handleBumpEventTxFatal` correctly
+// handles a `TxFatal` event.
+func TestHandleBumpEventTxFatal(t *testing.T) {
+ t.Parallel()
+
+ rt := require.New(t)
+
+ // Create a mock store.
+ store := &MockSweeperStore{}
+ defer store.AssertExpectations(t)
+
+ // Create a mock input set. We are not testing `markInputFailed` here,
+ // so the actual set doesn't matter.
+ set := &MockInputSet{}
+ defer set.AssertExpectations(t)
+ set.On("Inputs").Return(nil)
+
+ // Create a test sweeper.
+ s := New(&UtxoSweeperConfig{
+ Store: store,
+ })
+
+ // Create a dummy tx.
+ tx := &wire.MsgTx{
+ LockTime: 1,
+ }
+
+ // Create a testing bump response.
+ result := &BumpResult{
+ Err: errDummy,
+ Tx: tx,
+ }
+ resp := &bumpResp{
+ result: result,
+ set: set,
+ }
+
+ // Mock the store to return an error.
+ store.On("DeleteTx", mock.Anything).Return(errDummy).Once()
+
+ // Call the method under test and assert the error is returned.
+ err := s.handleBumpEventTxFatal(resp)
+ rt.ErrorIs(err, errDummy)
+
+ // Mock the store to return nil.
+ store.On("DeleteTx", mock.Anything).Return(nil).Once()
+
+ // Call the method under test and assert no error is returned.
+ err = s.handleBumpEventTxFatal(resp)
+ rt.NoError(err)
+}
diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go
index ce144a8eb3..b80d52b0ea 100644
--- a/sweep/tx_input_set.go
+++ b/sweep/tx_input_set.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
@@ -64,6 +64,13 @@ type InputSet interface {
// StartingFeeRate returns the max starting fee rate found in the
// inputs.
StartingFeeRate() fn.Option[chainfee.SatPerKWeight]
+
+ // Immediate returns a boolean to indicate whether the tx made from
+ // this input set should be published immediately.
+ //
+ // TODO(yy): create a new method `Params` to combine the informational
+ // methods DeadlineHeight, Budget, StartingFeeRate and Immediate.
+ Immediate() bool
}
// createWalletTxInput converts a wallet utxo into an object that can be added
@@ -141,7 +148,7 @@ func validateInputs(inputs []SweeperInput, deadlineHeight int32) error {
// dedupInputs is a set used to track unique outpoints of the inputs.
dedupInputs := fn.NewSet(
// Iterate all the inputs and map the function.
- fn.Map(func(inp SweeperInput) wire.OutPoint {
+ fn.Map(inputs, func(inp SweeperInput) wire.OutPoint {
// If the input has a deadline height, we'll check if
// it's the same as the specified.
inp.params.DeadlineHeight.WhenSome(func(h int32) {
@@ -156,7 +163,7 @@ func validateInputs(inputs []SweeperInput, deadlineHeight int32) error {
})
return inp.OutPoint()
- }, inputs)...,
+ })...,
)
// Make sure the inputs share the same deadline height when there is
@@ -414,3 +421,18 @@ func (b *BudgetInputSet) StartingFeeRate() fn.Option[chainfee.SatPerKWeight] {
return startingFeeRate
}
+
+// Immediate returns whether the inputs should be swept immediately.
+//
+// NOTE: part of the InputSet interface.
+func (b *BudgetInputSet) Immediate() bool {
+ for _, inp := range b.inputs {
+ // As long as one of the inputs is immediate, the whole set is
+ // immediate.
+ if inp.params.Immediate {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/sweep/tx_input_set_test.go b/sweep/tx_input_set_test.go
index 8d0850b20d..73f056a964 100644
--- a/sweep/tx_input_set_test.go
+++ b/sweep/tx_input_set_test.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/stretchr/testify/require"
diff --git a/sweep/walletsweep.go b/sweep/walletsweep.go
index 81458fbfb0..3f790dc66f 100644
--- a/sweep/walletsweep.go
+++ b/sweep/walletsweep.go
@@ -10,7 +10,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/wtxmgr"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
diff --git a/sweep/walletsweep_test.go b/sweep/walletsweep_test.go
index 968d9cb4fb..c7a5dfc221 100644
--- a/sweep/walletsweep_test.go
+++ b/sweep/walletsweep_test.go
@@ -12,7 +12,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/wtxmgr"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
diff --git a/watchtower/blob/justice_kit.go b/watchtower/blob/justice_kit.go
index 7780239f07..9dc1af6258 100644
--- a/watchtower/blob/justice_kit.go
+++ b/watchtower/blob/justice_kit.go
@@ -10,7 +10,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
secp "github.com/decred/dcrd/dcrec/secp256k1/v4"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/watchtower/blob/justice_kit_test.go b/watchtower/blob/justice_kit_test.go
index a1d6ec9f2c..0d23e2e0fc 100644
--- a/watchtower/blob/justice_kit_test.go
+++ b/watchtower/blob/justice_kit_test.go
@@ -12,7 +12,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
diff --git a/watchtower/lookout/justice_descriptor_test.go b/watchtower/lookout/justice_descriptor_test.go
index 5045b4a0f4..ded2cd6031 100644
--- a/watchtower/lookout/justice_descriptor_test.go
+++ b/watchtower/lookout/justice_descriptor_test.go
@@ -11,7 +11,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
secp "github.com/decred/dcrd/dcrec/secp256k1/v4"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go
index 7eb34f6e37..62d7609469 100644
--- a/watchtower/wtclient/backup_task_internal_test.go
+++ b/watchtower/wtclient/backup_task_internal_test.go
@@ -10,7 +10,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go
index f3a4d5bf4e..e842876b65 100644
--- a/watchtower/wtclient/client_test.go
+++ b/watchtower/wtclient/client_test.go
@@ -19,7 +19,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channelnotifier"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go
index 01a9fa01ef..7a39c8ff73 100644
--- a/watchtower/wtclient/manager.go
+++ b/watchtower/wtclient/manager.go
@@ -12,7 +12,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channelnotifier"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/subscribe"
diff --git a/watchtower/wtdb/client_chan_summary.go b/watchtower/wtdb/client_chan_summary.go
index 6fec34c842..9ab77377b4 100644
--- a/watchtower/wtdb/client_chan_summary.go
+++ b/watchtower/wtdb/client_chan_summary.go
@@ -3,7 +3,7 @@ package wtdb
import (
"io"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go
index b6f6affce6..6e6adacc02 100644
--- a/watchtower/wtdb/client_db.go
+++ b/watchtower/wtdb/client_db.go
@@ -9,7 +9,7 @@ import (
"sync"
"github.com/btcsuite/btcd/btcec/v2"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
diff --git a/zpay32/decode.go b/zpay32/decode.go
index 61099cf2f0..76c2c1ecf4 100644
--- a/zpay32/decode.go
+++ b/zpay32/decode.go
@@ -14,7 +14,7 @@ import (
"github.com/btcsuite/btcd/btcutil/bech32"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
diff --git a/zpay32/encode.go b/zpay32/encode.go
index 3e2d799776..43ccd5ecb1 100644
--- a/zpay32/encode.go
+++ b/zpay32/encode.go
@@ -9,7 +9,7 @@ import (
"github.com/btcsuite/btcd/btcutil/bech32"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
diff --git a/zpay32/invoice.go b/zpay32/invoice.go
index 7c18253eb0..9c5d86ce2f 100644
--- a/zpay32/invoice.go
+++ b/zpay32/invoice.go
@@ -8,7 +8,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
diff --git a/zpay32/invoice_test.go b/zpay32/invoice_test.go
index a4753431e7..55718007db 100644
--- a/zpay32/invoice_test.go
+++ b/zpay32/invoice_test.go
@@ -17,7 +17,7 @@ import (
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
sphinx "github.com/lightningnetwork/lightning-onion"
- "github.com/lightningnetwork/lnd/fn"
+ "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
)