From 169f0c0bf4feda33c8e9bbcc2aa55893b2a41beb Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Mon, 7 Feb 2022 08:53:10 +0100 Subject: [PATCH 1/2] routerrpc+htlcswitch: move intercepted htlc tracking to switch In this commit we move the tracking of the outstanding intercepted htlcs to InterceptableSwitch. This is a preparation for making the htlc interceptor required. Required interception involves tracking outstanding htlcs across multiple grpc client sessions. The per-session routerrpc forwardInterceptor object is therefore no longer the best place for that. --- htlcswitch/interceptable_switch.go | 292 +++++++++++++++++++++---- htlcswitch/interfaces.go | 5 +- htlcswitch/switch_test.go | 105 ++++++--- lnrpc/routerrpc/forward_interceptor.go | 172 ++++----------- lnrpc/routerrpc/router_server.go | 4 +- server.go | 6 + 6 files changed, 378 insertions(+), 206 deletions(-) diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 72e8ae3b74..13f70dcf20 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -30,64 +30,260 @@ var ( // Settle - routes UpdateFulfillHTLC to the originating link. // Fail - routes UpdateFailHTLC to the originating link. type InterceptableSwitch struct { - sync.RWMutex - // htlcSwitch is the underline switch htlcSwitch *Switch - // fwdInterceptor is the callback that is called for each forward of - // an incoming htlc. It should return true if it is interested in handling - // it. - fwdInterceptor ForwardInterceptor + // intercepted is where we stream all intercepted packets coming from + // the switch. + intercepted chan *interceptedPackets + + // resolutionChan is where we stream all responses coming from the + // interceptor client. + resolutionChan chan *fwdResolution + + // interceptorRegistration is a channel that we use to synchronize + // client connect and disconnect. + interceptorRegistration chan ForwardInterceptor + + // interceptor is the handler for intercepted packets. + interceptor ForwardInterceptor + + // holdForwards keeps track of outstanding intercepted forwards. + holdForwards map[channeldb.CircuitKey]InterceptedForward + + wg sync.WaitGroup + quit chan struct{} +} + +type interceptedPackets struct { + packets []*htlcPacket + linkQuit chan struct{} +} + +// FwdAction defines the various resolution types. +type FwdAction int + +const ( + // FwdActionResume forwards the intercepted packet to the switch. + FwdActionResume FwdAction = iota + + // FwdActionSettle settles the intercepted packet with a preimage. + FwdActionSettle + + // FwdActionFail fails the intercepted packet back to the sender. + FwdActionFail +) + +// FwdResolution defines the action to be taken on an intercepted packet. +type FwdResolution struct { + // Key is the incoming circuit key of the htlc. + Key channeldb.CircuitKey + + // Action is the action to take on the intercepted htlc. + Action FwdAction + + // Preimage is the preimage that is to be used for settling if Action is + // FwdActionSettle. + Preimage lntypes.Preimage + + // FailureMessage is the encrypted failure message that is to be passed + // back to the sender if action is FwdActionFail. + FailureMessage []byte + + // FailureCode is the failure code that is to be passed back to the + // sender if action is FwdActionFail. + FailureCode lnwire.FailCode +} + +type fwdResolution struct { + resolution *FwdResolution + errChan chan error } // NewInterceptableSwitch returns an instance of InterceptableSwitch. func NewInterceptableSwitch(s *Switch) *InterceptableSwitch { - return &InterceptableSwitch{htlcSwitch: s} + return &InterceptableSwitch{ + htlcSwitch: s, + intercepted: make(chan *interceptedPackets), + interceptorRegistration: make(chan ForwardInterceptor), + holdForwards: make(map[channeldb.CircuitKey]InterceptedForward), + resolutionChan: make(chan *fwdResolution), + + quit: make(chan struct{}), + } } -// SetInterceptor sets the ForwardInterceptor to be used. +// SetInterceptor sets the ForwardInterceptor to be used. A nil argument +// unregisters the current interceptor. func (s *InterceptableSwitch) SetInterceptor( interceptor ForwardInterceptor) { - s.Lock() - defer s.Unlock() - s.fwdInterceptor = interceptor + // Synchronize setting the handler with the main loop to prevent race + // conditions. + select { + case s.interceptorRegistration <- interceptor: + + case <-s.quit: + } } -// ForwardPackets attempts to forward the batch of htlcs through the -// switch, any failed packets will be returned to the provided -// ChannelLink. The link's quit signal should be provided to allow -// cancellation of forwarding during link shutdown. -func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, - packets ...*htlcPacket) error { +func (s *InterceptableSwitch) Start() error { + s.wg.Add(1) + go func() { + defer s.wg.Done() + + s.run() + }() + + return nil +} - var interceptor ForwardInterceptor - s.Lock() - interceptor = s.fwdInterceptor - s.Unlock() +func (s *InterceptableSwitch) Stop() error { + close(s.quit) + s.wg.Wait() - // Optimize for the case we don't have an interceptor. - if interceptor == nil { - return s.htlcSwitch.ForwardPackets(linkQuit, packets...) + return nil +} + +func (s *InterceptableSwitch) run() { + for { + select { + // An interceptor registration or de-registration came in. + case interceptor := <-s.interceptorRegistration: + s.setInterceptor(interceptor) + + case packets := <-s.intercepted: + var notIntercepted []*htlcPacket + for _, p := range packets.packets { + if s.interceptor == nil || + !s.interceptForward(p) { + + notIntercepted = append( + notIntercepted, p, + ) + } + } + err := s.htlcSwitch.ForwardPackets( + packets.linkQuit, notIntercepted..., + ) + if err != nil { + log.Errorf("Cannot forward packets: %v", err) + } + + case res := <-s.resolutionChan: + res.errChan <- s.resolve(res.resolution) + + case <-s.quit: + return + } + } +} + +func (s *InterceptableSwitch) sendForward(fwd InterceptedForward) { + err := s.interceptor(fwd.Packet()) + if err != nil { + // Only log the error. If we couldn't send the packet, we assume + // that the interceptor will reconnect so that we can retry. + log.Debugf("Interceptor cannot handle forward: %v", err) } +} + +func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) { + s.interceptor = interceptor + + if interceptor != nil { + log.Debugf("Interceptor connected") - var notIntercepted []*htlcPacket - for _, p := range packets { - if !s.interceptForward(p, interceptor, linkQuit) { - notIntercepted = append(notIntercepted, p) + return + } + + log.Infof("Interceptor disconnected, resolving held packets") + + for _, fwd := range s.holdForwards { + if err := fwd.Resume(); err != nil { + log.Errorf("Failed to resume hold forward %v", err) + } + } + s.holdForwards = make(map[channeldb.CircuitKey]InterceptedForward) +} + +func (s *InterceptableSwitch) resolve(res *FwdResolution) error { + intercepted, ok := s.holdForwards[res.Key] + if !ok { + return fmt.Errorf("fwd %v not found", res.Key) + } + delete(s.holdForwards, res.Key) + + switch res.Action { + case FwdActionResume: + return intercepted.Resume() + + case FwdActionSettle: + return intercepted.Settle(res.Preimage) + + case FwdActionFail: + if len(res.FailureMessage) > 0 { + return intercepted.Fail(res.FailureMessage) } + + return intercepted.FailWithCode(res.FailureCode) + + default: + return fmt.Errorf("unrecognized action %v", res.Action) } - return s.htlcSwitch.ForwardPackets(linkQuit, notIntercepted...) } -// interceptForward checks if there is any external interceptor interested in -// this packet. Currently only htlc type of UpdateAddHTLC that are forwarded -// are being checked for interception. It can be extended in the future given -// the right use case. -func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, - interceptor ForwardInterceptor, linkQuit chan struct{}) bool { +// Resolve resolves an intercepted packet. +func (s *InterceptableSwitch) Resolve(res *FwdResolution) error { + internalRes := &fwdResolution{ + resolution: res, + errChan: make(chan error, 1), + } + + select { + case s.resolutionChan <- internalRes: + + case <-s.quit: + return errors.New("switch shutting down") + } + select { + case err := <-internalRes.errChan: + return err + + case <-s.quit: + return errors.New("switch shutting down") + } +} + +// ForwardPackets attempts to forward the batch of htlcs to a connected +// interceptor. If the interceptor signals the resume action, the htlcs are +// forwarded to the switch. The link's quit signal should be provided to allow +// cancellation of forwarding during link shutdown. +func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, + packets ...*htlcPacket) error { + + // Synchronize with the main event loop. This should be light in the + // case where there is no interceptor. + select { + case s.intercepted <- &interceptedPackets{ + packets: packets, + linkQuit: linkQuit, + }: + + case <-linkQuit: + log.Debugf("Forward cancelled because link quit") + + case <-s.quit: + return errors.New("interceptable switch quit") + } + + return nil +} + +// interceptForward forwards the packet to the external interceptor after +// checking the interception criteria. +func (s *InterceptableSwitch) interceptForward(packet *htlcPacket) bool { switch htlc := packet.htlc.(type) { case *lnwire.UpdateAddHTLC: // We are not interested in intercepting initiated payments. @@ -95,15 +291,28 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, return false } + inKey := channeldb.CircuitKey{ + ChanID: packet.incomingChanID, + HtlcID: packet.incomingHTLCID, + } + + // Ignore already held htlcs. + if _, ok := s.holdForwards[inKey]; ok { + return true + } + intercepted := &interceptedForward{ - linkQuit: linkQuit, htlc: htlc, packet: packet, htlcSwitch: s.htlcSwitch, } - // If this htlc was intercepted, don't handle the forward. - return interceptor(intercepted) + s.holdForwards[inKey] = intercepted + + s.sendForward(intercepted) + + return true + default: return false } @@ -113,7 +322,6 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, // It is passed from the switch to external interceptors that are interested // in holding forwards and resolve them manually. type interceptedForward struct { - linkQuit chan struct{} htlc *lnwire.UpdateAddHTLC packet *htlcPacket htlcSwitch *Switch @@ -139,10 +347,12 @@ func (f *interceptedForward) Packet() InterceptedPacket { // Resume resumes the default behavior as if the packet was not intercepted. func (f *interceptedForward) Resume() error { - return f.htlcSwitch.ForwardPackets(f.linkQuit, f.packet) + // Forward to the switch. A link quit channel isn't needed, because we + // are on a different thread now. + return f.htlcSwitch.ForwardPackets(nil, f.packet) } -// Fail notifies the intention to fail an existing hold forward with an +// Fail notifies the intention to Fail an existing hold forward with an // encrypted failure reason. func (f *interceptedForward) Fail(reason []byte) error { obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 4b201ee1e6..1b80adab3e 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -234,6 +234,9 @@ type TowerClient interface { type InterceptableHtlcForwarder interface { // SetInterceptor sets a ForwardInterceptor. SetInterceptor(interceptor ForwardInterceptor) + + // Resolve resolves an intercepted packet. + Resolve(res *FwdResolution) error } // ForwardInterceptor is a function that is invoked from the switch for every @@ -242,7 +245,7 @@ type InterceptableHtlcForwarder interface { // to resolve it manually later in case it is held. // The return value indicates if this handler will take control of this forward // and resolve it later or let the switch execute its default behavior. -type ForwardInterceptor func(InterceptedForward) bool +type ForwardInterceptor func(InterceptedPacket) error // InterceptedPacket contains the relevant information for the interceptor about // an htlc. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index ffbca49070..4c22b5e8e9 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3140,32 +3140,29 @@ func getThreeHopEvents(channels *clusterChannels, htlcID uint64, } type mockForwardInterceptor struct { - intercepted InterceptedForward + t *testing.T + + interceptedChan chan InterceptedPacket } func (m *mockForwardInterceptor) InterceptForwardHtlc( - intercepted InterceptedForward) bool { - - m.intercepted = intercepted - return true -} + intercepted InterceptedPacket) error { -func (m *mockForwardInterceptor) settle(preimage lntypes.Preimage) error { - return m.intercepted.Settle(preimage) -} + m.interceptedChan <- intercepted -func (m *mockForwardInterceptor) fail(reason []byte) error { - return m.intercepted.Fail(reason) + return nil } -func (m *mockForwardInterceptor) failWithCode( - code lnwire.FailCode) error { +func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket { + select { + case p := <-m.interceptedChan: + return p - return m.intercepted.FailWithCode(code) -} + case <-time.After(time.Second): + require.Fail(m.t, "timeout") -func (m *mockForwardInterceptor) resume() error { - return m.intercepted.Resume() + return InterceptedPacket{} + } } func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) { @@ -3272,12 +3269,17 @@ func TestSwitchHoldForward(t *testing.T) { }, } - forwardInterceptor := &mockForwardInterceptor{} + forwardInterceptor := &mockForwardInterceptor{ + t: t, + interceptedChan: make(chan InterceptedPacket), + } switchForwardInterceptor := NewInterceptableSwitch(s) + require.NoError(t, switchForwardInterceptor.Start()) + switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc) linkQuit := make(chan struct{}) - // Test resume a hold forward + // Test resume a hold forward. assertNumCircuits(t, s, 0, 0) if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { t.Fatalf("can't forward htlc packet: %v", err) @@ -3285,9 +3287,10 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) - if err := forwardInterceptor.resume(); err != nil { - t.Fatalf("failed to resume forward") - } + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Action: FwdActionResume, + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + })) assertOutgoingLinkReceive(t, bobChannelLink, true) assertNumCircuits(t, s, 1, 1) @@ -3306,16 +3309,46 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) + // Test resume a hold forward after disconnection. + err = switchForwardInterceptor.ForwardPackets(nil, ogPacket) + require.NoError(t, err) + + // Wait until the packet is offered to the interceptor. + _ = forwardInterceptor.getIntercepted() + + // No forward expected yet. + assertNumCircuits(t, s, 0, 0) + assertOutgoingLinkReceive(t, bobChannelLink, false) + + // Disconnect should resume the forwarding. + switchForwardInterceptor.SetInterceptor(nil) + + assertOutgoingLinkReceive(t, bobChannelLink, true) + assertNumCircuits(t, s, 1, 1) + + // Settle the htlc to close the circuit. + settle.outgoingHTLCID = 1 + require.NoError(t, switchForwardInterceptor.ForwardPackets(nil, settle)) + + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) + // Test failing a hold forward + switchForwardInterceptor.SetInterceptor( + forwardInterceptor.InterceptForwardHtlc, + ) + if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { t.Fatalf("can't forward htlc packet: %v", err) } assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) - if err := forwardInterceptor.fail(nil); err != nil { - t.Fatalf("failed to cancel forward %v", err) - } + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Action: FwdActionFail, + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + FailureCode: lnwire.CodeTemporaryChannelFailure, + })) assertOutgoingLinkReceive(t, bobChannelLink, false) assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) @@ -3328,7 +3361,11 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, bobChannelLink, false) reason := lnwire.OpaqueReason([]byte{1, 2, 3}) - require.NoError(t, forwardInterceptor.fail(reason)) + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Action: FwdActionFail, + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + FailureMessage: reason, + })) assertOutgoingLinkReceive(t, bobChannelLink, false) packet := assertOutgoingLinkReceive(t, aliceChannelLink, true) @@ -3345,7 +3382,11 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, bobChannelLink, false) code := lnwire.CodeInvalidOnionKey - require.NoError(t, forwardInterceptor.failWithCode(code)) + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Action: FwdActionFail, + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + FailureCode: code, + })) assertOutgoingLinkReceive(t, bobChannelLink, false) packet = assertOutgoingLinkReceive(t, aliceChannelLink, true) @@ -3369,12 +3410,16 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) - if err := forwardInterceptor.settle(preimage); err != nil { - t.Fatal("failed to cancel forward") - } + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + Action: FwdActionSettle, + Preimage: preimage, + })) assertOutgoingLinkReceive(t, bobChannelLink, false) assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) + + require.NoError(t, switchForwardInterceptor.Stop()) } // TestSwitchDustForwarding tests that the switch properly fails HTLC's which diff --git a/lnrpc/routerrpc/forward_interceptor.go b/lnrpc/routerrpc/forward_interceptor.go index c52a23d2a6..9fba60f1f1 100644 --- a/lnrpc/routerrpc/forward_interceptor.go +++ b/lnrpc/routerrpc/forward_interceptor.go @@ -2,7 +2,6 @@ package routerrpc import ( "errors" - "sync" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" @@ -27,36 +26,19 @@ var ( // interceptor streaming session. // It is created when the stream opens and disconnects when the stream closes. type forwardInterceptor struct { - // server is the Server reference - server *Server - - // holdForwards is a map of current hold forwards and their corresponding - // ForwardResolver. - holdForwards map[channeldb.CircuitKey]htlcswitch.InterceptedForward - // stream is the bidirectional RPC stream stream Router_HtlcInterceptorServer - // quit is a channel that is closed when this forwardInterceptor is shutting - // down. - quit chan struct{} - - // intercepted is where we stream all intercepted packets coming from - // the switch. - intercepted chan htlcswitch.InterceptedForward - - wg sync.WaitGroup + htlcSwitch htlcswitch.InterceptableHtlcForwarder } // newForwardInterceptor creates a new forwardInterceptor. -func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) *forwardInterceptor { +func newForwardInterceptor(htlcSwitch htlcswitch.InterceptableHtlcForwarder, + stream Router_HtlcInterceptorServer) *forwardInterceptor { + return &forwardInterceptor{ - server: server, - stream: stream, - holdForwards: make( - map[channeldb.CircuitKey]htlcswitch.InterceptedForward), - quit: make(chan struct{}), - intercepted: make(chan htlcswitch.InterceptedForward), + htlcSwitch: htlcSwitch, + stream: stream, } } @@ -67,42 +49,18 @@ func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) // To coordinate all this and make sure it is safe for concurrent access all // packets are sent to the main where they are handled. func (r *forwardInterceptor) run() error { - // make sure we disconnect and resolves all remaining packets if any. - defer r.onDisconnect() - // Register our interceptor so we receive all forwarded packets. - interceptableForwarder := r.server.cfg.RouterBackend.InterceptableForwarder - interceptableForwarder.SetInterceptor(r.onIntercept) - defer interceptableForwarder.SetInterceptor(nil) - - // start a go routine that reads client resolutions. - errChan := make(chan error) - resolutionRequests := make(chan *ForwardHtlcInterceptResponse) - r.wg.Add(1) - go r.readClientResponses(resolutionRequests, errChan) + r.htlcSwitch.SetInterceptor(r.onIntercept) + defer r.htlcSwitch.SetInterceptor(nil) - // run the main loop that synchronizes both sides input into one go routine. for { - select { - case intercepted := <-r.intercepted: - log.Tracef("sending intercepted packet to client %v", intercepted) - // in case we couldn't forward we exit the loop and drain the - // current interceptor as this indicates on a connection problem. - if err := r.holdAndForwardToClient(intercepted); err != nil { - return err - } - case resolution := <-resolutionRequests: - log.Tracef("resolving intercepted packet %v", resolution) - // in case we couldn't resolve we just add a log line since this - // does not indicate on any connection problem. - if err := r.resolveFromClient(resolution); err != nil { - log.Warnf("client resolution of intercepted "+ - "packet failed %v", err) - } - case err := <-errChan: + resp, err := r.stream.Recv() + if err != nil { + return err + } + + if err := r.resolveFromClient(resp); err != nil { return err - case <-r.server.quit: - return nil } } } @@ -111,54 +69,14 @@ func (r *forwardInterceptor) run() error { // packet. Our interceptor makes sure we hold the packet and then signal to the // main loop to handle the packet. We only return true if we were able // to deliver the packet to the main loop. -func (r *forwardInterceptor) onIntercept(p htlcswitch.InterceptedForward) bool { - select { - case r.intercepted <- p: - return true - case <-r.quit: - return false - case <-r.server.quit: - return false - } -} +func (r *forwardInterceptor) onIntercept( + htlc htlcswitch.InterceptedPacket) error { -func (r *forwardInterceptor) readClientResponses( - resolutionChan chan *ForwardHtlcInterceptResponse, errChan chan error) { + log.Tracef("Sending intercepted packet to client %v", htlc) - defer r.wg.Done() - for { - resp, err := r.stream.Recv() - if err != nil { - errChan <- err - return - } - - // Now that we have the response from the RPC client, send it to - // the responses chan. - select { - case resolutionChan <- resp: - case <-r.quit: - return - case <-r.server.quit: - return - } - } -} - -// holdAndForwardToClient forwards the intercepted htlc to the client. -func (r *forwardInterceptor) holdAndForwardToClient( - forward htlcswitch.InterceptedForward) error { - - htlc := forward.Packet() inKey := htlc.IncomingCircuit - // Ignore already held htlcs. - if _, ok := r.holdForwards[inKey]; ok { - return nil - } - // First hold the forward, then send to client. - r.holdForwards[inKey] = forward interceptionRequest := &ForwardHtlcInterceptRequest{ IncomingCircuitKey: &CircuitKey{ ChanId: inKey.ChanID.ToUint64(), @@ -181,20 +99,19 @@ func (r *forwardInterceptor) holdAndForwardToClient( func (r *forwardInterceptor) resolveFromClient( in *ForwardHtlcInterceptResponse) error { + log.Tracef("Resolving intercepted packet %v", in) + circuitKey := channeldb.CircuitKey{ ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId), HtlcID: in.IncomingCircuitKey.HtlcId, } - var interceptedForward htlcswitch.InterceptedForward - interceptedForward, ok := r.holdForwards[circuitKey] - if !ok { - return ErrFwdNotExists - } - delete(r.holdForwards, circuitKey) switch in.Action { case ResolveHoldForwardAction_RESUME: - return interceptedForward.Resume() + return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ + Key: circuitKey, + Action: htlcswitch.FwdActionResume, + }) case ResolveHoldForwardAction_FAIL: // Fail with an encrypted reason. @@ -219,7 +136,11 @@ func (r *forwardInterceptor) resolveFromClient( ) } - return interceptedForward.Fail(in.FailureMessage) + return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ + Key: circuitKey, + Action: htlcswitch.FwdActionFail, + FailureMessage: in.FailureMessage, + }) } var code lnwire.FailCode @@ -244,14 +165,11 @@ func (r *forwardInterceptor) resolveFromClient( ) } - err := interceptedForward.FailWithCode(code) - if err == htlcswitch.ErrUnsupportedFailureCode { - return status.Errorf( - codes.InvalidArgument, err.Error(), - ) - } - - return err + return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ + Key: circuitKey, + Action: htlcswitch.FwdActionFail, + FailureCode: code, + }) case ResolveHoldForwardAction_SETTLE: if in.Preimage == nil { @@ -261,7 +179,12 @@ func (r *forwardInterceptor) resolveFromClient( if err != nil { return err } - return interceptedForward.Settle(preimage) + + return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ + Key: circuitKey, + Action: htlcswitch.FwdActionSettle, + Preimage: preimage, + }) default: return status.Errorf( @@ -270,20 +193,3 @@ func (r *forwardInterceptor) resolveFromClient( ) } } - -// onDisconnect removes all previousely held forwards from -// the store. Before they are removed it ensure to resume as the default -// behavior. -func (r *forwardInterceptor) onDisconnect() { - // Then close the channel so all go routine will exit. - close(r.quit) - - log.Infof("RPC interceptor disconnected, resolving held packets") - for key, forward := range r.holdForwards { - if err := forward.Resume(); err != nil { - log.Errorf("failed to resume hold forward %v", err) - } - delete(r.holdForwards, key) - } - r.wg.Wait() -} diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index dc14af4efa..0c60a62d52 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -890,7 +890,9 @@ func (s *Server) HtlcInterceptor(stream Router_HtlcInterceptorServer) error { defer atomic.CompareAndSwapInt32(&s.forwardInterceptorActive, 1, 0) // run the forward interceptor. - return newForwardInterceptor(s, stream).run() + return newForwardInterceptor( + s.cfg.RouterBackend.InterceptableForwarder, stream, + ).run() } func extractOutPoint(req *UpdateChanStatusRequest) (*wire.OutPoint, error) { diff --git a/server.go b/server.go index 6326205266..b2d3b1c9d3 100644 --- a/server.go +++ b/server.go @@ -1786,6 +1786,12 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.htlcSwitch.Stop) + if err := s.interceptableSwitch.Start(); err != nil { + startErr = err + return + } + cleanup = cleanup.add(s.interceptableSwitch.Stop) + if err := s.chainArb.Start(); err != nil { startErr = err return From ae314ec7941f083d4294eba3cd9a05462bdef8be Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Thu, 3 Feb 2022 15:34:25 +0100 Subject: [PATCH 2/2] htlcswitch: add an always on mode to interceptable switch Co-authored-by: Juan Pablo Civile --- config.go | 4 + docs/release-notes/release-notes-0.15.0.md | 4 + htlcswitch/interceptable_switch.go | 68 +++++++++++-- htlcswitch/link.go | 18 ++-- htlcswitch/link_test.go | 30 +++--- htlcswitch/switch_test.go | 106 +++++++++++++++++---- htlcswitch/test_utils.go | 14 +-- peer/test_utils.go | 4 +- sample-lnd.conf | 3 + server.go | 4 +- 10 files changed, 201 insertions(+), 54 deletions(-) diff --git a/config.go b/config.go index 6a710ee795..6f7a66b983 100644 --- a/config.go +++ b/config.go @@ -354,6 +354,10 @@ type Config struct { RejectHTLC bool `long:"rejecthtlc" description:"If true, lnd will not forward any HTLCs that are meant as onward payments. This option will still allow lnd to send HTLCs and receive HTLCs but lnd won't be used as a hop."` + // RequireInterceptor determines whether the HTLC interceptor is + // registered regardless of whether the RPC is called or not. + RequireInterceptor bool `long:"requireinterceptor" description:"Whether to always intercept HTLCs, even if no stream is attached"` + StaggerInitialReconnect bool `long:"stagger-initial-reconnect" description:"If true, will apply a randomized staggering between 0s and 30s when reconnecting to persistent peers on startup. The first 10 reconnections will be attempted instantly, regardless of the flag's value"` MaxOutgoingCltvExpiry uint32 `long:"max-cltv-expiry" description:"The maximum number of blocks funds could be locked up for when forwarding payments."` diff --git a/docs/release-notes/release-notes-0.15.0.md b/docs/release-notes/release-notes-0.15.0.md index 140908d69f..f521ef3cc6 100644 --- a/docs/release-notes/release-notes-0.15.0.md +++ b/docs/release-notes/release-notes-0.15.0.md @@ -101,6 +101,10 @@ change, it allows encrypted failure messages to be returned to the sender. Additionally it is possible to signal a malformed htlc. +* Add an [always on](https://github.com/lightningnetwork/lnd/pull/6232) mode to + the HTLC interceptor API. This enables interception applications where every + packet must be intercepted. + ## Database * [Add ForAll implementation for etcd to speed up diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 13f70dcf20..87ef1e61d9 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -45,6 +45,10 @@ type InterceptableSwitch struct { // client connect and disconnect. interceptorRegistration chan ForwardInterceptor + // requireInterceptor indicates whether processing should block if no + // interceptor is connected. + requireInterceptor bool + // interceptor is the handler for intercepted packets. interceptor ForwardInterceptor @@ -58,6 +62,7 @@ type InterceptableSwitch struct { type interceptedPackets struct { packets []*htlcPacket linkQuit chan struct{} + isReplay bool } // FwdAction defines the various resolution types. @@ -101,13 +106,16 @@ type fwdResolution struct { } // NewInterceptableSwitch returns an instance of InterceptableSwitch. -func NewInterceptableSwitch(s *Switch) *InterceptableSwitch { +func NewInterceptableSwitch(s *Switch, + requireInterceptor bool) *InterceptableSwitch { + return &InterceptableSwitch{ htlcSwitch: s, intercepted: make(chan *interceptedPackets), interceptorRegistration: make(chan ForwardInterceptor), holdForwards: make(map[channeldb.CircuitKey]InterceptedForward), resolutionChan: make(chan *fwdResolution), + requireInterceptor: requireInterceptor, quit: make(chan struct{}), } @@ -155,9 +163,7 @@ func (s *InterceptableSwitch) run() { case packets := <-s.intercepted: var notIntercepted []*htlcPacket for _, p := range packets.packets { - if s.interceptor == nil || - !s.interceptForward(p) { - + if !s.interceptForward(p, packets.isReplay) { notIntercepted = append( notIntercepted, p, ) @@ -178,7 +184,6 @@ func (s *InterceptableSwitch) run() { } } } - func (s *InterceptableSwitch) sendForward(fwd InterceptedForward) { err := s.interceptor(fwd.Packet()) if err != nil { @@ -191,12 +196,28 @@ func (s *InterceptableSwitch) sendForward(fwd InterceptedForward) { func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) { s.interceptor = interceptor + // Replay all currently held htlcs. When an interceptor is not required, + // there may be none because they've been cleared after the previous + // disconnect. if interceptor != nil { log.Debugf("Interceptor connected") + for _, fwd := range s.holdForwards { + s.sendForward(fwd) + } + return } + // The interceptor disconnects. If an interceptor is required, keep the + // held htlcs. + if s.requireInterceptor { + log.Infof("Interceptor disconnected, retaining held packets") + + return + } + + // Interceptor is not required. Release held forwards. log.Infof("Interceptor disconnected, resolving held packets") for _, fwd := range s.holdForwards { @@ -260,7 +281,7 @@ func (s *InterceptableSwitch) Resolve(res *FwdResolution) error { // interceptor. If the interceptor signals the resume action, the htlcs are // forwarded to the switch. The link's quit signal should be provided to allow // cancellation of forwarding during link shutdown. -func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, +func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, isReplay bool, packets ...*htlcPacket) error { // Synchronize with the main event loop. This should be light in the @@ -269,6 +290,7 @@ func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, case s.intercepted <- &interceptedPackets{ packets: packets, linkQuit: linkQuit, + isReplay: isReplay, }: case <-linkQuit: @@ -283,7 +305,15 @@ func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, // interceptForward forwards the packet to the external interceptor after // checking the interception criteria. -func (s *InterceptableSwitch) interceptForward(packet *htlcPacket) bool { +func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, + isReplay bool) bool { + + // Process normally if an interceptor is not required and not + // registered. + if !s.requireInterceptor && s.interceptor == nil { + return false + } + switch htlc := packet.htlc.(type) { case *lnwire.UpdateAddHTLC: // We are not interested in intercepting initiated payments. @@ -307,9 +337,31 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket) bool { htlcSwitch: s.htlcSwitch, } + if s.interceptor == nil && !isReplay { + // There is no interceptor registered, we are in + // interceptor-required mode, and this is a new packet + // + // Because the interceptor has never seen this packet + // yet, it is still safe to fail back. This limits the + // backlog of htlcs when the interceptor is down. + err := intercepted.FailWithCode( + lnwire.CodeTemporaryChannelFailure, + ) + if err != nil { + log.Errorf("Cannot fail packet: %v", err) + } + + return true + } + s.holdForwards[inKey] = intercepted - s.sendForward(intercepted) + // If there is no interceptor registered, we must be in + // interceptor-required mode. The packet is kept in the queue + // until the interceptor registers itself. + if s.interceptor != nil { + s.sendForward(intercepted) + } return true diff --git a/htlcswitch/link.go b/htlcswitch/link.go index d71d20d14f..a867f6b00b 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -141,7 +141,7 @@ type ChannelLinkConfig struct { // switch. The function returns and error in case it fails to send one or // more packets. The link's quit signal should be provided to allow // cancellation of forwarding during link shutdown. - ForwardPackets func(chan struct{}, ...*htlcPacket) error + ForwardPackets func(chan struct{}, bool, ...*htlcPacket) error // DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion // blobs, which are then used to inform how to forward an HTLC. @@ -1720,7 +1720,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { l.uncommittedPreimages = append(l.uncommittedPreimages, pre) // Pipeline this settle, send it to the switch. - go l.forwardBatch(settlePacket) + go l.forwardBatch(false, settlePacket) case *lnwire.UpdateFailMalformedHTLC: // Convert the failure type encoded within the HTLC fail @@ -2744,7 +2744,7 @@ func (l *channelLink) processRemoteSettleFails(fwdPkg *channeldb.FwdPkg, // Only spawn the task forward packets we have a non-zero number. if len(switchPackets) > 0 { - go l.forwardBatch(switchPackets...) + go l.forwardBatch(false, switchPackets...) } } @@ -3043,14 +3043,17 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, return } - l.log.Debugf("forwarding %d packets to switch", len(switchPackets)) + replay := fwdPkg.State != channeldb.FwdStateLockedIn + + l.log.Debugf("forwarding %d packets to switch: replay=%v", + len(switchPackets), replay) // NOTE: This call is made synchronous so that we ensure all circuits // are committed in the exact order that they are processed in the link. // Failing to do this could cause reorderings/gaps in the range of // opened circuits, which violates assumptions made by the circuit // trimming. - l.forwardBatch(switchPackets...) + l.forwardBatch(replay, switchPackets...) } // processExitHop handles an htlc for which this link is the exit hop. It @@ -3184,7 +3187,7 @@ func (l *channelLink) settleHTLC(preimage lntypes.Preimage, // forwardBatch forwards the given htlcPackets to the switch, and waits on the // err chan for the individual responses. This method is intended to be spawned // as a goroutine so the responses can be handled in the background. -func (l *channelLink) forwardBatch(packets ...*htlcPacket) { +func (l *channelLink) forwardBatch(replay bool, packets ...*htlcPacket) { // Don't forward packets for which we already have a response in our // mailbox. This could happen if a packet fails and is buffered in the // mailbox, and the incoming link flaps. @@ -3197,7 +3200,8 @@ func (l *channelLink) forwardBatch(packets ...*htlcPacket) { filteredPkts = append(filteredPkts, pkt) } - if err := l.cfg.ForwardPackets(l.quit, filteredPkts...); err != nil { + err := l.cfg.ForwardPackets(l.quit, replay, filteredPkts...) + if err != nil { log.Errorf("Unhandled error while reforwarding htlc "+ "settle/fail over htlcswitch: %v", err) } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index d2cfdbcea8..5320336ec0 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1940,12 +1940,14 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( // the firing via force feeding. bticker := ticker.NewForce(time.Hour) aliceCfg := ChannelLinkConfig{ - FwrdingPolicy: globalPolicy, - Peer: alicePeer, - Switch: aliceSwitch, - BestHeight: aliceSwitch.BestHeight, - Circuits: aliceSwitch.CircuitModifier(), - ForwardPackets: aliceSwitch.ForwardPackets, + FwrdingPolicy: globalPolicy, + Peer: alicePeer, + Switch: aliceSwitch, + BestHeight: aliceSwitch.BestHeight, + Circuits: aliceSwitch.CircuitModifier(), + ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error { + return aliceSwitch.ForwardPackets(linkQuit, packets...) + }, DecodeHopIterators: decoder.DecodeHopIterators, ExtractErrorEncrypter: func(*btcec.PublicKey) ( hop.ErrorEncrypter, lnwire.FailCode) { @@ -4491,12 +4493,14 @@ func (h *persistentLinkHarness) restartLink( // the firing via force feeding. bticker := ticker.NewForce(time.Hour) aliceCfg := ChannelLinkConfig{ - FwrdingPolicy: globalPolicy, - Peer: alicePeer, - Switch: aliceSwitch, - BestHeight: aliceSwitch.BestHeight, - Circuits: aliceSwitch.CircuitModifier(), - ForwardPackets: aliceSwitch.ForwardPackets, + FwrdingPolicy: globalPolicy, + Peer: alicePeer, + Switch: aliceSwitch, + BestHeight: aliceSwitch.BestHeight, + Circuits: aliceSwitch.CircuitModifier(), + ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error { + return aliceSwitch.ForwardPackets(linkQuit, packets...) + }, DecodeHopIterators: decoder.DecodeHopIterators, ExtractErrorEncrypter: func(*btcec.PublicKey) ( hop.ErrorEncrypter, lnwire.FailCode) { @@ -6694,7 +6698,7 @@ func TestPipelineSettle(t *testing.T) { // erroneously forwarded. If the forwardChan is closed before the last // step, then the test will fail. forwardChan := make(chan struct{}) - fwdPkts := func(c chan struct{}, hp ...*htlcPacket) error { + fwdPkts := func(c chan struct{}, _ bool, hp ...*htlcPacket) error { close(forwardChan) return nil } diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 4c22b5e8e9..148a195391 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3273,7 +3273,7 @@ func TestSwitchHoldForward(t *testing.T) { t: t, interceptedChan: make(chan InterceptedPacket), } - switchForwardInterceptor := NewInterceptableSwitch(s) + switchForwardInterceptor := NewInterceptableSwitch(s, false) require.NoError(t, switchForwardInterceptor.Start()) switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc) @@ -3281,9 +3281,9 @@ func TestSwitchHoldForward(t *testing.T) { // Test resume a hold forward. assertNumCircuits(t, s, 0, 0) - if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } + err = switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket) + require.NoError(t, err) + assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3303,15 +3303,16 @@ func TestSwitchHoldForward(t *testing.T) { PaymentPreimage: preimage, }, } - if err := switchForwardInterceptor.ForwardPackets(linkQuit, settle); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } + err = switchForwardInterceptor.ForwardPackets(linkQuit, false, settle) + require.NoError(t, err) + assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) // Test resume a hold forward after disconnection. - err = switchForwardInterceptor.ForwardPackets(nil, ogPacket) - require.NoError(t, err) + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, ogPacket, + )) // Wait until the packet is offered to the interceptor. _ = forwardInterceptor.getIntercepted() @@ -3328,7 +3329,9 @@ func TestSwitchHoldForward(t *testing.T) { // Settle the htlc to close the circuit. settle.outgoingHTLCID = 1 - require.NoError(t, switchForwardInterceptor.ForwardPackets(nil, settle)) + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, settle, + )) assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) @@ -3338,9 +3341,9 @@ func TestSwitchHoldForward(t *testing.T) { forwardInterceptor.InterceptForwardHtlc, ) - if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, ogPacket, + )) assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3355,7 +3358,7 @@ func TestSwitchHoldForward(t *testing.T) { // Test failing a hold forward with a failure message. require.NoError(t, - switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket), + switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket), ) assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3375,7 +3378,7 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) // Test failing a hold forward with a malformed htlc failure. - err = switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket) + err = switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket) require.NoError(t, err) assertNumCircuits(t, s, 0, 0) @@ -3404,9 +3407,9 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) // Test settling a hold forward - if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, ogPacket, + )) assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3420,6 +3423,73 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) require.NoError(t, switchForwardInterceptor.Stop()) + + // Test always-on interception. + switchForwardInterceptor = NewInterceptableSwitch(s, true) + require.NoError(t, switchForwardInterceptor.Start()) + + // Forward a fresh packet. It is expected to be failed immediately, + // because there is no interceptor registered. + require.NoError(t, switchForwardInterceptor.ForwardPackets( + linkQuit, false, ogPacket, + )) + + assertOutgoingLinkReceive(t, bobChannelLink, false) + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) + + // Forward a replayed packet. It is expected to be held until the + // interceptor connects. To continue the test, it needs to be ran in a + // goroutine. + errChan := make(chan error) + go func() { + errChan <- switchForwardInterceptor.ForwardPackets( + linkQuit, true, ogPacket, + ) + }() + + // Assert that nothing is forward to the switch. + assertOutgoingLinkReceive(t, bobChannelLink, false) + assertNumCircuits(t, s, 0, 0) + + // Register an interceptor. + switchForwardInterceptor.SetInterceptor( + forwardInterceptor.InterceptForwardHtlc, + ) + + // Expect the ForwardPackets call to unblock. + require.NoError(t, <-errChan) + + // Now expect the queued packet to come through. + forwardInterceptor.getIntercepted() + + // Disconnect and reconnect interceptor. + switchForwardInterceptor.SetInterceptor(nil) + switchForwardInterceptor.SetInterceptor( + forwardInterceptor.InterceptForwardHtlc, + ) + + // A replay of the held packet is expected. + intercepted := forwardInterceptor.getIntercepted() + + // Settle the packet. + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Key: intercepted.IncomingCircuit, + Action: FwdActionSettle, + Preimage: preimage, + })) + assertOutgoingLinkReceive(t, bobChannelLink, false) + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) + + require.NoError(t, switchForwardInterceptor.Stop()) + + select { + case <-forwardInterceptor.interceptedChan: + require.Fail(t, "unexpected interception") + + default: + } } // TestSwitchDustForwarding tests that the switch properly fails HTLC's which diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 014adfc52d..7dfd1599c3 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -1135,12 +1135,14 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer, link := NewChannelLink( ChannelLinkConfig{ - Switch: server.htlcSwitch, - BestHeight: server.htlcSwitch.BestHeight, - FwrdingPolicy: h.globalPolicy, - Peer: peer, - Circuits: server.htlcSwitch.CircuitModifier(), - ForwardPackets: server.htlcSwitch.ForwardPackets, + Switch: server.htlcSwitch, + BestHeight: server.htlcSwitch.BestHeight, + FwrdingPolicy: h.globalPolicy, + Peer: peer, + Circuits: server.htlcSwitch.CircuitModifier(), + ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error { + return server.htlcSwitch.ForwardPackets(linkQuit, packets...) + }, DecodeHopIterators: decoder.DecodeHopIterators, ExtractErrorEncrypter: func(*btcec.PublicKey) ( hop.ErrorEncrypter, lnwire.FailCode) { diff --git a/peer/test_utils.go b/peer/test_utils.go index 75cc19f91f..40f3295634 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -367,7 +367,9 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, Switch: mockSwitch, ChanActiveTimeout: chanActiveTimeout, - InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil), + InterceptSwitch: htlcswitch.NewInterceptableSwitch( + nil, false, + ), ChannelDB: dbAlice.ChannelStateDB(), FeeEstimator: estimator, diff --git a/sample-lnd.conf b/sample-lnd.conf index 714aaea7dc..029547e064 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -347,6 +347,9 @@ ; used as a hop. ; rejecthtlc=true +; If true, all HTLCs will be held until they are handled by an interceptor +; requireinterceptor=true + ; If true, will apply a randomized staggering between 0s and 30s when ; reconnecting to persistent peers on startup. The first 10 reconnections will be ; attempted instantly, regardless of the flag's value diff --git a/server.go b/server.go index b2d3b1c9d3..28ac518074 100644 --- a/server.go +++ b/server.go @@ -654,7 +654,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, if err != nil { return nil, err } - s.interceptableSwitch = htlcswitch.NewInterceptableSwitch(s.htlcSwitch) + s.interceptableSwitch = htlcswitch.NewInterceptableSwitch( + s.htlcSwitch, s.cfg.RequireInterceptor, + ) chanStatusMgrCfg := &netann.ChanStatusConfig{ ChanStatusSampleInterval: cfg.ChanStatusSampleInterval,