Skip to content

Commit

Permalink
Merge pull request #8774 from yyforyongyu/fix-size-calc
Browse files Browse the repository at this point in the history
lnwallet+sweep: fix wrong unit used in fee calculation
  • Loading branch information
Roasbeef authored May 28, 2024
2 parents bc6292f + 2e40a23 commit ff85328
Show file tree
Hide file tree
Showing 38 changed files with 301 additions and 156 deletions.
5 changes: 3 additions & 2 deletions contractcourt/breach_arbitrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
Expand Down Expand Up @@ -1497,14 +1498,14 @@ func (b *BreachArbitrator) createSweepTx(inputs ...input.Input) (*wire.MsgTx,
spendableOutputs = append(spendableOutputs, inp)
}

txWeight := int64(weightEstimate.Weight())
txWeight := weightEstimate.Weight()

return b.sweepSpendableOutputsTxn(txWeight, spendableOutputs...)
}

// sweepSpendableOutputsTxn creates a signed transaction from a sequence of
// spendable outputs by sweeping the funds into a single p2wkh output.
func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight int64,
func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit,
inputs ...input.Input) (*wire.MsgTx, error) {

// First, we obtain a new public key script from the wallet which we'll
Expand Down
44 changes: 22 additions & 22 deletions htlcswitch/link_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2480,7 +2480,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight := int64(1) * input.HTLCWeight
htlcWeight := lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer := lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -2517,7 +2517,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight = int64(2) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(2) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -2578,7 +2578,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight = int64(1) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -2609,7 +2609,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight = int64(2) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(2) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -2675,7 +2675,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight = int64(1) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -2721,7 +2721,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight = int64(2) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(2) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -2771,7 +2771,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight = int64(1) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -2835,7 +2835,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight = int64(2) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(2) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -2880,7 +2880,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight = int64(1) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -2976,7 +2976,7 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight := int64(1) * input.HTLCWeight
htlcWeight := lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer := lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3022,7 +3022,7 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight = int64(1+halfHtlcs) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1+halfHtlcs) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3113,7 +3113,7 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) {
// With two HTLCs on the pending commit, and two added to the in-memory
// commitment state, the resulting bandwidth should reflect that Alice
// is paying the all htlc amounts in addition to all htlc fees.
htlcWeight = int64(1+numHtlcs) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1+numHtlcs) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3205,7 +3205,7 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) {
// Since the latter two HTLCs have been completely dropped from memory,
// only the first two HTLCs we added should still be reflected in the
// channel bandwidth.
htlcWeight = int64(1+halfHtlcs) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1+halfHtlcs) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3264,7 +3264,7 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight := int64(1) * input.HTLCWeight
htlcWeight := lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer := lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3309,7 +3309,7 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) {

// We account for the 2 htlcs and the additional one which would be
// needed when sending and htlc.
htlcWeight = int64(1+halfHtlcs) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1+halfHtlcs) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3387,7 +3387,7 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) {
}

// Alice's bandwidth should have reverted back to her starting value.
htlcWeight = int64(1) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand All @@ -3414,7 +3414,7 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) {

// We account for the 2 htlcs and the additional one which would be
// needed when sending and htlc.
htlcWeight = int64(1+halfHtlcs) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1+halfHtlcs) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3492,7 +3492,7 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) {
t.Fatalf("expected %d packet to be failed", halfHtlcs)
}

