From 92036900337fce22e5afce7a86f2781369ea6fc8 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 26 Nov 2024 23:34:27 +0100 Subject: [PATCH] [client] Code cleaning in net pkg and fix exit node feature on Android(#2932) Code cleaning around the util/net package. The goal was to write a more understandable source code but modify nothing on the logic. Protect the WireGuard UDP listeners with marks. The implementation can support the VPN permission revocation events in thread safe way. It will be important if we start to support the running time route and DNS update features. - uniformize the file name convention: [struct_name] _ [functions] _ [os].go - code cleaning in net_linux.go - move env variables to env.go file --- client/iface/bind/control_android.go | 12 ++++ .../routemanager/systemops/systemops_linux.go | 2 +- go.mod | 2 +- go.sum | 4 +- util/net/conn.go | 31 ++++++++ util/net/dial.go | 58 +++++++++++++++ util/net/{dialer_ios.go => dial_ios.go} | 0 util/net/dialer_android.go | 25 ------- util/net/{dialer_nonios.go => dialer_dial.go} | 70 ------------------- util/net/dialer_init_android.go | 5 ++ .../{dialer_linux.go => dialer_init_linux.go} | 2 +- ...er_nonlinux.go => dialer_init_nonlinux.go} | 1 + util/net/env.go | 29 ++++++++ util/net/listen.go | 37 ++++++++++ util/net/{listener_ios.go => listen_ios.go} | 0 util/net/listener_android.go | 26 ------- util/net/listener_init_android.go | 6 ++ ...stener_linux.go => listener_init_linux.go} | 2 +- ..._nonlinux.go => listener_init_nonlinux.go} | 1 + ...{listener_nonios.go => listener_listen.go} | 25 ------- util/net/net.go | 12 ---- util/net/net_linux.go | 42 +++++++---- util/net/protectsocket_android.go | 34 ++++++++- 23 files changed, 245 insertions(+), 181 deletions(-) create mode 100644 client/iface/bind/control_android.go create mode 100644 util/net/conn.go create mode 100644 util/net/dial.go rename util/net/{dialer_ios.go => dial_ios.go} (100%) delete mode 100644 util/net/dialer_android.go rename util/net/{dialer_nonios.go => dialer_dial.go} (63%) create mode 100644 util/net/dialer_init_android.go rename util/net/{dialer_linux.go => dialer_init_linux.go} (88%) rename util/net/{dialer_nonlinux.go => dialer_init_nonlinux.go} (58%) create mode 100644 util/net/env.go create mode 100644 util/net/listen.go rename util/net/{listener_ios.go => listen_ios.go} (100%) delete mode 100644 util/net/listener_android.go create mode 100644 util/net/listener_init_android.go rename util/net/{listener_linux.go => listener_init_linux.go} (89%) rename util/net/{listener_nonlinux.go => listener_init_nonlinux.go} (61%) rename util/net/{listener_nonios.go => listener_listen.go} (84%) diff --git a/client/iface/bind/control_android.go b/client/iface/bind/control_android.go new file mode 100644 index 00000000000..b8a865e3908 --- /dev/null +++ b/client/iface/bind/control_android.go @@ -0,0 +1,12 @@ +package bind + +import ( + wireguard "golang.zx2c4.com/wireguard/conn" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func init() { + // ControlFns is not thread safe and should only be modified during init. + *wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket) +} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 1d629d6e975..455e3407e2a 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -55,7 +55,7 @@ type ruleParams struct { // isLegacy determines whether to use the legacy routing setup func isLegacy() bool { - return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true" + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark() } // setIsLegacy sets the legacy routing setup diff --git a/go.mod b/go.mod index 0a16753ea43..e8c65542280 100644 --- a/go.mod +++ b/go.mod @@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index a4d7ea7f9c1..47975d4eab4 100644 --- a/go.sum +++ b/go.sum @@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= diff --git a/util/net/conn.go b/util/net/conn.go new file mode 100644 index 00000000000..26693f84166 --- /dev/null +++ b/util/net/conn.go @@ -0,0 +1,31 @@ +//go:build !ios + +package net + +import ( + "net" + + log "github.com/sirupsen/logrus" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} diff --git a/util/net/dial.go b/util/net/dial.go new file mode 100644 index 00000000000..59531149278 --- /dev/null +++ b/util/net/dial.go @@ -0,0 +1,58 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/util/net/dialer_ios.go b/util/net/dial_ios.go similarity index 100% rename from util/net/dialer_ios.go rename to util/net/dial_ios.go diff --git a/util/net/dialer_android.go b/util/net/dialer_android.go deleted file mode 100644 index 4cbded53634..00000000000 --- a/util/net/dialer_android.go +++ /dev/null @@ -1,25 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -func (d *Dialer) init() { - d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/dialer_nonios.go b/util/net/dialer_dial.go similarity index 63% rename from util/net/dialer_nonios.go rename to util/net/dialer_dial.go index 34004a368c1..1659b622051 100644 --- a/util/net/dialer_nonios.go +++ b/util/net/dialer_dial.go @@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) { return d.DialContext(context.Background(), network, address) } -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} - func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { host, _, err := net.SplitHostPort(address) if err != nil { @@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r return result.ErrorOrNil() } - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_init_android.go b/util/net/dialer_init_android.go new file mode 100644 index 00000000000..63b9033484e --- /dev/null +++ b/util/net/dialer_init_android.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = ControlProtectSocket +} diff --git a/util/net/dialer_linux.go b/util/net/dialer_init_linux.go similarity index 88% rename from util/net/dialer_linux.go rename to util/net/dialer_init_linux.go index aed5c59a322..d801e608086 100644 --- a/util/net/dialer_linux.go +++ b/util/net/dialer_init_linux.go @@ -7,6 +7,6 @@ import "syscall" // init configures the net.Dialer Control function to set the fwmark on the socket func (d *Dialer) init() { d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_init_nonlinux.go similarity index 58% rename from util/net/dialer_nonlinux.go rename to util/net/dialer_init_nonlinux.go index c838441bdb5..8c57ebbaa52 100644 --- a/util/net/dialer_nonlinux.go +++ b/util/net/dialer_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (d *Dialer) init() { + // implemented on Linux and Android only } diff --git a/util/net/env.go b/util/net/env.go new file mode 100644 index 00000000000..099da39b760 --- /dev/null +++ b/util/net/env.go @@ -0,0 +1,29 @@ +package net + +import ( + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +const ( + envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" +) + +func CustomRoutingDisabled() bool { + if netstack.IsEnabled() { + return true + } + return os.Getenv(envDisableCustomRouting) == "true" +} + +func SkipSocketMark() bool { + if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" { + log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark) + return true + } + return false +} diff --git a/util/net/listen.go b/util/net/listen.go new file mode 100644 index 00000000000..3ae8a9435cb --- /dev/null +++ b/util/net/listen.go @@ -0,0 +1,37 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil +} diff --git a/util/net/listener_ios.go b/util/net/listen_ios.go similarity index 100% rename from util/net/listener_ios.go rename to util/net/listen_ios.go diff --git a/util/net/listener_android.go b/util/net/listener_android.go deleted file mode 100644 index d4167ad53a6..00000000000 --- a/util/net/listener_android.go +++ /dev/null @@ -1,26 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -// init configures the net.ListenerConfig Control function to set the fwmark on the socket -func (l *ListenerConfig) init() { - l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect listener socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/listener_init_android.go b/util/net/listener_init_android.go new file mode 100644 index 00000000000..f7bfa1dab27 --- /dev/null +++ b/util/net/listener_init_android.go @@ -0,0 +1,6 @@ +package net + +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = ControlProtectSocket +} diff --git a/util/net/listener_linux.go b/util/net/listener_init_linux.go similarity index 89% rename from util/net/listener_linux.go rename to util/net/listener_init_linux.go index 8d332160a04..e32d5d8942e 100644 --- a/util/net/listener_linux.go +++ b/util/net/listener_init_linux.go @@ -9,6 +9,6 @@ import ( // init configures the net.ListenerConfig Control function to set the fwmark on the socket func (l *ListenerConfig) init() { l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/listener_nonlinux.go b/util/net/listener_init_nonlinux.go similarity index 61% rename from util/net/listener_nonlinux.go rename to util/net/listener_init_nonlinux.go index 14a6be49dc3..80f6f7f1a55 100644 --- a/util/net/listener_nonlinux.go +++ b/util/net/listener_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (l *ListenerConfig) init() { + // implemented on Linux and Android only } diff --git a/util/net/listener_nonios.go b/util/net/listener_listen.go similarity index 84% rename from util/net/listener_nonios.go rename to util/net/listener_listen.go index ae4be34949b..efffba40e6e 100644 --- a/util/net/listener_nonios.go +++ b/util/net/listener_listen.go @@ -8,7 +8,6 @@ import ( "net" "sync" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" ) @@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { return err } - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/net.go b/util/net/net.go index 5448eb85a5f..403aa87e7d1 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -2,9 +2,6 @@ package net import ( "net" - "os" - - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/google/uuid" ) @@ -16,8 +13,6 @@ const ( PreroutingFwmarkRedirected = 0x1BD01 PreroutingFwmarkMasquerade = 0x1BD11 PreroutingFwmarkMasqueradeReturn = 0x1BD12 - - envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) // ConnectionID provides a globally unique identifier for network connections. @@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } - -func CustomRoutingDisabled() bool { - if netstack.IsEnabled() { - return true - } - return os.Getenv(envDisableCustomRouting) == "true" -} diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 98f49af8d00..fc486ebd496 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -4,29 +4,42 @@ package net import ( "fmt" - "os" "syscall" log "github.com/sirupsen/logrus" ) -const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK" - // SetSocketMark sets the SO_MARK option on the given socket connection func SetSocketMark(conn syscall.Conn) error { + if isSocketMarkDisabled() { + return nil + } + sysconn, err := conn.SyscallConn() if err != nil { return fmt.Errorf("get raw conn: %w", err) } - return SetRawSocketMark(sysconn) + return setRawSocketMark(sysconn) +} + +// SetSocketOpt sets the SO_MARK option on the given file descriptor +func SetSocketOpt(fd int) error { + if isSocketMarkDisabled() { + return nil + } + + return setSocketOptInt(fd) } -func SetRawSocketMark(conn syscall.RawConn) error { +func setRawSocketMark(conn syscall.RawConn) error { var setErr error err := conn.Control(func(fd uintptr) { - setErr = SetSocketOpt(int(fd)) + if isSocketMarkDisabled() { + return + } + setErr = setSocketOptInt(int(fd)) }) if err != nil { return fmt.Errorf("control: %w", err) @@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error { return nil } -func SetSocketOpt(fd int) error { +func setSocketOptInt(fd int) error { + return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) +} + +func isSocketMarkDisabled() bool { if CustomRoutingDisabled() { log.Infof("Custom routing is disabled, skipping SO_MARK") - return nil + return true } - // Check for the new environment variable - if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" { - log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK") - return nil + if SkipSocketMark() { + return true } - - return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) + return false } diff --git a/util/net/protectsocket_android.go b/util/net/protectsocket_android.go index 64fb45aa44e..febed8a1e2b 100644 --- a/util/net/protectsocket_android.go +++ b/util/net/protectsocket_android.go @@ -1,14 +1,42 @@ package net -import "sync" +import ( + "fmt" + "sync" + "syscall" +) var ( androidProtectSocketLock sync.Mutex androidProtectSocket func(fd int32) bool ) -func SetAndroidProtectSocketFn(f func(fd int32) bool) { +func SetAndroidProtectSocketFn(fn func(fd int32) bool) { androidProtectSocketLock.Lock() - androidProtectSocket = f + androidProtectSocket = fn androidProtectSocketLock.Unlock() } + +// ControlProtectSocket is a Control function that sets the fwmark on the socket +func ControlProtectSocket(_, _ string, c syscall.RawConn) error { + var aErr error + err := c.Control(func(fd uintptr) { + androidProtectSocketLock.Lock() + defer androidProtectSocketLock.Unlock() + + if androidProtectSocket == nil { + aErr = fmt.Errorf("socket protection function not set") + return + } + + if !androidProtectSocket(int32(fd)) { + aErr = fmt.Errorf("failed to protect socket via Android") + } + }) + + if err != nil { + return err + } + + return aErr +}