Skip to content

Commit

Permalink
Merge pull request #900 from guggero/export-coin-selection
Browse files Browse the repository at this point in the history
wallet: export coin selection strategy code for re-use
  • Loading branch information
Roasbeef authored Jan 25, 2024
2 parents 7e3e5ed + d27dac3 commit 6b096b0
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 102 deletions.
158 changes: 106 additions & 52 deletions wallet/createtx.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package wallet

import (
"errors"
"fmt"
"math/rand"
"sort"
Expand All @@ -21,16 +22,8 @@ import (
"github.com/btcsuite/btcwallet/wtxmgr"
)

// byAmount defines the methods needed to satisify sort.Interface to
// sort credits by their output amount.
type byAmount []wtxmgr.Credit

func (s byAmount) Len() int { return len(s) }
func (s byAmount) Less(i, j int) bool { return s[i].Amount < s[j].Amount }
func (s byAmount) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

func makeInputSource(eligible []wtxmgr.Credit) txauthor.InputSource {
// Current inputs and their total value. These are closed over by the
func makeInputSource(eligible []Coin) txauthor.InputSource {
// Current inputs and their total value. These are closed over by the
// returned input source and reused across multiple calls.
currentTotal := btcutil.Amount(0)
currentInputs := make([]*wire.TxIn, 0, len(eligible))
Expand All @@ -41,15 +34,25 @@ func makeInputSource(eligible []wtxmgr.Credit) txauthor.InputSource {
[]btcutil.Amount, [][]byte, error) {

for currentTotal < target && len(eligible) != 0 {
nextCredit := &eligible[0]
nextCredit := eligible[0]
prevOut := nextCredit.TxOut
outpoint := nextCredit.OutPoint
eligible = eligible[1:]
nextInput := wire.NewTxIn(&nextCredit.OutPoint, nil, nil)
currentTotal += nextCredit.Amount

nextInput := wire.NewTxIn(&outpoint, nil, nil)
currentTotal += btcutil.Amount(prevOut.Value)
currentInputs = append(currentInputs, nextInput)
currentScripts = append(currentScripts, nextCredit.PkScript)
currentInputValues = append(currentInputValues, nextCredit.Amount)
currentScripts = append(
currentScripts, prevOut.PkScript,
)
currentInputValues = append(
currentInputValues,
btcutil.Amount(prevOut.Value),
)
}
return currentTotal, currentInputs, currentInputValues, currentScripts, nil

return currentTotal, currentInputs, currentInputValues,
currentScripts, nil
}
}

Expand Down Expand Up @@ -123,6 +126,11 @@ func (w *Wallet) txToOutputs(outputs []*wire.TxOut,
return nil, err
}

// Fall back to default coin selection strategy if none is supplied.
if coinSelectionStrategy == nil {
coinSelectionStrategy = CoinSelectionLargest
}

var tx *txauthor.AuthoredTx
err = walletdb.Update(w.db, func(dbtx walletdb.ReadWriteTx) error {
addrmgrNs, changeSource, err := w.addrMgrWithChangeSource(
Expand All @@ -139,40 +147,27 @@ func (w *Wallet) txToOutputs(outputs []*wire.TxOut,
return err
}

var inputSource txauthor.InputSource

switch coinSelectionStrategy {
// Pick largest outputs first.
case CoinSelectionLargest:
sort.Sort(sort.Reverse(byAmount(eligible)))
inputSource = makeInputSource(eligible)

// Select coins at random. This prevents the creation of ever
// smaller utxos over time that may never become economical to
// spend.
case CoinSelectionRandom:
// Skip inputs that do not raise the total transaction
// output value at the requested fee rate.
var positivelyYielding []wtxmgr.Credit
for _, output := range eligible {
output := output

if !inputYieldsPositively(&output, feeSatPerKb) {
continue
}

positivelyYielding = append(
positivelyYielding, output,
)
// Wrap our coins in a type that implements the SelectableCoin
// interface, so we can arrange them according to the selected
// coin selection strategy.
wrappedEligible := make([]Coin, len(eligible))
for i := range eligible {
wrappedEligible[i] = Coin{
TxOut: wire.TxOut{
Value: int64(eligible[i].Amount),
PkScript: eligible[i].PkScript,
},
OutPoint: eligible[i].OutPoint,
}

rand.Shuffle(len(positivelyYielding), func(i, j int) {
positivelyYielding[i], positivelyYielding[j] =
positivelyYielding[j], positivelyYielding[i]
})

inputSource = makeInputSource(positivelyYielding)
}
arrangedCoins, err := coinSelectionStrategy.ArrangeCoins(
wrappedEligible, feeSatPerKb,
)
if err != nil {
return err
}

inputSource := makeInputSource(arrangedCoins)

tx, err = txauthor.NewUnsignedTransaction(
outputs, feeSatPerKb, inputSource, changeSource,
Expand Down Expand Up @@ -261,7 +256,7 @@ func (w *Wallet) txToOutputs(outputs []*wire.TxOut,

return nil
})
if err != nil && err != walletdb.ErrDryRunRollBack {
if err != nil && !errors.Is(err, walletdb.ErrDryRunRollBack) {
return nil, err
}

Expand Down Expand Up @@ -290,7 +285,7 @@ func (w *Wallet) findEligibleOutputs(dbtx walletdb.ReadTx,
output := &unspent[i]

// Only include this output if it meets the required number of
// confirmations. Coinbase transactions must have have reached
// confirmations. Coinbase transactions must have reached
// maturity before their outputs may be spent.
if !confirmed(minconf, output.Height, bs.Height) {
continue
Expand Down Expand Up @@ -337,11 +332,13 @@ func (w *Wallet) findEligibleOutputs(dbtx walletdb.ReadTx,
// best-case added virtual size. For edge cases this function can return true
// while the input is yielding slightly negative as part of the final
// transaction.
func inputYieldsPositively(credit *wtxmgr.Credit, feeRatePerKb btcutil.Amount) bool {
func inputYieldsPositively(credit *wire.TxOut,
feeRatePerKb btcutil.Amount) bool {

inputSize := txsizes.GetMinInputVirtualSize(credit.PkScript)
inputFee := feeRatePerKb * btcutil.Amount(inputSize) / 1000

return inputFee < credit.Amount
return inputFee < btcutil.Amount(credit.Value)
}

// addrMgrWithChangeSource returns the address manager bucket and a change
Expand Down Expand Up @@ -386,6 +383,9 @@ func (w *Wallet) addrMgrWithChangeSource(dbtx walletdb.ReadWriteTx,
scriptSize = txsizes.P2WPKHPkScriptSize
case waddrmgr.TaprootPubKey:
scriptSize = txsizes.P2TRPkScriptSize
default:
return nil, nil, fmt.Errorf("unsupported address type: %v",
addrType)
}

newChangeScript := func() ([]byte, error) {
Expand Down Expand Up @@ -446,3 +446,57 @@ func validateMsgTx(tx *wire.MsgTx, prevScripts [][]byte,
}
return nil
}

// sortByAmount is a generic sortable type for sorting coins by their amount.
type sortByAmount []Coin

func (s sortByAmount) Len() int { return len(s) }
func (s sortByAmount) Less(i, j int) bool {
return s[i].Value < s[j].Value
}
func (s sortByAmount) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

// LargestFirstCoinSelector is an implementation of the CoinSelectionStrategy
// that always selects the largest coins first.
type LargestFirstCoinSelector struct{}

// ArrangeCoins takes a list of coins and arranges them according to the
// specified coin selection strategy and fee rate.
func (*LargestFirstCoinSelector) ArrangeCoins(eligible []Coin,
_ btcutil.Amount) ([]Coin, error) {

sort.Sort(sort.Reverse(sortByAmount(eligible)))

return eligible, nil
}

// RandomCoinSelector is an implementation of the CoinSelectionStrategy that
// selects coins at random. This prevents the creation of ever smaller UTXOs
// over time that may never become economical to spend.
type RandomCoinSelector struct{}

// ArrangeCoins takes a list of coins and arranges them according to the
// specified coin selection strategy and fee rate.
func (*RandomCoinSelector) ArrangeCoins(eligible []Coin,
feeSatPerKb btcutil.Amount) ([]Coin, error) {

// Skip inputs that do not raise the total transaction output
// value at the requested fee rate.
positivelyYielding := make([]Coin, 0, len(eligible))
for _, output := range eligible {
output := output

if !inputYieldsPositively(&output.TxOut, feeSatPerKb) {
continue
}

positivelyYielding = append(positivelyYielding, output)
}

rand.Shuffle(len(positivelyYielding), func(i, j int) {
positivelyYielding[i], positivelyYielding[j] =
positivelyYielding[j], positivelyYielding[i]
})

return positivelyYielding, nil
}
4 changes: 2 additions & 2 deletions wallet/createtx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ func TestInputYield(t *testing.T) {
pkScript, err := txscript.PayToAddrScript(addr)
require.NoError(t, err)

credit := &wtxmgr.Credit{
Amount: 1000,
credit := &wire.TxOut{
Value: 1000,
PkScript: pkScript,
}

Expand Down
110 changes: 66 additions & 44 deletions wallet/psbt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package wallet

import (
"bytes"
"errors"
"fmt"

"github.com/btcsuite/btcd/btcutil"
Expand Down Expand Up @@ -82,48 +83,6 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, keyScope *waddrmgr.KeyScope,
amt += output.Value
}

// addInputInfo is a helper function that fetches the UTXO information
// of an input and attaches it to the PSBT packet.
addInputInfo := func(inputs []*wire.TxIn) error {
packet.Inputs = make([]psbt.PInput, len(inputs))
for idx, in := range inputs {
tx, utxo, derivationPath, _, err := w.FetchInputInfo(
&in.PreviousOutPoint,
)
if err != nil {
return fmt.Errorf("error fetching UTXO: %v",
err)
}

addr, witnessProgram, _, err := w.ScriptForOutput(utxo)
if err != nil {
return fmt.Errorf("error fetching UTXO "+
"script: %v", err)
}

// We don't want to include the witness or any script
// on the unsigned TX just yet.
packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{}
packet.UnsignedTx.TxIn[idx].SignatureScript = nil

switch {
case txscript.IsPayToTaproot(utxo.PkScript):
addInputInfoSegWitV1(
&packet.Inputs[idx], utxo,
derivationPath,
)

default:
addInputInfoSegWitV0(
&packet.Inputs[idx], tx, utxo,
derivationPath, addr, witnessProgram,
)
}
}

return nil
}

var tx *txauthor.AuthoredTx
switch {
// We need to do coin selection.
Expand All @@ -146,7 +105,16 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, keyScope *waddrmgr.KeyScope,
// include the witness as the resulting PSBT isn't expected not
// should be signed yet.
packet.UnsignedTx.TxIn = tx.Tx.TxIn
err = addInputInfo(tx.Tx.TxIn)
packet.Inputs = make([]psbt.PInput, len(packet.UnsignedTx.TxIn))

for idx := range packet.UnsignedTx.TxIn {
// We don't want to include the witness or any script
// on the unsigned TX just yet.
packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{}
packet.UnsignedTx.TxIn[idx].SignatureScript = nil
}

err := w.DecorateInputs(packet, true)
if err != nil {
return 0, err
}
Expand All @@ -155,7 +123,16 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, keyScope *waddrmgr.KeyScope,
// a change output if necessary.
default:
// Make sure all inputs provided are actually ours.
err = addInputInfo(txIn)
packet.Inputs = make([]psbt.PInput, len(packet.UnsignedTx.TxIn))

for idx := range packet.UnsignedTx.TxIn {
// We don't want to include the witness or any script
// on the unsigned TX just yet.
packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{}
packet.UnsignedTx.TxIn[idx].SignatureScript = nil
}

err := w.DecorateInputs(packet, true)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -265,6 +242,51 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, keyScope *waddrmgr.KeyScope,
return changeIndex, nil
}

// DecorateInputs fetches the UTXO information of all inputs it can identify and
// adds the required information to the package's inputs. The failOnUnknown
// boolean controls whether the method should return an error if it cannot
// identify an input or if it should just skip it.
func (w *Wallet) DecorateInputs(packet *psbt.Packet, failOnUnknown bool) error {
for idx := range packet.Inputs {
txIn := packet.UnsignedTx.TxIn[idx]

tx, utxo, derivationPath, _, err := w.FetchInputInfo(
&txIn.PreviousOutPoint,
)

switch {
// If the error just means it's not an input our wallet controls
// and the user doesn't care about that, then we can just skip
// this input and continue.
case errors.Is(err, ErrNotMine) && !failOnUnknown:
continue

case err != nil:
return fmt.Errorf("error fetching UTXO: %v", err)
}

addr, witnessProgram, _, err := w.ScriptForOutput(utxo)
if err != nil {
return fmt.Errorf("error fetching UTXO script: %v", err)
}

switch {
case txscript.IsPayToTaproot(utxo.PkScript):
addInputInfoSegWitV1(
&packet.Inputs[idx], utxo, derivationPath,
)

default:
addInputInfoSegWitV0(
&packet.Inputs[idx], tx, utxo, derivationPath,
addr, witnessProgram,
)
}
}

return nil
}

// addInputInfoSegWitV0 adds the UTXO and BIP32 derivation info for a SegWit v0
// PSBT input (p2wkh, np2wkh) from the given wallet information.
func addInputInfoSegWitV0(in *psbt.PInput, prevTx *wire.MsgTx, utxo *wire.TxOut,
Expand Down
Loading

0 comments on commit 6b096b0

Please sign in to comment.