Skip to content

Commit

Permalink
Refactor the endpoint verification
Browse files Browse the repository at this point in the history
  • Loading branch information
weiiwang01 committed Sep 25, 2023
1 parent 9ab1847 commit 1bbd0b0
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 173 deletions.
200 changes: 32 additions & 168 deletions internal/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@ package analyzer

import (
"context"
"crypto/hmac"
"crypto/rand"
"encoding/binary"
"fmt"
"github.com/weiiwang01/wpex/internal/exchange"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"log/slog"
"net"
"sync"
"time"
)

const (
Expand All @@ -23,138 +18,9 @@ const (
cookieReplySize = 64
)

type unverifiedEndpoint struct {
cookie [cookieSize]byte
ttl time.Time
}

type unverifiedEndpoints struct {
mu sync.Mutex
endpoints map[string]unverifiedEndpoint
}

func (p *unverifiedEndpoints) cleanup() {
now := time.Now()
for addr := range p.endpoints {
if p.endpoints[addr].ttl.Before(now) {
slog.Debug("remove expired endpoint cookie", "addr", addr)
delete(p.endpoints, addr)
}
}
}

// CreateReply generate a cookie reply for the handshake initiation message.
func (p *unverifiedEndpoints) CreateReply(addr net.UDPAddr, pubkey []byte, index uint32, mac1 []byte) ([]byte, error) {
p.mu.Lock()
defer p.mu.Unlock()

p.cleanup()

secret := make([]byte, 16)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
mac, err := blake2s.New128(secret)
if err != nil {
return nil, err
}
src, err := addr.AddrPort().MarshalBinary()
if err != nil {
return nil, err
}
mac.Write(src)
nonce := make([]byte, 24)
if _, err = rand.Read(nonce); err != nil {
return nil, err
}
hash, _ := blake2s.New256(nil)
hash.Write([]byte("cookie--"))
hash.Write(pubkey)
aead, err := chacha20poly1305.NewX(hash.Sum(nil))
if err != nil {
return nil, err
}
cookie := mac.Sum(nil)
p.endpoints[addr.String()] = unverifiedEndpoint{
cookie: [16]byte(mac.Sum(nil)),
ttl: time.Now().Add(time.Duration(10) * time.Second),
}
reply := make([]byte, 64)
reply[0] = 3
binary.BigEndian.PutUint32(reply[4:8], index)
copy(reply[8:32], nonce)
aead.Seal(reply[:32], nonce, cookie, mac1)
return reply, nil
}

// Verify verifies the mac2 in the handshake initiation message responding the cookie reply.
// Corresponding cookie will be removed if the verification succeed.
func (p *unverifiedEndpoints) Verify(addr net.UDPAddr, msg []byte) bool {
p.mu.Lock()
defer p.mu.Unlock()

p.cleanup()

endpoint, ok := p.endpoints[addr.String()]
if !ok {
return false
}
mac2 := msg[148-16:]
mac, _ := blake2s.New128(endpoint.cookie[:])
mac.Write(msg[:148-16])
if hmac.Equal(mac2, mac.Sum(nil)) {
delete(p.endpoints, addr.String())
slog.Debug("endpoint verified", "addr", addr.String())
return true
}
slog.Debug("endpoint verification failed", "addr", addr.String())
return false
}

func (p *unverifiedEndpoints) pendingVerify(addr net.UDPAddr) bool {
p.mu.Lock()
defer p.mu.Unlock()

p.cleanup()

_, ok := p.endpoints[addr.String()]
return ok
}

type WireguardAnalyzer struct {
table exchange.ExchangeTable
publicKeys [][]byte
unverified unverifiedEndpoints
}

func (t *WireguardAnalyzer) matchPubkey(packet []byte) []byte {
if len(packet) < 32 {
return nil
}
if len(t.publicKeys) == 0 {
return nil
}
l := len(packet)
mac1 := packet[l-32 : l-16]
d := packet[:l-32]
for _, key := range t.publicKeys {
hash, err := blake2s.New256(nil)
if err != nil {
continue
}
hash.Write([]byte("mac1----"))
hash.Write(key)
mackey := hash.Sum(nil)
mac, err := blake2s.New128(mackey)
if err != nil {
continue
}
mac.Write(d)
if hmac.Equal(mac1, mac.Sum(nil)) {
return key
}
}
return nil
table exchange.ExchangeTable
checker EndpointChecker
}

func (t *WireguardAnalyzer) decodeIndex(index []byte) uint32 {
Expand All @@ -168,36 +34,33 @@ func (t *WireguardAnalyzer) analyseHandshakeInitiation(packet []byte, peer net.U
return nil, nil
}
sender := t.decodeIndex(packet[4:8])
if len(t.publicKeys) > 0 {
pubkey := t.matchPubkey(packet)
if pubkey == nil {
logger.Warn("invalid mac1 in handshake initiation")
if t.table.Contains(peer) {
goto send
}
switch t.checker.VerifyHandshakeInitiation(peer, packet) {
case illegalEndpoint:
return nil, nil
case endpointNotVerified:
logger.Debug("send cookie reply to unknown endpoint")
reply, err := t.checker.CreateReply(peer, packet)
if err != nil {
logger.Error(fmt.Sprintf("failed to create cookie reply: %s", err))
return nil, nil
}
if !t.table.Contains(peer) {
mac1Start := handshakeInitiationSize - macSize*2
mac2Start := mac1Start + macSize
if !t.unverified.Verify(peer, packet) {
if t.unverified.pendingVerify(peer) {
logger.Debug("ignore handshake initiation from endpoint pending verification")
return nil, nil
}
reply, err := t.unverified.CreateReply(peer, pubkey, sender, packet[mac1Start:mac2Start])
if err != nil {
logger.Error(fmt.Sprintf("fail to create cookie reply: %s", err))
return nil, nil
}
logger.Debug("send cookie reply to unknown endpoint")
return []net.UDPAddr{peer}, reply
} else {
newPacket := make([]byte, handshakeInitiationSize)
copy(newPacket, packet[:mac2Start])
packet = newPacket
}
}
return []net.UDPAddr{peer}, reply
case endpointPendingVerification:
logger.Debug("ignore handshake initiation from endpoint pending verification")
return nil, nil
case endpointVerified:
newPacket := make([]byte, handshakeInitiationSize)
copy(newPacket, packet[:handshakeInitiationSize-macSize])
packet = newPacket
default:
panic("unknown endpoint verification status")
}
send:
if err := t.table.AddPeerAddr(sender, peer); err != nil {
logger.Error(fmt.Sprintf("fail to add address: %s", err))
logger.Error(fmt.Sprintf("failed to add address: %s", err))
return nil, nil
}
addresses := t.table.ListAddrs(peer)
Expand All @@ -211,8 +74,7 @@ func (t *WireguardAnalyzer) analyseHandshakeResponse(packet []byte, peer net.UDP
logger.Warn(fmt.Sprintf("invalid handshake response: expected length %d, got %d", handshakeResponseSize, len(packet)))
return nil, nil
}
if len(t.publicKeys) > 0 && t.matchPubkey(packet) == nil {
logger.Warn("incorrect mac1 in handshake response")
if !t.checker.VerifyHandshakeResponse(peer, packet) {
return nil, nil
}
sender := t.decodeIndex(packet[4:8])
Expand Down Expand Up @@ -310,15 +172,17 @@ func (t *WireguardAnalyzer) Analyse(packet []byte, peer net.UDPAddr) ([]net.UDPA
}
}

func MakeWireguardAnalyzer(publicKeys [][]byte) WireguardAnalyzer {
func MakeWireguardAnalyzer(pubkeys [][]byte) WireguardAnalyzer {
salt := make([]byte, 32)
_, err := rand.Read(salt)
if err != nil {
panic(fmt.Errorf("failed to generate salt: %w", err))
}
return WireguardAnalyzer{
table: exchange.MakeExchangeTable(),
publicKeys: publicKeys,
unverified: unverifiedEndpoints{endpoints: make(map[string]unverifiedEndpoint)},
table: exchange.MakeExchangeTable(),
checker: EndpointChecker{
endpoints: make(map[string]unverifiedEndpoint),
pubkeys: pubkeys,
},
}
}
8 changes: 4 additions & 4 deletions internal/analyzer/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ func mapAddrs(as []net.UDPAddr) []string {
func TestWireguardAnalyzer_VerifyMac1(t *testing.T) {
analyzer := MakeWireguardAnalyzer([][]byte{pubkeyA, pubkeyB})
addr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:51820")
addrs, _ := analyzer.Analyse(handshakeInitiationSession1AtoB, *addr1)
if addrs == nil {
_, data := analyzer.Analyse(handshakeInitiationSession1AtoB, *addr1)
if data == nil {
t.Error("mac1 verification failed")
}
analyzer = MakeWireguardAnalyzer([][]byte{fakePubkey})
addrs, _ = analyzer.Analyse(handshakeInitiationSession1AtoB, *addr1)
if addrs != nil {
_, data = analyzer.Analyse(handshakeInitiationSession1AtoB, *addr1)
if data != nil {
t.Errorf("mac1 verification didn't fail")
}
}
Expand Down
Loading

0 comments on commit 1bbd0b0

Please sign in to comment.