From 9feaa8d767476895fe47513dbeaa22d9e460d022 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 30 Dec 2024 23:05:11 +0100 Subject: [PATCH] Add icmp forwarder --- .../firewall/uspfilter/forwarder/forwarder.go | 6 +- client/firewall/uspfilter/forwarder/icmp.go | 95 +++++++++++++++++++ client/firewall/uspfilter/forwarder/udp.go | 1 - 3 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 client/firewall/uspfilter/forwarder/icmp.go diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index f3920065851..b9bd471ef98 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -10,6 +10,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -37,6 +38,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) { TransportProtocols: []stack.TransportProtocolFactory{ tcp.NewProtocol, udp.NewProtocol, + icmp.NewProtocol4, }, HandleLocal: false, }) @@ -101,14 +103,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) { cancel: cancel, } - // Set up TCP forwarder tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP) s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) - // Set up UDP forwarder udpForwarder := udp.NewForwarder(s, f.handleUDP) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) + log.Debugf("forwarder: Initialization complete with NIC %d", nicID) return f, nil } diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go new file mode 100644 index 00000000000..5fb80afb5da --- /dev/null +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -0,0 +1,95 @@ +package forwarder + +import ( + "context" + "net" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// handleICMP handles ICMP packets from the network stack +func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { + ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) + defer cancel() + + lc := net.ListenConfig{} + // TODO: support non-root + conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + if err != nil { + f.logger.Error("Failed to create ICMP socket for %v: %v", id, err) + return false + } + defer func() { + if err := conn.Close(); err != nil { + f.logger.Debug("Failed to close ICMP socket: %v", err) + } + }() + + dstIP := net.IP(id.LocalAddress.AsSlice()) + dst := &net.IPAddr{IP: dstIP} + + // Get the complete ICMP message (header + data) + fullPacket := stack.PayloadSince(pkt.TransportHeader()) + payload := fullPacket.AsSlice() + + icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) + + // For Echo Requests, send and handle response + if icmpHdr.Type() == header.ICMPv4Echo { + _, err = conn.WriteTo(payload, dst) + if err != nil { + f.logger.Error("Failed to write ICMP packet for %v: %v", id, err) + return false + } + + f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v", + id, icmpHdr.Type(), icmpHdr.Code()) + + return f.handleEchoResponse(conn, id) + } + + // TODO: forward other ICMP types + + return true +} + +func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) bool { + if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + f.logger.Error("Failed to set read deadline for ICMP response: %v", err) + return false + } + + response := make([]byte, f.endpoint.mtu) + n, _, err := conn.ReadFrom(response) + if err != nil { + if !isTimeout(err) { + f.logger.Error("Failed to read ICMP response: %v", err) + } + return false + } + + ipHdr := make([]byte, header.IPv4MinimumSize) + ip := header.IPv4(ipHdr) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(header.IPv4MinimumSize + n), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: id.LocalAddress, + DstAddr: id.RemoteAddress, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + fullPacket := make([]byte, 0, len(ipHdr)+n) + fullPacket = append(fullPacket, ipHdr...) + fullPacket = append(fullPacket, response[:n]...) + + if err := f.InjectIncomingPacket(fullPacket); err != nil { + f.logger.Error("Failed to inject ICMP response: %v", err) + return false + } + + f.logger.Trace("Forwarded ICMP echo reply for %v", id) + return true +} diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index a6f3ab993dc..cbe86f48655 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -9,7 +9,6 @@ import ( "sync/atomic" "time" - log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"