diff --git a/demuxer/tcp.go b/demuxer/tcp.go index f732c57..3b71177 100644 --- a/demuxer/tcp.go +++ b/demuxer/tcp.go @@ -3,6 +3,7 @@ package demuxer import ( "context" + "net" "runtime" "runtime/debug" "time" @@ -10,6 +11,7 @@ import ( "github.com/m-lab/go/bytecount" "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/m-lab/go/anonymize" "github.com/m-lab/packet-headers/metrics" "github.com/m-lab/packet-headers/saver" @@ -23,6 +25,11 @@ type Saver interface { State() string } +type detector interface { + AddSyn(net.IP) + IsFlooding(net.IP) bool +} + // A saver implementation that only drains the packet and UUID channels, to // effectively ignore all packets and UUIDs for specific flows. type drain struct { @@ -91,6 +98,8 @@ type TCP struct { UUIDChan chan<- UUIDEvent uuidReadChan <-chan UUIDEvent + synFloodDetector detector + // We use a generational GC. Every time the GC timer advances, we garbage // collect all savers in oldFlows and make all the currentFlows into // oldFlows. It is only through this garbage collection process that @@ -178,6 +187,7 @@ func (d *TCP) savePacket(ctx context.Context, packet gopacket.Packet) { metrics.DemuxerBadPacket.Inc() return } + // Send the packet to the saver. s := d.getSaver(ctx, fromPacket(packet)) if s == nil { @@ -254,6 +264,22 @@ func (d *TCP) CapturePackets(ctx context.Context, packets <-chan gopacket.Packet for { select { case packet := <-packets: + // Early SYN flood detection + srcIP := net.IP(packet.NetworkLayer().NetworkFlow().Src().Raw()) + dstIP := net.IP(packet.NetworkLayer().NetworkFlow().Dst().Raw()) + tcpLayer := packet.TransportLayer().(*layers.TCP) + + // Update flood detection only for SYN packets + if tcpLayer.SYN && !tcpLayer.ACK { + d.synFloodDetector.AddSyn(srcIP) + } + + // Drop packets if either endpoint is flooding + if d.synFloodDetector.IsFlooding(srcIP) || d.synFloodDetector.IsFlooding(dstIP) { + metrics.SynFloodDrops.Inc() + continue + } + // Get a packet and save it. d.savePacket(ctx, packet) case ev = <-d.uuidReadChan: @@ -271,12 +297,16 @@ func (d *TCP) CapturePackets(ctx context.Context, packets <-chan gopacket.Packet // NewTCP creates a demuxer.TCP, which is the system which chooses which channel // to send TCP/IP packets for subsequent saving to a file. -func NewTCP(anon anonymize.IPAnonymizer, dataDir string, uuidWaitDuration, maxFlowDuration time.Duration, maxIdleRAM bytecount.ByteCount, stream bool, maxHeap uint64, maxFlows int) *TCP { +func NewTCP(anon anonymize.IPAnonymizer, dataDir string, + uuidWaitDuration, maxFlowDuration time.Duration, maxIdleRAM bytecount.ByteCount, + stream bool, maxHeap uint64, maxFlows int, detector detector) *TCP { uuidc := make(chan UUIDEvent, 100) return &TCP{ UUIDChan: uuidc, uuidReadChan: uuidc, + synFloodDetector: detector, + currentFlows: make(map[FlowKey]Saver), oldFlows: make(map[FlowKey]Saver), maxIdleRAM: maxIdleRAM, diff --git a/demuxer/tcp_test.go b/demuxer/tcp_test.go index f68a9c5..cdba5fb 100644 --- a/demuxer/tcp_test.go +++ b/demuxer/tcp_test.go @@ -5,6 +5,7 @@ import ( "errors" "io/ioutil" "log" + "net" "os" "os/exec" "sync" @@ -37,12 +38,17 @@ func (f *fakePacketSource) run() { } } +type noopDetector struct{} + +func (d *noopDetector) AddSyn(net.IP) {} +func (d *noopDetector) IsFlooding(net.IP) bool { return false } + func TestTCPDryRun(t *testing.T) { - dir, err := ioutil.TempDir("", "TestTCPDryRun") + dir, err := os.MkdirTemp("", "TestTCPDryRun") rtx.Must(err, "Could not create directory") defer os.RemoveAll(dir) - tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 500*time.Millisecond, time.Second, 1000000000, true, uint64(2*bytecount.Gigabyte), 0) + tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 500*time.Millisecond, time.Second, 1000000000, true, uint64(2*bytecount.Gigabyte), 0, nil) // While we have a demuxer created, make sure that the processing path for // packets does not crash when given a nil packet. @@ -85,7 +91,7 @@ func TestTCPWithRealPcaps(t *testing.T) { rtx.Must(err, "Could not create directory") defer os.RemoveAll(dir) - tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 500*time.Millisecond, time.Second, 1000000000, true, uint64(2*bytecount.Gigabyte), 0) + tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 500*time.Millisecond, time.Second, 1000000000, true, uint64(2*bytecount.Gigabyte), 0, &noopDetector{}) st := &statusTracker{} tcpdm.status = st ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -242,7 +248,7 @@ func TestUUIDWontBlock(t *testing.T) { rtx.Must(err, "Could not create directory") defer os.RemoveAll(dir) - tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 15*time.Second, 30*time.Second, 1, true, uint64(2*bytecount.Gigabyte), 0) + tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 15*time.Second, 30*time.Second, 1, true, uint64(2*bytecount.Gigabyte), 0, &noopDetector{}) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) var wg sync.WaitGroup diff --git a/detector/synflood.go b/detector/synflood.go new file mode 100644 index 0000000..07c4dd8 --- /dev/null +++ b/detector/synflood.go @@ -0,0 +1,131 @@ +package detector + +import ( + "hash" + "net" + "sync" + "time" + + "github.com/cespare/xxhash/v2" +) + +type SynFloodDetector struct { + sketches []*CMSketch // Array of Count-Min sketches for time windows + windowSize time.Duration // How long each sketch covers (e.g. 10s) + numWindows int // Number of sliding windows to maintain + currentWindow int // Index of current active window + threshold uint32 // Max SYNs per source per window before flagging + lastRotation time.Time // When we last rotated windows + mu sync.RWMutex +} + +// Individual Count-Min sketch +type CMSketch struct { + width uint32 // Number of counters per row + depth uint32 // Number of rows (hash functions) + counts [][]uint32 // The actual counter matrix + hashes []hash.Hash64 // Hash functions for each row +} + +// Constructor for flood detector +func NewSynFloodDetector(windowSize time.Duration, numWindows int, width, depth uint32, threshold uint32) *SynFloodDetector { + sketches := make([]*CMSketch, numWindows) + for i := 0; i < numWindows; i++ { + sketches[i] = NewCMSketch(width, depth) + } + + return &SynFloodDetector{ + sketches: sketches, + windowSize: windowSize, + numWindows: numWindows, + currentWindow: 0, + threshold: threshold, + lastRotation: time.Now(), + } +} + +// Add a SYN packet from an IP +func (d *SynFloodDetector) AddSyn(srcIP net.IP) { + d.rotateIfNeeded() + d.sketches[d.currentWindow].Add(srcIP) +} + +// Check if an IP is currently flooding +func (d *SynFloodDetector) IsFlooding(srcIP net.IP) bool { + d.rotateIfNeeded() + // Sum estimates across all windows + var total uint32 + for _, sketch := range d.sketches { + total += sketch.Estimate(srcIP) + } + return total >= d.threshold +} + +// Rotate windows periodically +func (d *SynFloodDetector) rotateIfNeeded() { + now := time.Now() + if now.Sub(d.lastRotation) < d.windowSize { + return + } + + // Reset next window and advance + nextWindow := (d.currentWindow + 1) % d.numWindows + d.sketches[nextWindow].Reset() + d.currentWindow = nextWindow + d.lastRotation = now +} + +// Constructor for Count-Min sketch +func NewCMSketch(width, depth uint32) *CMSketch { + counts := make([][]uint32, depth) + for i := range counts { + counts[i] = make([]uint32, width) + } + + hashes := make([]hash.Hash64, depth) + for i := range hashes { + // Initialize with different seeds + hashes[i] = xxhash.New() + hashes[i].Write([]byte{byte(i)}) + } + + return &CMSketch{ + width: width, + depth: depth, + counts: counts, + hashes: hashes, + } +} + +// Add an item to sketch +func (s *CMSketch) Add(item []byte) { + for i := range s.hashes { + s.hashes[i].Reset() + s.hashes[i].Write(item) + h := s.hashes[i].Sum64() % uint64(s.width) + s.counts[i][h]++ + } +} + +// Get estimated count for item +func (s *CMSketch) Estimate(item []byte) uint32 { + var min uint32 = ^uint32(0) // Max uint32 value + for i := range s.hashes { + s.hashes[i].Reset() + s.hashes[i].Write(item) + h := s.hashes[i].Sum64() % uint64(s.width) + if s.counts[i][h] < min { + min = s.counts[i][h] + } + } + return min +} + +// Reset all counters in a sketch +func (s *CMSketch) Reset() { + for i := range s.counts { + for j := range s.counts[i] { + s.counts[i][j] = 0 + } + } +} diff --git a/detector/synflood_test.go b/detector/synflood_test.go new file mode 100644 index 0000000..7f69b02 --- /dev/null +++ b/detector/synflood_test.go @@ -0,0 +1,360 @@ +package detector + +import ( + "fmt" + "net" + "testing" + "time" +) + +func TestNewSynFloodDetector(t *testing.T) { + tests := []struct { + name string + windowSize time.Duration + numWindows int + width uint32 + depth uint32 + threshold uint32 + }{ + { + name: "basic detector", + windowSize: time.Second, + numWindows: 2, + width: 1000, + depth: 4, + threshold: 100, + }, + { + name: "minimal detector", + windowSize: time.Millisecond, + numWindows: 1, + width: 1, + depth: 1, + threshold: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewSynFloodDetector(tt.windowSize, tt.numWindows, tt.width, tt.depth, tt.threshold) + if d == nil { + t.Fatal("Expected non-nil detector") + } + if len(d.sketches) != tt.numWindows { + t.Errorf("Got %d sketches, want %d", len(d.sketches), tt.numWindows) + } + if d.threshold != tt.threshold { + t.Errorf("Got threshold %d, want %d", d.threshold, tt.threshold) + } + }) + } +} + +func TestCMSketch(t *testing.T) { + s := NewCMSketch(1000, 4) + if s == nil { + t.Fatal("Expected non-nil sketch") + } + + // Test Add and Estimate + ip := net.ParseIP("192.0.2.1").To4() + if ip == nil { + t.Fatal("Failed to parse IP") + } + + // Initial count should be 0 + if count := s.Estimate(ip); count != 0 { + t.Errorf("Initial count should be 0, got %d", count) + } + + // Add once + s.Add(ip) + if count := s.Estimate(ip); count != 1 { + t.Errorf("Count after one add should be 1, got %d", count) + } + + // Test Reset + s.Reset() + if count := s.Estimate(ip); count != 0 { + t.Errorf("Count after reset should be 0, got %d", count) + } +} + +func TestSynFloodDetection(t *testing.T) { + // Create a detector with small numbers for testing + d := NewSynFloodDetector(10*time.Millisecond, 2, 100, 4, 3) + ip := net.ParseIP("192.0.2.1").To4() + + // Should not be flooding initially + if d.IsFlooding(ip) { + t.Error("Should not detect flooding before any SYNs") + } + + // Add SYNs up to threshold + for i := 0; i < 3; i++ { + d.AddSyn(ip) + } + + // Should detect flooding + if !d.IsFlooding(ip) { + t.Error("Should detect flooding after threshold exceeded") + } + + // Test window rotation + time.Sleep(11 * time.Millisecond) // Wait for window rotation + d.AddSyn(ip) // This should go to new window + + // Should still be flooding (counts from both windows) + if !d.IsFlooding(ip) { + t.Error("Should still detect flooding across windows") + } + + // Wait for first window to expire + time.Sleep(11 * time.Millisecond) + // Force rotation + d.rotateIfNeeded() + + // Should no longer be flooding (old window rotated out) + if d.IsFlooding(ip) { + t.Error("Should not detect flooding after window rotation") + } +} + +func TestMultipleIPs(t *testing.T) { + d := NewSynFloodDetector(time.Second, 2, 100, 4, 3) + ip1 := net.ParseIP("192.0.2.1").To4() + ip2 := net.ParseIP("192.0.2.2").To4() + + // Add SYNs for first IP + for i := 0; i < 3; i++ { + d.AddSyn(ip1) + } + + // Check that only first IP is flooding + if !d.IsFlooding(ip1) { + t.Error("IP1 should be flooding") + } + if d.IsFlooding(ip2) { + t.Error("IP2 should not be flooding") + } + + // Add SYNs for second IP + for i := 0; i < 3; i++ { + d.AddSyn(ip2) + } + + // Both should be flooding + if !d.IsFlooding(ip1) { + t.Error("IP1 should still be flooding") + } + if !d.IsFlooding(ip2) { + t.Error("IP2 should now be flooding") + } +} + +func TestWindowRotation(t *testing.T) { + d := NewSynFloodDetector(10*time.Millisecond, 3, 100, 4, 5) + ip := net.ParseIP("192.0.2.1").To4() + + // Add SYNs in first window + d.AddSyn(ip) + d.AddSyn(ip) + + initialWindow := d.currentWindow + + // Wait for rotation + time.Sleep(11 * time.Millisecond) + d.rotateIfNeeded() + + if d.currentWindow == initialWindow { + t.Error("Window should have rotated") + } + + // Add more SYNs in new window + d.AddSyn(ip) + d.AddSyn(ip) + + // Total should be 4 (not flooding yet) + if d.IsFlooding(ip) { + t.Error("Should not be flooding with 4 SYNs across windows") + } + + // Add one more to exceed threshold + d.AddSyn(ip) + if !d.IsFlooding(ip) { + t.Error("Should be flooding after 5 SYNs") + } +} + +// Benchmarks +func BenchmarkAddSyn_SingleIP(b *testing.B) { + d := NewSynFloodDetector(time.Second, 2, 1000, 4, 100) + ip := net.ParseIP("192.0.2.1").To4() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + d.AddSyn(ip) + } +} + +// Benchmark adding SYNs from different IPs +func BenchmarkAddSyn_MultipleIPs(b *testing.B) { + d := NewSynFloodDetector(time.Second, 2, 1000, 4, 100) + // Pre-generate IPs to avoid IP generation overhead in benchmark + ips := make([]net.IP, 1000) + for i := range ips { + ips[i] = net.IPv4(192, 0, byte(i/256), byte(i%256)).To4() + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + d.AddSyn(ips[i%len(ips)]) + } +} + +// Benchmark flood checking for a single IP +func BenchmarkIsFlooding_SingleIP(b *testing.B) { + d := NewSynFloodDetector(time.Second, 2, 1000, 4, 100) + ip := net.ParseIP("192.0.2.1").To4() + + // Add some SYNs first + for i := 0; i < 50; i++ { + d.AddSyn(ip) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + d.IsFlooding(ip) + } +} + +// Benchmark flood checking with window rotation +func BenchmarkIsFlooding_WithRotation(b *testing.B) { + d := NewSynFloodDetector(time.Nanosecond, 2, 1000, 4, 100) // Very small window to force rotation + ip := net.ParseIP("192.0.2.1").To4() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + d.IsFlooding(ip) + } +} + +// Benchmark the CM Sketch operations directly +func BenchmarkCMSketch_Add(b *testing.B) { + s := NewCMSketch(1000, 4) + ip := net.ParseIP("192.0.2.1").To4() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Add(ip) + } +} + +func BenchmarkCMSketch_Estimate(b *testing.B) { + s := NewCMSketch(1000, 4) + ip := net.ParseIP("192.0.2.1").To4() + + // Add some data first + for i := 0; i < 50; i++ { + s.Add(ip) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Estimate(ip) + } +} + +// Benchmark different sketch sizes +func BenchmarkCMSketch_DifferentSizes(b *testing.B) { + sizes := []struct { + width uint32 + depth uint32 + }{ + {100, 2}, + {1000, 4}, + {10000, 8}, + } + + ip := net.ParseIP("192.0.2.1").To4() + + for _, size := range sizes { + b.Run(fmt.Sprintf("Width=%d_Depth=%d", size.width, size.depth), func(b *testing.B) { + s := NewCMSketch(size.width, size.depth) + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Add(ip) + s.Estimate(ip) + } + }) + } +} + +// Benchmark memory pressure scenarios +func BenchmarkUnderMemoryPressure(b *testing.B) { + // Create a lot of sketches to simulate memory pressure + sketches := make([]*CMSketch, 100) + for i := range sketches { + sketches[i] = NewCMSketch(1000, 4) + } + + d := NewSynFloodDetector(time.Second, 2, 1000, 4, 100) + ip := net.ParseIP("192.0.2.1").To4() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + d.AddSyn(ip) + d.IsFlooding(ip) + } +} + +func BenchmarkAccuracyVsSize(b *testing.B) { + configs := []struct { + name string + width uint32 + depth uint32 + numIPs int + synsPerIP int + }{ + {"Speedtest-W4k-D4", 4000, 4, 200, 100}, // ~64KB + {"Speedtest-W8k-D4", 8000, 4, 200, 100}, // ~128KB + {"Speedtest-W8k-D5", 8000, 5, 200, 100}, // ~160KB + {"Wide-W4k-D4", 4000, 4, 100, 100}, // ~64KB + {"Medium-1kIPs-W10k-D5", 10000, 5, 1000, 50}, // ~200KB + } + + for _, cfg := range configs { + b.Run(cfg.name, func(b *testing.B) { + s := NewCMSketch(cfg.width, cfg.depth) + ips := make([]net.IP, cfg.numIPs) + for i := range ips { + ips[i] = net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).To4() + } + + // Measure accuracy + b.ReportMetric(0, "error_pct") // Custom metric + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Add known number of SYNs + for _, ip := range ips { + for k := 0; k < cfg.synsPerIP; k++ { + s.Add(ip) + } + } + + // Calculate error rate + var totalError float64 + for _, ip := range ips { + est := s.Estimate(ip) + error := float64(est-uint32(cfg.synsPerIP)) / float64(cfg.synsPerIP) + totalError += error + } + avgError := (totalError / float64(cfg.numIPs)) * 100 + b.ReportMetric(avgError, "error_pct") + + s.Reset() + } + }) + } +} diff --git a/go.mod b/go.mod index 06d8d88..b62e2fa 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/m-lab/packet-headers go 1.20 require ( + github.com/cespare/xxhash/v2 v2.1.2 github.com/google/gopacket v1.1.19 github.com/m-lab/go v0.1.66 github.com/m-lab/tcp-info v1.5.3 @@ -13,7 +14,6 @@ require ( require ( github.com/araddon/dateparse v0.0.0-20200409225146-d820a6159ab1 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/gocarina/gocsv v0.0.0-20220729221910-a7386ae0b221 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect diff --git a/main.go b/main.go index 85a0fd4..b123385 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ import ( "github.com/m-lab/go/rtx" "github.com/m-lab/go/warnonerror" "github.com/m-lab/packet-headers/demuxer" + "github.com/m-lab/packet-headers/detector" "github.com/m-lab/packet-headers/muxer" "github.com/m-lab/packet-headers/tcpinfohandler" "github.com/m-lab/tcp-info/eventsocket" @@ -46,6 +47,13 @@ var ( // Context and injected variables to allow smoke testing of main() mainCtx, mainCancel = context.WithCancel(context.Background()) pcapOpenLive = pcap.OpenLive + + // SYN flood detection tunables. + synFloodWindow = flag.Duration("synflood.window", 5*time.Second, "Duration of each window for SYN flood detection") + synFloodWindows = flag.Int("synflood.numwindows", 12, "Number of windows to maintain (window × numwindows = block duration)") + synFloodWidth = flag.Uint("synflood.width", 8000, "Width of Count-Min Sketch (larger = more accurate, more memory)") + synFloodDepth = flag.Uint("synflood.depth", 4, "Number of hash functions (4-5 typical)") + synFloodThreshold = flag.Uint("synflood.threshold", 100, "Number of SYNs allowed per source IP across all windows before blocking") ) func init() { @@ -150,7 +158,14 @@ func main() { // Get ready to save the incoming packets to files. tcpdm := demuxer.NewTCP( anonymize.New(anonymize.IPAnonymizationFlag), *dir, *uuidWaitDuration, - *captureDuration, maxIdleRAM, *streamToDisk, uint64(maxHeap), *maxFlows) + *captureDuration, maxIdleRAM, *streamToDisk, uint64(maxHeap), *maxFlows, + detector.NewSynFloodDetector( + *synFloodWindow, + *synFloodWindows, + uint32(*synFloodWidth), + uint32(*synFloodDepth), + uint32(*synFloodThreshold), + )) // Inform the demuxer of new UUIDs h := tcpinfohandler.New(mainCtx, tcpdm.UUIDChan) diff --git a/metrics/metrics.go b/metrics/metrics.go index c4a0545..aa6b3e3 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -24,6 +24,13 @@ var ( []string{"saverstate"}, ) + SynFloodDrops = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "pcap_synflood_drops_total", + Help: "Number of flows dropped due to SYN flood detection", + }, + ) + // Savers are internal data structures each with a single associated // goroutine, that are allocated and run once for each connection. The start // and stop of that goroutine is counted in SaversStarted and SaversStopped,