Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: SYN flood detector #59

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion demuxer/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package demuxer

import (
"context"
"net"
"runtime"
"runtime/debug"
"time"

"github.com/m-lab/go/bytecount"

"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/m-lab/go/anonymize"
"github.com/m-lab/packet-headers/metrics"
"github.com/m-lab/packet-headers/saver"
Expand All @@ -23,6 +25,11 @@ type Saver interface {
State() string
}

type detector interface {
AddSyn(net.IP)
IsFlooding(net.IP) bool
}

// A saver implementation that only drains the packet and UUID channels, to
// effectively ignore all packets and UUIDs for specific flows.
type drain struct {
Expand Down Expand Up @@ -91,6 +98,8 @@ type TCP struct {
UUIDChan chan<- UUIDEvent
uuidReadChan <-chan UUIDEvent

synFloodDetector detector

// We use a generational GC. Every time the GC timer advances, we garbage
// collect all savers in oldFlows and make all the currentFlows into
// oldFlows. It is only through this garbage collection process that
Expand Down Expand Up @@ -178,6 +187,7 @@ func (d *TCP) savePacket(ctx context.Context, packet gopacket.Packet) {
metrics.DemuxerBadPacket.Inc()
return
}

// Send the packet to the saver.
s := d.getSaver(ctx, fromPacket(packet))
if s == nil {
Expand Down Expand Up @@ -254,6 +264,22 @@ func (d *TCP) CapturePackets(ctx context.Context, packets <-chan gopacket.Packet
for {
select {
case packet := <-packets:
// Early SYN flood detection
srcIP := net.IP(packet.NetworkLayer().NetworkFlow().Src().Raw())
dstIP := net.IP(packet.NetworkLayer().NetworkFlow().Dst().Raw())
tcpLayer := packet.TransportLayer().(*layers.TCP)

// Update flood detection only for SYN packets
if tcpLayer.SYN && !tcpLayer.ACK {
d.synFloodDetector.AddSyn(srcIP)
}

// Drop packets if either endpoint is flooding
if d.synFloodDetector.IsFlooding(srcIP) || d.synFloodDetector.IsFlooding(dstIP) {
metrics.SynFloodDrops.Inc()
continue
}

// Get a packet and save it.
d.savePacket(ctx, packet)
case ev = <-d.uuidReadChan:
Expand All @@ -271,12 +297,16 @@ func (d *TCP) CapturePackets(ctx context.Context, packets <-chan gopacket.Packet

// NewTCP creates a demuxer.TCP, which is the system which chooses which channel
// to send TCP/IP packets for subsequent saving to a file.
func NewTCP(anon anonymize.IPAnonymizer, dataDir string, uuidWaitDuration, maxFlowDuration time.Duration, maxIdleRAM bytecount.ByteCount, stream bool, maxHeap uint64, maxFlows int) *TCP {
func NewTCP(anon anonymize.IPAnonymizer, dataDir string,
uuidWaitDuration, maxFlowDuration time.Duration, maxIdleRAM bytecount.ByteCount,
stream bool, maxHeap uint64, maxFlows int, detector detector) *TCP {
uuidc := make(chan UUIDEvent, 100)
return &TCP{
UUIDChan: uuidc,
uuidReadChan: uuidc,

synFloodDetector: detector,

currentFlows: make(map[FlowKey]Saver),
oldFlows: make(map[FlowKey]Saver),
maxIdleRAM: maxIdleRAM,
Expand Down
14 changes: 10 additions & 4 deletions demuxer/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"io/ioutil"
"log"
"net"
"os"
"os/exec"
"sync"
Expand Down Expand Up @@ -37,12 +38,17 @@ func (f *fakePacketSource) run() {
}
}

type noopDetector struct{}

func (d *noopDetector) AddSyn(net.IP) {}
func (d *noopDetector) IsFlooding(net.IP) bool { return false }

func TestTCPDryRun(t *testing.T) {
dir, err := ioutil.TempDir("", "TestTCPDryRun")
dir, err := os.MkdirTemp("", "TestTCPDryRun")
rtx.Must(err, "Could not create directory")
defer os.RemoveAll(dir)

tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 500*time.Millisecond, time.Second, 1000000000, true, uint64(2*bytecount.Gigabyte), 0)
tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 500*time.Millisecond, time.Second, 1000000000, true, uint64(2*bytecount.Gigabyte), 0, nil)

// While we have a demuxer created, make sure that the processing path for
// packets does not crash when given a nil packet.
Expand Down Expand Up @@ -85,7 +91,7 @@ func TestTCPWithRealPcaps(t *testing.T) {
rtx.Must(err, "Could not create directory")
defer os.RemoveAll(dir)

tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 500*time.Millisecond, time.Second, 1000000000, true, uint64(2*bytecount.Gigabyte), 0)
tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 500*time.Millisecond, time.Second, 1000000000, true, uint64(2*bytecount.Gigabyte), 0, &noopDetector{})
st := &statusTracker{}
tcpdm.status = st
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
Expand Down Expand Up @@ -242,7 +248,7 @@ func TestUUIDWontBlock(t *testing.T) {
rtx.Must(err, "Could not create directory")
defer os.RemoveAll(dir)

tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 15*time.Second, 30*time.Second, 1, true, uint64(2*bytecount.Gigabyte), 0)
tcpdm := NewTCP(anonymize.New(anonymize.None), dir, 15*time.Second, 30*time.Second, 1, true, uint64(2*bytecount.Gigabyte), 0, &noopDetector{})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)

var wg sync.WaitGroup
Expand Down
131 changes: 131 additions & 0 deletions detector/synflood.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package detector

import (
"hash"
"net"
"sync"
"time"

"github.com/cespare/xxhash/v2"
)

type SynFloodDetector struct {
sketches []*CMSketch // Array of Count-Min sketches for time windows
windowSize time.Duration // How long each sketch covers (e.g. 10s)
numWindows int // Number of sliding windows to maintain
currentWindow int // Index of current active window
threshold uint32 // Max SYNs per source per window before flagging
lastRotation time.Time // When we last rotated windows
mu sync.RWMutex
}

// Individual Count-Min sketch
type CMSketch struct {
width uint32 // Number of counters per row
depth uint32 // Number of rows (hash functions)
counts [][]uint32 // The actual counter matrix
hashes []hash.Hash64 // Hash functions for each row
}

// Constructor for flood detector
func NewSynFloodDetector(windowSize time.Duration, numWindows int, width, depth uint32, threshold uint32) *SynFloodDetector {
sketches := make([]*CMSketch, numWindows)
for i := 0; i < numWindows; i++ {
sketches[i] = NewCMSketch(width, depth)
}

return &SynFloodDetector{
sketches: sketches,
windowSize: windowSize,
numWindows: numWindows,
currentWindow: 0,
threshold: threshold,
lastRotation: time.Now(),
}
}

// Add a SYN packet from an IP
func (d *SynFloodDetector) AddSyn(srcIP net.IP) {
d.rotateIfNeeded()
d.sketches[d.currentWindow].Add(srcIP)
}

// Check if an IP is currently flooding
func (d *SynFloodDetector) IsFlooding(srcIP net.IP) bool {
d.rotateIfNeeded()
// Sum estimates across all windows
var total uint32
for _, sketch := range d.sketches {
total += sketch.Estimate(srcIP)
}
return total >= d.threshold
}

// Rotate windows periodically
func (d *SynFloodDetector) rotateIfNeeded() {
now := time.Now()
if now.Sub(d.lastRotation) < d.windowSize {
return
}

// Reset next window and advance
nextWindow := (d.currentWindow + 1) % d.numWindows
d.sketches[nextWindow].Reset()
d.currentWindow = nextWindow
d.lastRotation = now
}

// Constructor for Count-Min sketch
func NewCMSketch(width, depth uint32) *CMSketch {
counts := make([][]uint32, depth)
for i := range counts {
counts[i] = make([]uint32, width)
}

hashes := make([]hash.Hash64, depth)
for i := range hashes {
// Initialize with different seeds
hashes[i] = xxhash.New()
hashes[i].Write([]byte{byte(i)})
}

return &CMSketch{
width: width,
depth: depth,
counts: counts,
hashes: hashes,
}
}

// Add an item to sketch
func (s *CMSketch) Add(item []byte) {
for i := range s.hashes {
s.hashes[i].Reset()
s.hashes[i].Write(item)
h := s.hashes[i].Sum64() % uint64(s.width)
s.counts[i][h]++
}
}

// Get estimated count for item
func (s *CMSketch) Estimate(item []byte) uint32 {
var min uint32 = ^uint32(0) // Max uint32 value
for i := range s.hashes {
s.hashes[i].Reset()
s.hashes[i].Write(item)
h := s.hashes[i].Sum64() % uint64(s.width)
if s.counts[i][h] < min {
min = s.counts[i][h]
}
}
return min
}

// Reset all counters in a sketch
func (s *CMSketch) Reset() {
for i := range s.counts {
for j := range s.counts[i] {
s.counts[i][j] = 0
}
}
}
Loading