diff --git a/network/portfwd/docker/docker.go b/network/portfwd/docker/docker.go index 7718714..d1ce181 100644 --- a/network/portfwd/docker/docker.go +++ b/network/portfwd/docker/docker.go @@ -36,6 +36,13 @@ func PortForwardings(tables nufftables.TableMap, family nufftables.TableFamily) if nattable == nil { return nil } + return grabPortForwardings(nattable) +} + +// grabPortForwardings is a convenience helper to wire up the individual port +// forwarding detectors in a single place, making maintenance easier for both +// PROD and TEST. +func grabPortForwardings(nattable *nufftables.Table) []*portfinder.ForwardedPortRange { forwardedPorts := forwardedPortsMk1(nattable) forwardedPorts = append(forwardedPorts, forwardedPortsMk2(nattable)...) forwardedPorts = append(forwardedPorts, forwardedPortsMk3(nattable)...) @@ -69,7 +76,7 @@ func forwardedPortsInChainMk2(chain *nufftables.Chain) []*portfinder.ForwardedPo family := chain.Table.Family forwardedPorts := []*portfinder.ForwardedPortRange{} for _, rule := range chain.Rules { - exprs, proto := nftget.L4ProtoTcpUdp(rule.Exprs) + exprs, proto := nftget.MetaL4ProtoTcpUdp(rule.Exprs) exprs, origIP := nftget.OptionalIPv46(exprs, family) exprs, port := nufftables.OfTypeTransformed(exprs, nftget.Port) exprs, dnat := dsl.TargetDNAT(exprs) @@ -112,7 +119,7 @@ func forwardedPortsInChainMk3(chain *nufftables.Chain) []*portfinder.ForwardedPo forwardedPorts := []*portfinder.ForwardedPortRange{} for _, rule := range chain.Rules { exprs, origIP := nftget.OptionalDestIPv46(rule.Exprs, family) - exprs, proto := nftget.L4ProtoTcpUdp(exprs) + exprs, proto := nftget.PayloadL4ProtoTcpUdp(exprs) exprs, port := nftget.PayloadPort(exprs) exprs, dnat := dsl.TargetDNAT(exprs) if exprs == nil || dnat.Flags&dnatWithIPsAndPorts != dnatWithIPsAndPorts || port == 0 { diff --git a/network/portfwd/docker/docker_test.go b/network/portfwd/docker/docker_test.go index cd6cb90..bef401c 100644 --- a/network/portfwd/docker/docker_test.go +++ b/network/portfwd/docker/docker_test.go @@ -18,20 +18,12 @@ import ( "github.com/thediveo/morbyd/session" "github.com/thediveo/notwork/netns" "github.com/thediveo/nufftables" - "github.com/thediveo/nufftables/portfinder" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" . "github.com/thediveo/success" ) -func fwports(nattable *nufftables.Table) []*portfinder.ForwardedPortRange { - forwardedPorts := forwardedPortsMk1(nattable) - forwardedPorts = append(forwardedPorts, forwardedPortsMk2(nattable)...) - forwardedPorts = append(forwardedPorts, forwardedPortsMk3(nattable)...) - return forwardedPorts -} - var _ = Describe("Docker port forwarding", Ordered, func() { var cntrPID int @@ -76,7 +68,7 @@ var _ = Describe("Docker port forwarding", Ordered, func() { nattable := tables.Table("nat", nufftables.TableFamilyIPv4) Expect(nattable).NotTo(BeNil()) Expect(nattable.ChainsByName).NotTo(BeEmpty()) - forwardedPorts := fwports(nattable) + forwardedPorts := grabPortForwardings(nattable) Expect(forwardedPorts).To(ContainElement(And( HaveField("Protocol", "tcp"), HaveField("IP", net.ParseIP("127.0.0.1").To4()), @@ -103,8 +95,11 @@ var _ = Describe("Docker port forwarding", Ordered, func() { nattable := tables.Table("nat", nufftables.TableFamilyIPv4) Expect(nattable).NotTo(BeNil()) Expect(nattable.ChainsByName).NotTo(BeEmpty()) - forwardedPorts := fwports(nattable) - Expect(forwardedPorts).To(ContainElements( + forwardedPorts := grabPortForwardings(nattable) + // Ensure to exactly match in order to catch any false positives; this + // is possible in this case because we're looking at the nft inside the + // container and thus know what should be there and what shouldn't. + Expect(forwardedPorts).To(ConsistOf( And( HaveField("Protocol", "tcp"), HaveField("IP", net.ParseIP("127.0.0.11").To4()), diff --git a/network/portfwd/nftget/l4proto.go b/network/portfwd/nftget/l4proto.go index 6e0fc83..42851c6 100644 --- a/network/portfwd/nftget/l4proto.go +++ b/network/portfwd/nftget/l4proto.go @@ -10,13 +10,17 @@ import ( "golang.org/x/sys/unix" ) -// L4ProtoTcpUdp returns the transport layer protocol name checked for from -// either a Meta/Cmp twin-expression or a Payload/Cmp twin-expression, together -// with the remaining expressions; otherwise, it returns nil. -func L4ProtoTcpUdp(exprs nufftables.Expressions) (nufftables.Expressions, string) { - if exprs, proto := nufftables.PrefixedOfTypeTransformed(exprs, isMetaL4Proto, TcpUdp); exprs != nil { - return exprs, proto - } +// MetaL4ProtoTcpUdp returns the transport layer protocol name checked for from +// a Meta/Cmp twin-expression, together with the remaining expressions; +// otherwise, it returns nil. +func MetaL4ProtoTcpUdp(exprs nufftables.Expressions) (nufftables.Expressions, string) { + return nufftables.PrefixedOfTypeTransformed(exprs, isMetaL4Proto, TcpUdp) +} + +// PayloadL4ProtoTcpUdp returns the transport layer protocol name checked for +// from a Payload/Cmp twin-expression, together with the remaining expressions; +// otherwise, it returns nil. +func PayloadL4ProtoTcpUdp(exprs nufftables.Expressions) (nufftables.Expressions, string) { return nufftables.PrefixedOfTypeTransformed(exprs, isPayloadIPv4L4Proto, TcpUdp) } diff --git a/network/portfwd/nftget/l4proto_test.go b/network/portfwd/nftget/l4proto_test.go index 868edc1..f9795cf 100644 --- a/network/portfwd/nftget/l4proto_test.go +++ b/network/portfwd/nftget/l4proto_test.go @@ -56,7 +56,7 @@ var _ = Describe("nftables L4 proto getter", func() { if cmp != nil { exprs = append(exprs, cmp) } - exprs, protoname := L4ProtoTcpUdp(exprs) + exprs, protoname := MetaL4ProtoTcpUdp(exprs) if expectedName == "" { Expect(exprs).To(BeNil()) } else {