From ae314ec7941f083d4294eba3cd9a05462bdef8be Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Thu, 3 Feb 2022 15:34:25 +0100 Subject: [PATCH] htlcswitch: add an always on mode to interceptable switch Co-authored-by: Juan Pablo Civile --- config.go | 4 + docs/release-notes/release-notes-0.15.0.md | 4 + htlcswitch/interceptable_switch.go | 68 +++++++++++-- htlcswitch/link.go | 18 ++-- htlcswitch/link_test.go | 30 +++--- htlcswitch/switch_test.go | 106 +++++++++++++++++---- htlcswitch/test_utils.go | 14 +-- peer/test_utils.go | 4 +- sample-lnd.conf | 3 + server.go | 4 +- 10 files changed, 201 insertions(+), 54 deletions(-) diff --git a/config.go b/config.go index 6a710ee795..6f7a66b983 100644 --- a/config.go +++ b/config.go @@ -354,6 +354,10 @@ type Config struct { RejectHTLC bool `long:"rejecthtlc" description:"If true, lnd will not forward any HTLCs that are meant as onward payments. This option will still allow lnd to send HTLCs and receive HTLCs but lnd won't be used as a hop."` + // RequireInterceptor determines whether the HTLC interceptor is + // registered regardless of whether the RPC is called or not. + RequireInterceptor bool `long:"requireinterceptor" description:"Whether to always intercept HTLCs, even if no stream is attached"` + StaggerInitialReconnect bool `long:"stagger-initial-reconnect" description:"If true, will apply a randomized staggering between 0s and 30s when reconnecting to persistent peers on startup. The first 10 reconnections will be attempted instantly, regardless of the flag's value"` MaxOutgoingCltvExpiry uint32 `long:"max-cltv-expiry" description:"The maximum number of blocks funds could be locked up for when forwarding payments."` diff --git a/docs/release-notes/release-notes-0.15.0.md b/docs/release-notes/release-notes-0.15.0.md index 140908d69f..f521ef3cc6 100644 --- a/docs/release-notes/release-notes-0.15.0.md +++ b/docs/release-notes/release-notes-0.15.0.md @@ -101,6 +101,10 @@ change, it allows encrypted failure messages to be returned to the sender. Additionally it is possible to signal a malformed htlc. +* Add an [always on](https://github.com/lightningnetwork/lnd/pull/6232) mode to + the HTLC interceptor API. This enables interception applications where every + packet must be intercepted. + ## Database * [Add ForAll implementation for etcd to speed up diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 13f70dcf20..87ef1e61d9 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -45,6 +45,10 @@ type InterceptableSwitch struct { // client connect and disconnect. interceptorRegistration chan ForwardInterceptor + // requireInterceptor indicates whether processing should block if no + // interceptor is connected. + requireInterceptor bool + // interceptor is the handler for intercepted packets. interceptor ForwardInterceptor @@ -58,6 +62,7 @@ type InterceptableSwitch struct { type interceptedPackets struct { packets []*htlcPacket linkQuit chan struct{} + isReplay bool } // FwdAction defines the various resolution types. @@ -101,13 +106,16 @@ type fwdResolution struct { } // NewInterceptableSwitch returns an instance of InterceptableSwitch. -func NewInterceptableSwitch(s *Switch) *InterceptableSwitch { +func NewInterceptableSwitch(s *Switch, + requireInterceptor bool) *InterceptableSwitch { + return &InterceptableSwitch{ htlcSwitch: s, intercepted: make(chan *interceptedPackets), interceptorRegistration: make(chan ForwardInterceptor), holdForwards: make(map[channeldb.CircuitKey]InterceptedForward), resolutionChan: make(chan *fwdResolution), + requireInterceptor: requireInterceptor, quit: make(chan struct{}), } @@ -155,9 +163,7 @@ func (s *InterceptableSwitch) run() { case packets := <-s.intercepted: var notIntercepted []*htlcPacket for _, p := range packets.packets { - if s.interceptor == nil || - !s.interceptForward(p) { - + if !s.interceptForward(p, packets.isReplay) { notIntercepted = append( notIntercepted, p, ) @@ -178,7 +184,6 @@ func (s *InterceptableSwitch) run() { } } } - func (s *InterceptableSwitch) sendForward(fwd InterceptedForward) { err := s.interceptor(fwd.Packet()) if err != nil { @@ -191,12 +196,28 @@ func (s *InterceptableSwitch) sendForward(fwd InterceptedForward) { func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) { s.interceptor = interceptor + // Replay all currently held htlcs. When an interceptor is not required, + // there may be none because they've been cleared after the previous + // disconnect. if interceptor != nil { log.Debugf("Interceptor connected") + for _, fwd := range s.holdForwards { + s.sendForward(fwd) + } + return } + // The interceptor disconnects. If an interceptor is required, keep the + // held htlcs. + if s.requireInterceptor { + log.Infof("Interceptor disconnected, retaining held packets") + + return + } + + // Interceptor is not required. Release held forwards. log.Infof("Interceptor disconnected, resolving held packets") for _, fwd := range s.holdForwards { @@ -260,7 +281,7 @@ func (s *InterceptableSwitch) Resolve(res *FwdResolution) error { // interceptor. If the interceptor signals the resume action, the htlcs are // forwarded to the switch. The link's quit signal should be provided to allow // cancellation of forwarding during link shutdown. -func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, +func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, isReplay bool, packets ...*htlcPacket) error { // Synchronize with the main event loop. This should be light in the @@ -269,6 +290,7 @@ func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, case s.intercepted <- &interceptedPackets{ packets: packets, linkQuit: linkQuit, + isReplay: isReplay, }: case <-linkQuit: @@ -283,7 +305,15 @@ func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, // interceptForward forwards the packet to the external interceptor after // checking the interception criteria. -func (s *InterceptableSwitch) interceptForward(packet *htlcPacket) bool { +func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, + isReplay bool) bool { + + // Process normally if an interceptor is not required and not + // registered. + if !s.requireInterceptor && s.interceptor == nil { + return false + } + switch htlc := packet.htlc.(type) { case *lnwire.UpdateAddHTLC: // We are not interested in intercepting initiated payments. @@ -307,9 +337,31 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket) bool { htlcSwitch: s.htlcSwitch, } + if s.interceptor == nil && !isReplay { + // There is no interceptor registered, we are in + // interceptor-required mode, and this is a new packet + // + // Because the interceptor has never seen this packet + // yet, it is still safe to fail back. This limits the + // backlog of htlcs when the interceptor is down. + err := intercepted.FailWithCode( + lnwire.CodeTemporaryChannelFailure, + ) + if err != nil { + log.Errorf("Cannot fail packet: %v", err) + } + + return true + } + s.holdForwards[inKey] = intercepted - s.sendForward(intercepted) + // If there is no interceptor registered, we must be in + // interceptor-required mode. The packet is kept in the queue + // until the interceptor registers itself. + if s.interceptor != nil { + s.sendForward(intercepted) + } return true diff --git a/htlcswitch/link.go b/htlcswitch/link.go index d71d20d14f..a867f6b00b 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -141,7 +141,7 @@ type ChannelLinkConfig struct { // switch. The function returns and error in case it fails to send one or // more packets. The link's quit signal should be provided to allow // cancellation of forwarding during link shutdown. - ForwardPackets func(chan struct{}, ...*htlcPacket) error + ForwardPackets func(chan struct{}, bool, ...*htlcPacket) error // DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion // blobs, which are then used to inform how to forward an HTLC. @@ -1720,7 +1720,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { l.uncommittedPreimages = append(l.uncommittedPreimages, pre) // Pipeline this settle, send it to the switch. - go l.forwardBatch(settlePacket) + go l.forwardBatch(false, settlePacket) case *lnwire.UpdateFailMalformedHTLC: // Convert the failure type encoded within the HTLC fail @@ -2744,7 +2744,7 @@ func (l *channelLink) processRemoteSettleFails(fwdPkg *channeldb.FwdPkg, // Only spawn the task forward packets we have a non-zero number. if len(switchPackets) > 0 { - go l.forwardBatch(switchPackets...) + go l.forwardBatch(false, switchPackets...) } } @@ -3043,14 +3043,17 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, return } - l.log.Debugf("forwarding %d packets to switch", len(switchPackets)) + replay := fwdPkg.State != channeldb.FwdStateLockedIn + + l.log.Debugf("forwarding %d packets to switch: replay=%v", + len(switchPackets), replay) // NOTE: This call is made synchronous so that we ensure all circuits // are committed in the exact order that they are processed in the link. // Failing to do this could cause reorderings/gaps in the range of // opened circuits, which violates assumptions made by the circuit // trimming. - l.forwardBatch(switchPackets...) + l.forwardBatch(replay, switchPackets...) } // processExitHop handles an htlc for which this link is the exit hop. It @@ -3184,7 +3187,7 @@ func (l *channelLink) settleHTLC(preimage lntypes.Preimage, // forwardBatch forwards the given htlcPackets to the switch, and waits on the // err chan for the individual responses. This method is intended to be spawned // as a goroutine so the responses can be handled in the background. -func (l *channelLink) forwardBatch(packets ...*htlcPacket) { +func (l *channelLink) forwardBatch(replay bool, packets ...*htlcPacket) { // Don't forward packets for which we already have a response in our // mailbox. This could happen if a packet fails and is buffered in the // mailbox, and the incoming link flaps. @@ -3197,7 +3200,8 @@ func (l *channelLink) forwardBatch(packets ...*htlcPacket) { filteredPkts = append(filteredPkts, pkt) } - if err := l.cfg.ForwardPackets(l.quit, filteredPkts...); err != nil { + err := l.cfg.ForwardPackets(l.quit, replay, filteredPkts...) + if err != nil { log.Errorf("Unhandled error while reforwarding htlc "+ "settle/fail over htlcswitch: %v", err) } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index d2cfdbcea8..5320336ec0 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1940,12 +1940,14 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( // the firing via force feeding. bticker := ticker.NewForce(time.Hour) aliceCfg := ChannelLinkConfig{ - FwrdingPolicy: globalPolicy, - Peer: alicePeer, - Switch: aliceSwitch, - BestHeight: aliceSwitch.BestHeight, - Circuits: aliceSwitch.CircuitModifier(), - ForwardPackets: aliceSwitch.ForwardPackets, + FwrdingPolicy: globalPolicy, + Peer: alicePeer, + Switch: aliceSwitch, + BestHeight: aliceSwitch.BestHeight, + Circuits: aliceSwitch.CircuitModifier(), + ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error { + return aliceSwitch.ForwardPackets(linkQuit, packets...) + }, DecodeHopIterators: decoder.DecodeHopIterators, ExtractErrorEncrypter: func(*btcec.PublicKey) ( hop.ErrorEncrypter, lnwire.FailCode) { @@ -4491,12 +4493,14 @@ func (h *persistentLinkHarness) restartLink( // the firing via force feeding. bticker := ticker.NewForce(time.Hour) aliceCfg := ChannelLinkConfig{ - FwrdingPolicy: globalPolicy, - Peer: alicePeer, - Switch: aliceSwitch, - BestHeight: aliceSwitch.BestHeight, - Circuits: aliceSwitch.CircuitModifier(), - ForwardPackets: aliceSwitch.ForwardPackets, + FwrdingPolicy: globalPolicy, + Peer: alicePeer, + Switch: aliceSwitch, + BestHeight: aliceSwitch.BestHeight, + Circuits: aliceSwitch.CircuitModifier(), + ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error { + return aliceSwitch.ForwardPackets(linkQuit, packets...) + }, DecodeHopIterators: decoder.DecodeHopIterators, ExtractErrorEncrypter: func(*btcec.PublicKey) ( hop.ErrorEncrypter, lnwire.FailCode) { @@ -6694,7 +6698,7 @@ func TestPipelineSettle(t *testing.T) { // erroneously forwarded. If the forwardChan is closed before the last // step, then the test will fail. forwardChan := make(chan struct{}) - fwdPkts := func(c chan struct{}, hp ...*htlcPacket) error { + fwdPkts := func(c chan struct{}, _ bool, hp ...*htlcPacket) error { close(forwardChan) return nil } diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 4c22b5e8e9..148a195391 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3273,7 +3273,7 @@ func TestSwitchHoldForward(t *testing.T) { t: t, interceptedChan: make(chan InterceptedPacket), } - switchForwardInterceptor := NewInterceptableSwitch(s) + switchForwardInterceptor := NewInterceptableSwitch(s, false) require.NoError(t, switchForwardInterceptor.Start()) switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc) @@ -3281,9 +3281,9 @@ func TestSwitchHoldForward(t *testing.T) { // Test resume a hold forward. assertNumCircuits(t, s, 0, 0) - if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } + err = switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket) + require.NoError(t, err) + assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3303,15 +3303,16 @@ func TestSwitchHoldForward(t *testing.T) { PaymentPreimage: preimage, }, } - if err := switchForwardInterceptor.ForwardPackets(linkQuit, settle); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } + err = switchForwardInterceptor.ForwardPackets(linkQuit, false, settle) + require.NoError(t, err) + assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) // Test resume a hold forward after disconnection. - err = switchForwardInterceptor.ForwardPackets(nil, ogPacket) - require.NoError(t, err) + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, ogPacket, + )) // Wait until the packet is offered to the interceptor. _ = forwardInterceptor.getIntercepted() @@ -3328,7 +3329,9 @@ func TestSwitchHoldForward(t *testing.T) { // Settle the htlc to close the circuit. settle.outgoingHTLCID = 1 - require.NoError(t, switchForwardInterceptor.ForwardPackets(nil, settle)) + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, settle, + )) assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) @@ -3338,9 +3341,9 @@ func TestSwitchHoldForward(t *testing.T) { forwardInterceptor.InterceptForwardHtlc, ) - if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, ogPacket, + )) assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3355,7 +3358,7 @@ func TestSwitchHoldForward(t *testing.T) { // Test failing a hold forward with a failure message. require.NoError(t, - switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket), + switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket), ) assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3375,7 +3378,7 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) // Test failing a hold forward with a malformed htlc failure. - err = switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket) + err = switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket) require.NoError(t, err) assertNumCircuits(t, s, 0, 0) @@ -3404,9 +3407,9 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) // Test settling a hold forward - if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, ogPacket, + )) assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3420,6 +3423,73 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) require.NoError(t, switchForwardInterceptor.Stop()) + + // Test always-on interception. + switchForwardInterceptor = NewInterceptableSwitch(s, true) + require.NoError(t, switchForwardInterceptor.Start()) + + // Forward a fresh packet. It is expected to be failed immediately, + // because there is no interceptor registered. + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, ogPacket, + )) + + assertOutgoingLinkReceive(t, bobChannelLink, false) + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) + + // Forward a replayed packet. It is expected to be held until the + // interceptor connects. To continue the test, it needs to be ran in a + // goroutine. + errChan := make(chan error) + go func() { + errChan <- switchForwardInterceptor.ForwardPackets( + linkQuit, true, ogPacket, + ) + }() + + // Assert that nothing is forward to the switch. + assertOutgoingLinkReceive(t, bobChannelLink, false) + assertNumCircuits(t, s, 0, 0) + + // Register an interceptor. + switchForwardInterceptor.SetInterceptor( + forwardInterceptor.InterceptForwardHtlc, + ) + + // Expect the ForwardPackets call to unblock. + require.NoError(t, <-errChan) + + // Now expect the queued packet to come through. + forwardInterceptor.getIntercepted() + + // Disconnect and reconnect interceptor. + switchForwardInterceptor.SetInterceptor(nil) + switchForwardInterceptor.SetInterceptor( + forwardInterceptor.InterceptForwardHtlc, + ) + + // A replay of the held packet is expected. + intercepted := forwardInterceptor.getIntercepted() + + // Settle the packet. + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Key: intercepted.IncomingCircuit, + Action: FwdActionSettle, + Preimage: preimage, + })) + assertOutgoingLinkReceive(t, bobChannelLink, false) + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) + + require.NoError(t, switchForwardInterceptor.Stop()) + + select { + case <-forwardInterceptor.interceptedChan: + require.Fail(t, "unexpected interception") + + default: + } } // TestSwitchDustForwarding tests that the switch properly fails HTLC's which diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 014adfc52d..7dfd1599c3 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -1135,12 +1135,14 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer, link := NewChannelLink( ChannelLinkConfig{ - Switch: server.htlcSwitch, - BestHeight: server.htlcSwitch.BestHeight, - FwrdingPolicy: h.globalPolicy, - Peer: peer, - Circuits: server.htlcSwitch.CircuitModifier(), - ForwardPackets: server.htlcSwitch.ForwardPackets, + Switch: server.htlcSwitch, + BestHeight: server.htlcSwitch.BestHeight, + FwrdingPolicy: h.globalPolicy, + Peer: peer, + Circuits: server.htlcSwitch.CircuitModifier(), + ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error { + return server.htlcSwitch.ForwardPackets(linkQuit, packets...) + }, DecodeHopIterators: decoder.DecodeHopIterators, ExtractErrorEncrypter: func(*btcec.PublicKey) ( hop.ErrorEncrypter, lnwire.FailCode) { diff --git a/peer/test_utils.go b/peer/test_utils.go index 75cc19f91f..40f3295634 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -367,7 +367,9 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, Switch: mockSwitch, ChanActiveTimeout: chanActiveTimeout, - InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil), + InterceptSwitch: htlcswitch.NewInterceptableSwitch( + nil, false, + ), ChannelDB: dbAlice.ChannelStateDB(), FeeEstimator: estimator, diff --git a/sample-lnd.conf b/sample-lnd.conf index 714aaea7dc..029547e064 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -347,6 +347,9 @@ ; used as a hop. ; rejecthtlc=true +; If true, all HTLCs will be held until they are handled by an interceptor +; requireinterceptor=true + ; If true, will apply a randomized staggering between 0s and 30s when ; reconnecting to persistent peers on startup. The first 10 reconnections will be ; attempted instantly, regardless of the flag's value diff --git a/server.go b/server.go index b2d3b1c9d3..28ac518074 100644 --- a/server.go +++ b/server.go @@ -654,7 +654,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, if err != nil { return nil, err } - s.interceptableSwitch = htlcswitch.NewInterceptableSwitch(s.htlcSwitch) + s.interceptableSwitch = htlcswitch.NewInterceptableSwitch( + s.htlcSwitch, s.cfg.RequireInterceptor, + ) chanStatusMgrCfg := &netann.ChanStatusConfig{ ChanStatusSampleInterval: cfg.ChanStatusSampleInterval,