From 1bbd0b05eb9bf2c4017fc08a43db8c7878667416 Mon Sep 17 00:00:00 2001 From: Weii Wang Date: Mon, 25 Sep 2023 15:54:00 +0800 Subject: [PATCH] Refactor the endpoint verification --- internal/analyzer/analyzer.go | 200 +++++------------------------ internal/analyzer/analyzer_test.go | 8 +- internal/analyzer/endpoints.go | 136 ++++++++++++++++++++ internal/analyzer/helper.go | 45 +++++++ internal/relay/relay.go | 2 +- 5 files changed, 218 insertions(+), 173 deletions(-) create mode 100644 internal/analyzer/endpoints.go create mode 100644 internal/analyzer/helper.go diff --git a/internal/analyzer/analyzer.go b/internal/analyzer/analyzer.go index 5bce757..6be6056 100644 --- a/internal/analyzer/analyzer.go +++ b/internal/analyzer/analyzer.go @@ -2,17 +2,12 @@ package analyzer import ( "context" - "crypto/hmac" "crypto/rand" "encoding/binary" "fmt" "github.com/weiiwang01/wpex/internal/exchange" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/chacha20poly1305" "log/slog" "net" - "sync" - "time" ) const ( @@ -23,138 +18,9 @@ const ( cookieReplySize = 64 ) -type unverifiedEndpoint struct { - cookie [cookieSize]byte - ttl time.Time -} - -type unverifiedEndpoints struct { - mu sync.Mutex - endpoints map[string]unverifiedEndpoint -} - -func (p *unverifiedEndpoints) cleanup() { - now := time.Now() - for addr := range p.endpoints { - if p.endpoints[addr].ttl.Before(now) { - slog.Debug("remove expired endpoint cookie", "addr", addr) - delete(p.endpoints, addr) - } - } -} - -// CreateReply generate a cookie reply for the handshake initiation message. -func (p *unverifiedEndpoints) CreateReply(addr net.UDPAddr, pubkey []byte, index uint32, mac1 []byte) ([]byte, error) { - p.mu.Lock() - defer p.mu.Unlock() - - p.cleanup() - - secret := make([]byte, 16) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - mac, err := blake2s.New128(secret) - if err != nil { - return nil, err - } - src, err := addr.AddrPort().MarshalBinary() - if err != nil { - return nil, err - } - mac.Write(src) - nonce := make([]byte, 24) - if _, err = rand.Read(nonce); err != nil { - return nil, err - } - hash, _ := blake2s.New256(nil) - hash.Write([]byte("cookie--")) - hash.Write(pubkey) - aead, err := chacha20poly1305.NewX(hash.Sum(nil)) - if err != nil { - return nil, err - } - cookie := mac.Sum(nil) - p.endpoints[addr.String()] = unverifiedEndpoint{ - cookie: [16]byte(mac.Sum(nil)), - ttl: time.Now().Add(time.Duration(10) * time.Second), - } - reply := make([]byte, 64) - reply[0] = 3 - binary.BigEndian.PutUint32(reply[4:8], index) - copy(reply[8:32], nonce) - aead.Seal(reply[:32], nonce, cookie, mac1) - return reply, nil -} - -// Verify verifies the mac2 in the handshake initiation message responding the cookie reply. -// Corresponding cookie will be removed if the verification succeed. -func (p *unverifiedEndpoints) Verify(addr net.UDPAddr, msg []byte) bool { - p.mu.Lock() - defer p.mu.Unlock() - - p.cleanup() - - endpoint, ok := p.endpoints[addr.String()] - if !ok { - return false - } - mac2 := msg[148-16:] - mac, _ := blake2s.New128(endpoint.cookie[:]) - mac.Write(msg[:148-16]) - if hmac.Equal(mac2, mac.Sum(nil)) { - delete(p.endpoints, addr.String()) - slog.Debug("endpoint verified", "addr", addr.String()) - return true - } - slog.Debug("endpoint verification failed", "addr", addr.String()) - return false -} - -func (p *unverifiedEndpoints) pendingVerify(addr net.UDPAddr) bool { - p.mu.Lock() - defer p.mu.Unlock() - - p.cleanup() - - _, ok := p.endpoints[addr.String()] - return ok -} - type WireguardAnalyzer struct { - table exchange.ExchangeTable - publicKeys [][]byte - unverified unverifiedEndpoints -} - -func (t *WireguardAnalyzer) matchPubkey(packet []byte) []byte { - if len(packet) < 32 { - return nil - } - if len(t.publicKeys) == 0 { - return nil - } - l := len(packet) - mac1 := packet[l-32 : l-16] - d := packet[:l-32] - for _, key := range t.publicKeys { - hash, err := blake2s.New256(nil) - if err != nil { - continue - } - hash.Write([]byte("mac1----")) - hash.Write(key) - mackey := hash.Sum(nil) - mac, err := blake2s.New128(mackey) - if err != nil { - continue - } - mac.Write(d) - if hmac.Equal(mac1, mac.Sum(nil)) { - return key - } - } - return nil + table exchange.ExchangeTable + checker EndpointChecker } func (t *WireguardAnalyzer) decodeIndex(index []byte) uint32 { @@ -168,36 +34,33 @@ func (t *WireguardAnalyzer) analyseHandshakeInitiation(packet []byte, peer net.U return nil, nil } sender := t.decodeIndex(packet[4:8]) - if len(t.publicKeys) > 0 { - pubkey := t.matchPubkey(packet) - if pubkey == nil { - logger.Warn("invalid mac1 in handshake initiation") + if t.table.Contains(peer) { + goto send + } + switch t.checker.VerifyHandshakeInitiation(peer, packet) { + case illegalEndpoint: + return nil, nil + case endpointNotVerified: + logger.Debug("send cookie reply to unknown endpoint") + reply, err := t.checker.CreateReply(peer, packet) + if err != nil { + logger.Error(fmt.Sprintf("failed to create cookie reply: %s", err)) return nil, nil } - if !t.table.Contains(peer) { - mac1Start := handshakeInitiationSize - macSize*2 - mac2Start := mac1Start + macSize - if !t.unverified.Verify(peer, packet) { - if t.unverified.pendingVerify(peer) { - logger.Debug("ignore handshake initiation from endpoint pending verification") - return nil, nil - } - reply, err := t.unverified.CreateReply(peer, pubkey, sender, packet[mac1Start:mac2Start]) - if err != nil { - logger.Error(fmt.Sprintf("fail to create cookie reply: %s", err)) - return nil, nil - } - logger.Debug("send cookie reply to unknown endpoint") - return []net.UDPAddr{peer}, reply - } else { - newPacket := make([]byte, handshakeInitiationSize) - copy(newPacket, packet[:mac2Start]) - packet = newPacket - } - } + return []net.UDPAddr{peer}, reply + case endpointPendingVerification: + logger.Debug("ignore handshake initiation from endpoint pending verification") + return nil, nil + case endpointVerified: + newPacket := make([]byte, handshakeInitiationSize) + copy(newPacket, packet[:handshakeInitiationSize-macSize]) + packet = newPacket + default: + panic("unknown endpoint verification status") } +send: if err := t.table.AddPeerAddr(sender, peer); err != nil { - logger.Error(fmt.Sprintf("fail to add address: %s", err)) + logger.Error(fmt.Sprintf("failed to add address: %s", err)) return nil, nil } addresses := t.table.ListAddrs(peer) @@ -211,8 +74,7 @@ func (t *WireguardAnalyzer) analyseHandshakeResponse(packet []byte, peer net.UDP logger.Warn(fmt.Sprintf("invalid handshake response: expected length %d, got %d", handshakeResponseSize, len(packet))) return nil, nil } - if len(t.publicKeys) > 0 && t.matchPubkey(packet) == nil { - logger.Warn("incorrect mac1 in handshake response") + if !t.checker.VerifyHandshakeResponse(peer, packet) { return nil, nil } sender := t.decodeIndex(packet[4:8]) @@ -310,15 +172,17 @@ func (t *WireguardAnalyzer) Analyse(packet []byte, peer net.UDPAddr) ([]net.UDPA } } -func MakeWireguardAnalyzer(publicKeys [][]byte) WireguardAnalyzer { +func MakeWireguardAnalyzer(pubkeys [][]byte) WireguardAnalyzer { salt := make([]byte, 32) _, err := rand.Read(salt) if err != nil { panic(fmt.Errorf("failed to generate salt: %w", err)) } return WireguardAnalyzer{ - table: exchange.MakeExchangeTable(), - publicKeys: publicKeys, - unverified: unverifiedEndpoints{endpoints: make(map[string]unverifiedEndpoint)}, + table: exchange.MakeExchangeTable(), + checker: EndpointChecker{ + endpoints: make(map[string]unverifiedEndpoint), + pubkeys: pubkeys, + }, } } diff --git a/internal/analyzer/analyzer_test.go b/internal/analyzer/analyzer_test.go index cdca574..57f2d1c 100644 --- a/internal/analyzer/analyzer_test.go +++ b/internal/analyzer/analyzer_test.go @@ -34,13 +34,13 @@ func mapAddrs(as []net.UDPAddr) []string { func TestWireguardAnalyzer_VerifyMac1(t *testing.T) { analyzer := MakeWireguardAnalyzer([][]byte{pubkeyA, pubkeyB}) addr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:51820") - addrs, _ := analyzer.Analyse(handshakeInitiationSession1AtoB, *addr1) - if addrs == nil { + _, data := analyzer.Analyse(handshakeInitiationSession1AtoB, *addr1) + if data == nil { t.Error("mac1 verification failed") } analyzer = MakeWireguardAnalyzer([][]byte{fakePubkey}) - addrs, _ = analyzer.Analyse(handshakeInitiationSession1AtoB, *addr1) - if addrs != nil { + _, data = analyzer.Analyse(handshakeInitiationSession1AtoB, *addr1) + if data != nil { t.Errorf("mac1 verification didn't fail") } } diff --git a/internal/analyzer/endpoints.go b/internal/analyzer/endpoints.go new file mode 100644 index 0000000..23b479f --- /dev/null +++ b/internal/analyzer/endpoints.go @@ -0,0 +1,136 @@ +package analyzer + +import ( + "crypto/hmac" + "log/slog" + "net" + "sync" + "time" +) + +type endpointVerificationStatus int + +const ( + illegalEndpoint endpointVerificationStatus = iota + endpointNotVerified + endpointPendingVerification + endpointVerified +) + +type unverifiedEndpoint struct { + cookie [cookieSize]byte + ttl time.Time +} + +type EndpointChecker struct { + mu sync.Mutex + pubkeys [][]byte + endpoints map[string]unverifiedEndpoint +} + +func (c *EndpointChecker) cleanup() { + now := time.Now() + for addr := range c.endpoints { + if c.endpoints[addr].ttl.Before(now) { + slog.Debug("remove expired endpoint cookie", "addr", addr) + delete(c.endpoints, addr) + } + } +} + +// CreateReply generate a cookie reply for the handshake initiation message. +func (c *EndpointChecker) CreateReply(addr net.UDPAddr, msg []byte) ([]byte, error) { + pubkey := c.matchPubkey(msg) + if pubkey == nil { + panic("public key not found") + } + secret, err := token(32) + if err != nil { + return nil, err + } + src, err := addr.AddrPort().MarshalBinary() + if err != nil { + return nil, err + } + cookie := mac32(nil, [32]byte(secret), src) + nonce, err := token(24) + if err != nil { + return nil, err + } + key := hash(nil, []byte("cookie--"), pubkey) + c.mu.Lock() + defer c.mu.Unlock() + c.endpoints[addr.String()] = unverifiedEndpoint{ + cookie: cookie, + ttl: time.Now().Add(time.Duration(10) * time.Second), + } + reply := make([]byte, 64) + reply[0] = 3 + copy(reply[4:8], msg[4:8]) + copy(reply[8:32], nonce) + mac1 := msg[handshakeInitiationSize-2*macSize : handshakeInitiationSize-macSize] + xaead(reply[:32], key, [24]byte(nonce), cookie[:], mac1) + return reply, nil +} + +func (c *EndpointChecker) matchPubkey(packet []byte) []byte { + if len(packet) < 2*macSize { + return nil + } + if len(c.pubkeys) == 0 { + return nil + } + l := len(packet) + mac1 := packet[l-2*macSize : l-macSize] + d := packet[:l-2*macSize] + for _, pubkey := range c.pubkeys { + key := hash(nil, []byte("mac1----"), pubkey) + m := mac32(nil, key, d) + if hmac.Equal(mac1, m[:]) { + return pubkey + } + } + return nil +} + +func (c *EndpointChecker) VerifyHandshakeInitiation(addr net.UDPAddr, msg []byte) endpointVerificationStatus { + if len(c.pubkeys) == 0 { + return endpointVerified + } + logger := slog.With("addr", addr.String()) + pubkey := c.matchPubkey(msg) + if pubkey == nil { + logger.Warn("invalid mac1 in handshake initiation") + return illegalEndpoint + } + + c.mu.Lock() + defer c.mu.Unlock() + c.cleanup() + + endpoint, ok := c.endpoints[addr.String()] + if !ok { + return endpointNotVerified + } + + mac2 := mac16(nil, endpoint.cookie, msg[:handshakeInitiationSize-macSize]) + if hmac.Equal(msg[handshakeInitiationSize-macSize:], mac2[:]) { + delete(c.endpoints, addr.String()) + slog.Debug("endpoint verified", "addr", addr.String()) + return endpointVerified + } + return endpointPendingVerification +} + +func (c *EndpointChecker) VerifyHandshakeResponse(addr net.UDPAddr, msg []byte) bool { + if len(c.pubkeys) == 0 { + return true + } + logger := slog.With("addr", addr.String()) + pubkey := c.matchPubkey(msg) + if pubkey == nil { + logger.Warn("invalid mac1 in handshake response") + return false + } + return true +} diff --git a/internal/analyzer/helper.go b/internal/analyzer/helper.go new file mode 100644 index 0000000..7de10e1 --- /dev/null +++ b/internal/analyzer/helper.go @@ -0,0 +1,45 @@ +package analyzer + +import ( + "crypto/rand" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" +) + +func hash(dst []byte, data ...[]byte) [32]byte { + hash, _ := blake2s.New256(nil) + for _, d := range data { + hash.Write(d) + } + return [32]byte(hash.Sum(dst)) +} + +func mac16(dst []byte, key [16]byte, data ...[]byte) [16]byte { + mac, _ := blake2s.New128(key[:]) + for _, d := range data { + mac.Write(d) + } + return [16]byte(mac.Sum(dst)) +} + +func mac32(dst []byte, key [32]byte, data ...[]byte) [16]byte { + mac, _ := blake2s.New128(key[:]) + for _, d := range data { + mac.Write(d) + } + return [16]byte(mac.Sum(dst)) +} + +func xaead(dst []byte, key [32]byte, nonce [24]byte, plain []byte, auth []byte) []byte { + xaead, _ := chacha20poly1305.NewX(key[:]) + return xaead.Seal(dst, nonce[:], plain, auth) +} + +func token(n int) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + return nil, err + } + return b, nil +} diff --git a/internal/relay/relay.go b/internal/relay/relay.go index 6276d7e..512bd64 100644 --- a/internal/relay/relay.go +++ b/internal/relay/relay.go @@ -25,7 +25,7 @@ func (r *Relay) sendUDP() { for packet := range r.send { if packet.isBroadcast { if !r.limit.Allow() { - slog.Error("broadcast rate limit exceeded", "src", packet.source.String(), "dst", packet.addr.String()) + slog.Warn("broadcast rate limit exceeded", "src", packet.source.String(), "dst", packet.addr.String()) } } _, err := r.conn.WriteToUDP(packet.data, &packet.addr)