diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 2592ff840b0..2e745a31e00 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -3,6 +3,7 @@ package iptables import ( "fmt" "net" + "slices" "strconv" "github.com/coreos/go-iptables/iptables" @@ -99,6 +100,16 @@ func (m *aclManager) AddPeerFiltering( ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal) specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName) + + mangleSpecs := slices.Clone(specs) + mangleSpecs = append(mangleSpecs, + "-i", m.wgIface.Name(), + "-m", "addrtype", "--dst-type", "LOCAL", + "-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), + ) + + specs = append(specs, "-j", actionToStr(action)) + if ipsetName != "" { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { if err := ipset.Add(ipsetName, ip.String()); err != nil { @@ -130,7 +141,7 @@ func (m *aclManager) AddPeerFiltering( m.ipsetStore.addIpList(ipsetName, ipList) } - ok, err := m.iptablesClient.Exists("filter", chain, specs...) + ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...) if err != nil { return nil, fmt.Errorf("failed to check rule: %w", err) } @@ -138,16 +149,22 @@ func (m *aclManager) AddPeerFiltering( return nil, fmt.Errorf("rule already exists") } - if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { + if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil { return nil, err } + if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil { + log.Errorf("failed to add mangle rule: %v", err) + mangleSpecs = nil + } + rule := &Rule{ - ruleID: uuid.New().String(), - specs: specs, - ipsetName: ipsetName, - ip: ip.String(), - chain: chain, + ruleID: uuid.New().String(), + specs: specs, + mangleSpecs: mangleSpecs, + ipsetName: ipsetName, + ip: ip.String(), + chain: chain, } m.updateState() @@ -190,6 +207,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) } + if r.mangleSpecs != nil { + if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil { + log.Errorf("failed to delete mangle rule: %v", err) + } + } + m.updateState() return nil @@ -310,17 +333,10 @@ func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialOptionalEntries() { m.optionalEntries["FORWARD"] = []entry{ { - spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules}, + spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"}, position: 2, }, } - - m.optionalEntries["PREROUTING"] = []entry{ - { - spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)}, - position: 1, - }, - } } func (m *aclManager) appendToEntries(chainName string, spec []string) { @@ -377,7 +393,7 @@ func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.A if dPort != "" { specs = append(specs, "--dport", dPort) } - return append(specs, "-j", actionToStr(action)) + return specs } func actionToStr(action firewall.Action) string { diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go index 1047c5cf8ff..e90e32f8b02 100644 --- a/client/firewall/iptables/rule.go +++ b/client/firewall/iptables/rule.go @@ -5,9 +5,10 @@ type Rule struct { ruleID string ipsetName string - specs []string - ip string - chain string + specs []string + mangleSpecs []string + ip string + chain string } // GetRuleID returns the rule id diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 8c1d89e6833..0d1d659afee 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "net" + "slices" "strconv" "strings" "time" @@ -46,6 +47,7 @@ type AclManager struct { workTable *nftables.Table chainInputRules *nftables.Chain + chainPrerouting *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -118,23 +120,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { } if r.nftSet == nil { - err := m.rConn.DelRule(r.nftRule) - if err != nil { + if err := m.rConn.DelRule(r.nftRule); err != nil { log.Errorf("failed to delete rule: %v", err) } + if r.mangleRule != nil { + if err := m.rConn.DelRule(r.mangleRule); err != nil { + log.Errorf("failed to delete mangle rule: %v", err) + } + } delete(m.rules, r.GetRuleID()) return m.rConn.Flush() } ips, ok := m.ipsetStore.ips(r.nftSet.Name) if !ok { - err := m.rConn.DelRule(r.nftRule) - if err != nil { + if err := m.rConn.DelRule(r.nftRule); err != nil { log.Errorf("failed to delete rule: %v", err) } + if r.mangleRule != nil { + if err := m.rConn.DelRule(r.mangleRule); err != nil { + log.Errorf("failed to delete mangle rule: %v", err) + } + } delete(m.rules, r.GetRuleID()) return m.rConn.Flush() } + if _, ok := ips[r.ip.String()]; ok { err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) if err != nil { @@ -153,12 +164,16 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { return nil } - err := m.rConn.DelRule(r.nftRule) - if err != nil { + if err := m.rConn.DelRule(r.nftRule); err != nil { log.Errorf("failed to delete rule: %v", err) } - err = m.rConn.Flush() - if err != nil { + if r.mangleRule != nil { + if err := m.rConn.DelRule(r.mangleRule); err != nil { + log.Errorf("failed to delete mangle rule: %v", err) + } + } + + if err := m.rConn.Flush(); err != nil { return err } @@ -225,9 +240,12 @@ func (m *AclManager) Flush() error { return err } - if err := m.refreshRuleHandles(m.chainInputRules); err != nil { + if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil { log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) } + if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil { + log.Errorf("failed to refresh rule handles prerouting chain: %v", err) + } return nil } @@ -244,10 +262,11 @@ func (m *AclManager) addIOFiltering( ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ - r.nftRule, - r.nftSet, - r.ruleID, - ip, + nftRule: r.nftRule, + mangleRule: r.mangleRule, + nftSet: r.nftSet, + ruleID: r.ruleID, + ip: ip, }, nil } @@ -340,11 +359,13 @@ func (m *AclManager) addIOFiltering( ) } + mainExpressions := slices.Clone(expressions) + switch action { case firewall.ActionAccept: - expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept}) + mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept}) case firewall.ActionDrop: - expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) + mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop}) } userData := []byte(strings.Join([]string{ruleId, comment}, " ")) @@ -353,15 +374,16 @@ func (m *AclManager) addIOFiltering( nftRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: chain, - Exprs: expressions, + Exprs: mainExpressions, UserData: userData, }) rule := &Rule{ - nftRule: nftRule, - nftSet: ipset, - ruleID: ruleId, - ip: ip, + nftRule: nftRule, + mangleRule: m.createPreroutingRule(expressions, userData), + nftSet: ipset, + ruleID: ruleId, + ip: ip, } m.rules[ruleId] = rule if ipset != nil { @@ -370,6 +392,59 @@ func (m *AclManager) addIOFiltering( return rule, nil } +func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule { + if m.chainPrerouting == nil { + log.Warn("prerouting chain is not created") + return nil + } + + preroutingExprs := slices.Clone(expressions) + + // interface + preroutingExprs = append([]expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + }, preroutingExprs...) + + // local destination and mark + preroutingExprs = append(preroutingExprs, + &expr.Fib{ + Register: 1, + ResultADDRTYPE: true, + FlagDADDR: true, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL), + }, + + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + SourceRegister: true, + }, + ) + + return m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: m.chainPrerouting, + Exprs: preroutingExprs, + UserData: userData, + }) +} + func (m *AclManager) createDefaultChains() (err error) { // chainNameInputRules chain := m.createChain(chainNameInputRules) @@ -413,7 +488,7 @@ func (m *AclManager) createDefaultChains() (err error) { // go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the // netbird peer IP. func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { - preroutingChain := m.rConn.AddChain(&nftables.Chain{ + m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{ Name: chainNamePrerouting, Table: m.workTable, Type: nftables.ChainTypeFilter, @@ -421,8 +496,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error Priority: nftables.ChainPriorityMangle, }) - m.addPreroutingRule(preroutingChain) - m.addFwmarkToForward(chainFwFilter) if err := m.rConn.Flush(); err != nil { @@ -432,43 +505,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error return nil } -func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) { - m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: preroutingChain, - Exprs: []expr.Any{ - &expr.Meta{ - Key: expr.MetaKeyIIFNAME, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Fib{ - Register: 1, - ResultADDRTYPE: true, - FlagDADDR: true, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL), - }, - &expr.Immediate{ - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - SourceRegister: true, - }, - }, - }) -} - func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { m.rConn.InsertRule(&nftables.Rule{ Table: m.workTable, @@ -484,8 +520,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), }, &expr.Verdict{ - Kind: expr.VerdictJump, - Chain: m.chainInputRules.Name, + Kind: expr.VerdictAccept, }, }, }) @@ -632,6 +667,7 @@ func (m *AclManager) flushWithBackoff() (err error) { for i := 0; ; i++ { err = m.rConn.Flush() if err != nil { + log.Debugf("failed to flush nftables: %v", err) if !strings.Contains(err.Error(), "busy") { return } @@ -648,7 +684,7 @@ func (m *AclManager) flushWithBackoff() (err error) { return } -func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { +func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error { if m.workTable == nil || chain == nil { return nil } @@ -665,7 +701,11 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { split := bytes.Split(rule.UserData, []byte(" ")) r, ok := m.rules[string(split[0])] if ok { - *r.nftRule = *rule + if mangle { + *r.mangleRule = *rule + } else { + *r.nftRule = *rule + } } } diff --git a/client/firewall/nftables/rule_linux.go b/client/firewall/nftables/rule_linux.go index 678c10b4409..4d652346b95 100644 --- a/client/firewall/nftables/rule_linux.go +++ b/client/firewall/nftables/rule_linux.go @@ -8,10 +8,11 @@ import ( // Rule to handle management of rules type Rule struct { - nftRule *nftables.Rule - nftSet *nftables.Set - ruleID string - ip net.IP + nftRule *nftables.Rule + mangleRule *nftables.Rule + nftSet *nftables.Set + ruleID string + ip net.IP } // GetRuleID returns the rule id