htlcWeight = int64(1) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3540,7 +3540,7 @@ func TestChannelLinkTrimCircuitsRemoteCommit(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight := int64(1) * input.HTLCWeight
htlcWeight := lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer := lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3602,7 +3602,7 @@ func TestChannelLinkTrimCircuitsRemoteCommit(t *testing.T) {

// The resulting bandwidth should reflect that Alice is paying both
// htlc amounts, in addition to both htlc fees.
htlcWeight = int64(1+numHtlcs) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(1+numHtlcs) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down Expand Up @@ -3700,7 +3700,7 @@ func TestChannelLinkBandwidthChanReserve(t *testing.T) {

// Calculate the fee buffer for a channel state. Account for htlcs on
// the potential channel state as well.
htlcWeight := int64(1) * input.HTLCWeight
htlcWeight := lntypes.WeightUnit(1) * input.HTLCWeight
feeBuffer := lnwallet.CalcFeeBuffer(feePerKw, commitWeight+htlcWeight)

// The starting bandwidth of the channel should be exactly the amount
Expand All @@ -3727,7 +3727,7 @@ func TestChannelLinkBandwidthChanReserve(t *testing.T) {
_ = harness.aliceLink.handleSwitchPacket(addPkt)
time.Sleep(time.Millisecond * 100)

htlcWeight = int64(2) * input.HTLCWeight
htlcWeight = lntypes.WeightUnit(2) * input.HTLCWeight
feeBuffer = lnwallet.CalcFeeBuffer(
feePerKw, commitWeight+htlcWeight,
)
Expand Down
2 changes: 1 addition & 1 deletion input/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type TxInfo struct {
Fee btcutil.Amount

// Weight is the weight of the tx.
Weight int64
Weight lntypes.WeightUnit
}

// String returns a human readable version of the tx info.
Expand Down
5 changes: 3 additions & 2 deletions input/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/mock"
)

Expand Down Expand Up @@ -157,10 +158,10 @@ func (m *MockWitnessType) WitnessGenerator(signer Signer,
// if it would be included in a tx. It also returns if the output itself is a
// nested p2sh output, if so then we need to take into account the extra
// sigScript data size.
func (m *MockWitnessType) SizeUpperBound() (int, bool, error) {
func (m *MockWitnessType) SizeUpperBound() (lntypes.WeightUnit, bool, error) {
args := m.Called()

return args.Int(0), args.Bool(1), args.Error(2)
return args.Get(0).(lntypes.WeightUnit), args.Bool(1), args.Error(2)
}

// AddWeightEstimation adds the estimated size of the witness in bytes to the
Expand Down
40 changes: 25 additions & 15 deletions input/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/lightningnetwork/lnd/lntypes"
)

const (
Expand Down Expand Up @@ -862,9 +863,9 @@ type TxWeightEstimator struct {
hasWitness bool
inputCount uint32
outputCount uint32
inputSize int
inputWitnessSize int
outputSize int
inputSize lntypes.VByte
inputWitnessSize lntypes.WeightUnit
outputSize lntypes.VByte
}

// AddP2PKHInput updates the weight estimate to account for an additional input
Expand All @@ -888,7 +889,9 @@ func (twe *TxWeightEstimator) AddP2WKHInput() *TxWeightEstimator {
// AddWitnessInput updates the weight estimate to account for an additional
// input spending a native pay-to-witness output. This accepts the total size
// of the witness as a parameter.
func (twe *TxWeightEstimator) AddWitnessInput(witnessSize int) *TxWeightEstimator {
func (twe *TxWeightEstimator) AddWitnessInput(
witnessSize lntypes.WeightUnit) *TxWeightEstimator {

twe.inputSize += InputSize
twe.inputWitnessSize += witnessSize
twe.inputCount++
Expand All @@ -905,7 +908,8 @@ func (twe *TxWeightEstimator) AddWitnessInput(witnessSize int) *TxWeightEstimato
// NOTE: The leaf witness size must be calculated without the byte that accounts
// for the number of witness elements, only the total size of all elements on
// the stack that are consumed by the revealed script should be counted.
func (twe *TxWeightEstimator) AddTapscriptInput(leafWitnessSize int,
func (twe *TxWeightEstimator) AddTapscriptInput(
leafWitnessSize lntypes.WeightUnit,
tapscript *waddrmgr.Tapscript) *TxWeightEstimator {

// We add 1 byte for the total number of witness elements.
Expand All @@ -915,7 +919,9 @@ func (twe *TxWeightEstimator) AddTapscriptInput(leafWitnessSize int,
1 + len(tapscript.ControlBlock.InclusionProof)

twe.inputSize += InputSize
twe.inputWitnessSize += leafWitnessSize + controlBlockWitnessSize
twe.inputWitnessSize += leafWitnessSize + lntypes.WeightUnit(
controlBlockWitnessSize,
)
twe.inputCount++
twe.hasWitness = true

Expand Down Expand Up @@ -956,7 +962,9 @@ func (twe *TxWeightEstimator) AddNestedP2WKHInput() *TxWeightEstimator {

// AddNestedP2WSHInput updates the weight estimate to account for an additional
// input spending a P2SH output with a nested P2WSH redeem script.
func (twe *TxWeightEstimator) AddNestedP2WSHInput(witnessSize int) *TxWeightEstimator {
func (twe *TxWeightEstimator) AddNestedP2WSHInput(
witnessSize lntypes.WeightUnit) *TxWeightEstimator {

twe.inputSize += InputSize + NestedP2WSHSize
twe.inputWitnessSize += witnessSize
twe.inputCount++
Expand All @@ -967,7 +975,7 @@ func (twe *TxWeightEstimator) AddNestedP2WSHInput(witnessSize int) *TxWeightEsti

// AddTxOutput adds a known TxOut to the weight estimator.
func (twe *TxWeightEstimator) AddTxOutput(txOut *wire.TxOut) *TxWeightEstimator {
twe.outputSize += txOut.SerializeSize()
twe.outputSize += lntypes.VByte(txOut.SerializeSize())
twe.outputCount++

return twe
Expand Down Expand Up @@ -1020,18 +1028,20 @@ func (twe *TxWeightEstimator) AddP2SHOutput() *TxWeightEstimator {

// AddOutput estimates the weight of an output based on the pkScript.
func (twe *TxWeightEstimator) AddOutput(pkScript []byte) *TxWeightEstimator {
twe.outputSize += BaseOutputSize + len(pkScript)
twe.outputSize += BaseOutputSize + lntypes.VByte(len(pkScript))
twe.outputCount++

return twe
}

// Weight gets the estimated weight of the transaction.
func (twe *TxWeightEstimator) Weight() int {
txSizeStripped := BaseTxSize +
wire.VarIntSerializeSize(uint64(twe.inputCount)) + twe.inputSize +
wire.VarIntSerializeSize(uint64(twe.outputCount)) + twe.outputSize
weight := txSizeStripped * witnessScaleFactor
func (twe *TxWeightEstimator) Weight() lntypes.WeightUnit {
inputCount := wire.VarIntSerializeSize(uint64(twe.inputCount))
outputCount := wire.VarIntSerializeSize(uint64(twe.outputCount))
txSizeStripped := BaseTxSize + lntypes.VByte(inputCount) +
twe.inputSize + lntypes.VByte(outputCount) + twe.outputSize
weight := lntypes.WeightUnit(txSizeStripped * witnessScaleFactor)

if twe.hasWitness {
weight += WitnessHeaderSize + twe.inputWitnessSize
}
Expand All @@ -1041,5 +1051,5 @@ func (twe *TxWeightEstimator) Weight() int {
// VSize gets the estimated virtual size of the transactions, in vbytes.
func (twe *TxWeightEstimator) VSize() int {
// A tx's vsize is 1/4 of the weight, rounded up.
return (twe.Weight() + witnessScaleFactor - 1) / witnessScaleFactor
return int(twe.Weight().ToVB())
}
11 changes: 6 additions & 5 deletions input/size_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,12 @@ func TestTxWeightEstimator(t *testing.T) {
tx.AddTxOut(&wire.TxOut{PkScript: p2shScript})
}

expectedWeight := blockchain.GetTransactionWeight(btcutil.NewTx(tx))
if weightEstimate.Weight() != int(expectedWeight) {
t.Errorf("Case %d: Got wrong weight: expected %d, got %d",
i, expectedWeight, weightEstimate.Weight())
}
expectedWeight := blockchain.GetTransactionWeight(
btcutil.NewTx(tx),
)
require.EqualValuesf(t, expectedWeight, weightEstimate.Weight(),
"Case %d: Got wrong weight: expected %d, got %d",
i, expectedWeight, weightEstimate.Weight())
}
}

Expand Down
Loading

0 comments on commit ff85328

Please sign in to comment.