diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b1d057e --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +wpex \ No newline at end of file diff --git a/internal/analyzer/analyzer.go b/internal/analyzer/analyzer.go index 6be6056..ce5a7b8 100644 --- a/internal/analyzer/analyzer.go +++ b/internal/analyzer/analyzer.go @@ -2,12 +2,12 @@ package analyzer import ( "context" - "crypto/rand" "encoding/binary" "fmt" "github.com/weiiwang01/wpex/internal/exchange" "log/slog" "net" + "time" ) const ( @@ -20,7 +20,7 @@ const ( type WireguardAnalyzer struct { table exchange.ExchangeTable - checker EndpointChecker + checker macChecker } func (t *WireguardAnalyzer) decodeIndex(index []byte) uint32 { @@ -33,32 +33,32 @@ func (t *WireguardAnalyzer) analyseHandshakeInitiation(packet []byte, peer net.U logger.Warn(fmt.Sprintf("invalid handshake initiation: expected length %d, got %d", handshakeInitiationSize, len(packet))) return nil, nil } - sender := t.decodeIndex(packet[4:8]) - if t.table.Contains(peer) { - goto send - } - switch t.checker.VerifyHandshakeInitiation(peer, packet) { - case illegalEndpoint: - return nil, nil - case endpointNotVerified: - logger.Debug("send cookie reply to unknown endpoint") - reply, err := t.checker.CreateReply(peer, packet) - if err != nil { - logger.Error(fmt.Sprintf("failed to create cookie reply: %s", err)) + + if t.checker.RequireCheck() { + pubkey := t.checker.MatchPubkey(packet) + if pubkey == nil { + logger.Warn("invalid mac1 in handshake initiation") return nil, nil } - return []net.UDPAddr{peer}, reply - case endpointPendingVerification: - logger.Debug("ignore handshake initiation from endpoint pending verification") - return nil, nil - case endpointVerified: - newPacket := make([]byte, handshakeInitiationSize) - copy(newPacket, packet[:handshakeInitiationSize-macSize]) - packet = newPacket - default: - panic("unknown endpoint verification status") + mac2Ok := t.checker.VerifyMac2(peer, packet) + known := t.table.Contains(peer) + if !mac2Ok && !known { + logger.Debug("send cookie reply to handshake initiation from unknown endpoint") + reply, err := t.checker.CreateReply(pubkey, peer, packet) + if err != nil { + logger.Error(fmt.Sprintf("failed to create cookie reply: %s", err)) + return nil, nil + } + return []net.UDPAddr{peer}, reply + } + if mac2Ok { + logger.Debug("strip mac2 from handshake initiation") + newPacket := make([]byte, handshakeInitiationSize) + copy(newPacket, packet[:handshakeInitiationSize-macSize]) + packet = newPacket + } } -send: + sender := t.decodeIndex(packet[4:8]) if err := t.table.AddPeerAddr(sender, peer); err != nil { logger.Error(fmt.Sprintf("failed to add address: %s", err)) return nil, nil @@ -74,8 +74,14 @@ func (t *WireguardAnalyzer) analyseHandshakeResponse(packet []byte, peer net.UDP logger.Warn(fmt.Sprintf("invalid handshake response: expected length %d, got %d", handshakeResponseSize, len(packet))) return nil, nil } - if !t.checker.VerifyHandshakeResponse(peer, packet) { - return nil, nil + if t.checker.RequireCheck() { + if t.checker.MatchPubkey(packet) == nil { + logger.Warn("invalid mac1 in handshake response") + return nil, nil + } + if t.checker.VerifyMac2(peer, packet) { + logger.Debug("strip mac2 from handshake response") + } } sender := t.decodeIndex(packet[4:8]) if err := t.table.AddPeerAddr(sender, peer); err != nil { @@ -89,7 +95,7 @@ func (t *WireguardAnalyzer) analyseHandshakeResponse(packet []byte, peer net.UDP logger.Warn(fmt.Sprintf("unknown receiver in handshake response: %s", err)) return nil, nil } - err = t.table.LinkPeers(sender, receiverIdx) + err = t.table.AssociatePeers(sender, receiverIdx) if err != nil { logger.Error(fmt.Sprintf("failed to link peers: %s", err)) return nil, nil @@ -173,16 +179,16 @@ func (t *WireguardAnalyzer) Analyse(packet []byte, peer net.UDPAddr) ([]net.UDPA } func MakeWireguardAnalyzer(pubkeys [][]byte) WireguardAnalyzer { - salt := make([]byte, 32) - _, err := rand.Read(salt) + secret, err := token(32) if err != nil { - panic(fmt.Errorf("failed to generate salt: %w", err)) + panic(fmt.Errorf("failed to generate cookie secret: %w", err)) } return WireguardAnalyzer{ table: exchange.MakeExchangeTable(), - checker: EndpointChecker{ - endpoints: make(map[string]unverifiedEndpoint), - pubkeys: pubkeys, + checker: macChecker{ + pubkeys: pubkeys, + secret: [32]byte(secret), + start: time.Now(), }, } } diff --git a/internal/analyzer/checker.go b/internal/analyzer/checker.go new file mode 100644 index 0000000..509e5d5 --- /dev/null +++ b/internal/analyzer/checker.go @@ -0,0 +1,66 @@ +package analyzer + +import ( + "crypto/hmac" + "encoding/binary" + "net" + "time" +) + +type macChecker struct { + secret [32]byte + start time.Time + pubkeys [][]byte +} + +func (c *macChecker) cookie(addr net.UDPAddr) [16]byte { + ticks := uint64(time.Now().Sub(c.start) / (time.Duration(120) * time.Minute)) + addrBytes, _ := addr.AddrPort().MarshalBinary() + return mac32(nil, c.secret, binary.BigEndian.AppendUint64(nil, ticks), addrBytes) +} + +func (c *macChecker) VerifyMac2(addr net.UDPAddr, message []byte) bool { + l := len(message) + mac2 := mac16(nil, c.cookie(addr), message[:l-macSize]) + return hmac.Equal(message[l-macSize:], mac2[:]) +} + +func (c *macChecker) MatchPubkey(message []byte) []byte { + if len(message) < 2*macSize { + return nil + } + if len(c.pubkeys) == 0 { + return nil + } + l := len(message) + mac1 := message[l-2*macSize : l-macSize] + d := message[:l-2*macSize] + for _, pubkey := range c.pubkeys { + key := hash(nil, []byte("mac1----"), pubkey) + m := mac32(nil, key, d) + if hmac.Equal(mac1, m[:]) { + return pubkey + } + } + return nil +} + +func (c *macChecker) CreateReply(pubkey []byte, addr net.UDPAddr, msg []byte) ([]byte, error) { + cookie := c.cookie(addr) + nonce, err := token(24) + if err != nil { + return nil, err + } + key := hash(nil, []byte("cookie--"), pubkey) + reply := make([]byte, 64) + reply[0] = 3 + copy(reply[4:8], msg[4:8]) + copy(reply[8:32], nonce) + mac1 := msg[handshakeInitiationSize-2*macSize : handshakeInitiationSize-macSize] + xaead(reply[:32], key, [24]byte(nonce), cookie[:], mac1) + return reply, nil +} + +func (c *macChecker) RequireCheck() bool { + return len(c.pubkeys) > 0 +} diff --git a/internal/analyzer/endpoints.go b/internal/analyzer/endpoints.go deleted file mode 100644 index 23b479f..0000000 --- a/internal/analyzer/endpoints.go +++ /dev/null @@ -1,136 +0,0 @@ -package analyzer - -import ( - "crypto/hmac" - "log/slog" - "net" - "sync" - "time" -) - -type endpointVerificationStatus int - -const ( - illegalEndpoint endpointVerificationStatus = iota - endpointNotVerified - endpointPendingVerification - endpointVerified -) - -type unverifiedEndpoint struct { - cookie [cookieSize]byte - ttl time.Time -} - -type EndpointChecker struct { - mu sync.Mutex - pubkeys [][]byte - endpoints map[string]unverifiedEndpoint -} - -func (c *EndpointChecker) cleanup() { - now := time.Now() - for addr := range c.endpoints { - if c.endpoints[addr].ttl.Before(now) { - slog.Debug("remove expired endpoint cookie", "addr", addr) - delete(c.endpoints, addr) - } - } -} - -// CreateReply generate a cookie reply for the handshake initiation message. -func (c *EndpointChecker) CreateReply(addr net.UDPAddr, msg []byte) ([]byte, error) { - pubkey := c.matchPubkey(msg) - if pubkey == nil { - panic("public key not found") - } - secret, err := token(32) - if err != nil { - return nil, err - } - src, err := addr.AddrPort().MarshalBinary() - if err != nil { - return nil, err - } - cookie := mac32(nil, [32]byte(secret), src) - nonce, err := token(24) - if err != nil { - return nil, err - } - key := hash(nil, []byte("cookie--"), pubkey) - c.mu.Lock() - defer c.mu.Unlock() - c.endpoints[addr.String()] = unverifiedEndpoint{ - cookie: cookie, - ttl: time.Now().Add(time.Duration(10) * time.Second), - } - reply := make([]byte, 64) - reply[0] = 3 - copy(reply[4:8], msg[4:8]) - copy(reply[8:32], nonce) - mac1 := msg[handshakeInitiationSize-2*macSize : handshakeInitiationSize-macSize] - xaead(reply[:32], key, [24]byte(nonce), cookie[:], mac1) - return reply, nil -} - -func (c *EndpointChecker) matchPubkey(packet []byte) []byte { - if len(packet) < 2*macSize { - return nil - } - if len(c.pubkeys) == 0 { - return nil - } - l := len(packet) - mac1 := packet[l-2*macSize : l-macSize] - d := packet[:l-2*macSize] - for _, pubkey := range c.pubkeys { - key := hash(nil, []byte("mac1----"), pubkey) - m := mac32(nil, key, d) - if hmac.Equal(mac1, m[:]) { - return pubkey - } - } - return nil -} - -func (c *EndpointChecker) VerifyHandshakeInitiation(addr net.UDPAddr, msg []byte) endpointVerificationStatus { - if len(c.pubkeys) == 0 { - return endpointVerified - } - logger := slog.With("addr", addr.String()) - pubkey := c.matchPubkey(msg) - if pubkey == nil { - logger.Warn("invalid mac1 in handshake initiation") - return illegalEndpoint - } - - c.mu.Lock() - defer c.mu.Unlock() - c.cleanup() - - endpoint, ok := c.endpoints[addr.String()] - if !ok { - return endpointNotVerified - } - - mac2 := mac16(nil, endpoint.cookie, msg[:handshakeInitiationSize-macSize]) - if hmac.Equal(msg[handshakeInitiationSize-macSize:], mac2[:]) { - delete(c.endpoints, addr.String()) - slog.Debug("endpoint verified", "addr", addr.String()) - return endpointVerified - } - return endpointPendingVerification -} - -func (c *EndpointChecker) VerifyHandshakeResponse(addr net.UDPAddr, msg []byte) bool { - if len(c.pubkeys) == 0 { - return true - } - logger := slog.With("addr", addr.String()) - pubkey := c.matchPubkey(msg) - if pubkey == nil { - logger.Warn("invalid mac1 in handshake response") - return false - } - return true -} diff --git a/internal/exchange/exchange.go b/internal/exchange/exchange.go index 1e2a990..f05b0ec 100644 --- a/internal/exchange/exchange.go +++ b/internal/exchange/exchange.go @@ -8,133 +8,197 @@ import ( "time" ) +const ( + endpointTTL = 1 * time.Minute + handshakeTTL = 5 * time.Second + sessionTTL = 3 * time.Minute +) + +type endpointInfo struct { + refs int + addr net.UDPAddr + expiredAt time.Time +} + +func (e *endpointInfo) isExpired() bool { + if e.refs <= 0 && e.expiredAt.Before(time.Now()) { + return true + } + return false +} + type peerInfo struct { - addr net.UDPAddr - ttl time.Time + index uint32 + addr *endpointInfo established bool counterpart uint32 + expiredAt time.Time +} + +func (p *peerInfo) isExpired() bool { + return p.isExpiredAt(time.Now()) +} + +func (p *peerInfo) isExpiredAt(t time.Time) bool { + if p.expiredAt.Before(t) { + return true + } + return false } // ExchangeTable is a concurrency-safe table that maintains wireguard peer information. type ExchangeTable struct { - table map[uint32]peerInfo - lock sync.RWMutex + mu sync.RWMutex + endpoints map[string]*endpointInfo + peers map[uint32]peerInfo } -// AddPeerAddr adds a new peer's endpoint address to the exchange table -// and automatically removes expired peer information. -func (t *ExchangeTable) AddPeerAddr(index uint32, addr net.UDPAddr) error { - t.lock.Lock() - defer t.lock.Unlock() +func (t *ExchangeTable) refEndpoint(addr net.UDPAddr) *endpointInfo { + addrStr := addr.String() + e, ok := t.endpoints[addrStr] + if ok { + e.refs += 1 + return e + } + e = &endpointInfo{ + addr: addr, + refs: 1, + } + t.endpoints[addrStr] = e + return e +} + +func (t *ExchangeTable) derefEndpoint(endpoint *endpointInfo) { + endpoint.refs -= 1 + if endpoint.refs <= 0 { + endpoint.expiredAt = time.Now().Add(endpointTTL) + } +} +func (t *ExchangeTable) cleanup() { now := time.Now() - for index, peer := range t.table { - if now.After(peer.ttl) { + for index, peer := range t.peers { + if peer.isExpiredAt(now) { slog.Debug("remove expired peer information", "index", index) - delete(t.table, index) + t.derefEndpoint(peer.addr) + delete(t.peers, index) + } + } + for addr, endpoint := range t.endpoints { + if endpoint.isExpired() { + slog.Debug("remove expired endpoint information", "addr", addr) + delete(t.endpoints, addr) } } +} + +// AddPeerAddr adds a new peer's endpoint address to the exchange table. +func (t *ExchangeTable) AddPeerAddr(index uint32, addr net.UDPAddr) error { + t.mu.Lock() + defer t.mu.Unlock() - if _, ok := t.table[index]; ok { + t.cleanup() + if _, ok := t.peers[index]; ok { return fmt.Errorf("peer index collision detected on %d", index) } - - t.table[index] = peerInfo{ - addr: addr, - ttl: time.Now().Add(10 * time.Second), + t.peers[index] = peerInfo{ + index: index, + addr: t.refEndpoint(addr), + established: false, + counterpart: 0, + expiredAt: time.Now().Add(handshakeTTL), } - - slog.Debug("exchange table updated", "entries", len(t.table)) return nil } // UpdatePeerAddr updates the endpoint address of a peer given its index. func (t *ExchangeTable) UpdatePeerAddr(index uint32, addr net.UDPAddr) error { - t.lock.Lock() - defer t.lock.Unlock() + t.mu.Lock() + defer t.mu.Unlock() - peer, ok := t.table[index] + t.cleanup() + peer, ok := t.peers[index] if !ok { - return fmt.Errorf("failed to update: unknown peer %d", index) + return fmt.Errorf("failed to update: peer %d not found", index) } - - peer.addr = addr - t.table[index] = peer + t.derefEndpoint(peer.addr) + peer.addr = t.refEndpoint(addr) + t.peers[index] = peer return nil } // GetPeerAddr retrieves the endpoint address of a peer using its index. func (t *ExchangeTable) GetPeerAddr(index uint32) (net.UDPAddr, error) { - t.lock.RLock() - defer t.lock.RUnlock() + t.mu.RLock() + defer t.mu.RUnlock() - peer, ok := t.table[index] - if !ok { - return net.UDPAddr{}, fmt.Errorf("unknown peer %d", index) + peer, ok := t.peers[index] + if !ok || peer.isExpired() { + return net.UDPAddr{}, fmt.Errorf("peer %d not found", index) } - - return peer.addr, nil + return peer.addr.addr, nil } // ListAddrs returns all known endpoint addresses from the exchange table. func (t *ExchangeTable) ListAddrs(exclude net.UDPAddr) []net.UDPAddr { - t.lock.RLock() - defer t.lock.RUnlock() + t.mu.RLock() + defer t.mu.RUnlock() var addrs []net.UDPAddr - excludes := map[string]struct{}{ - exclude.String(): {}, - } - - for _, peer := range t.table { - addrStr := peer.addr.String() - if _, ok := excludes[addrStr]; !ok { - addrs = append(addrs, peer.addr) - excludes[addrStr] = struct{}{} + for _, endpoint := range t.endpoints { + if !endpoint.isExpired() && endpoint.addr.String() != exclude.String() { + addrs = append(addrs, endpoint.addr) } } - return addrs } -// LinkPeers associates two peers in the same wireguard session. +// AssociatePeers associates two peers in the same wireguard session. // After linking, the counterpart peer can be retrieved using GetPeerCounterpart. -func (t *ExchangeTable) LinkPeers(sender, receiver uint32) error { - t.lock.Lock() - defer t.lock.Unlock() +func (t *ExchangeTable) AssociatePeers(sender, receiver uint32) error { + t.mu.Lock() + defer t.mu.Unlock() - s, ok := t.table[sender] + t.cleanup() + s, ok := t.peers[sender] if !ok { - return fmt.Errorf("failed to link peers: unknown sender %d", sender) + return fmt.Errorf("failed to associate peers: sender %d not found", sender) } - r, ok := t.table[receiver] + r, ok := t.peers[receiver] if !ok { - return fmt.Errorf("failed to link peers: unknown receiver %d", receiver) + return fmt.Errorf("failed to associate peers: receiver %d not found", receiver) + } + + if s.established && r.established && s.counterpart == receiver && r.counterpart == sender { + return nil } - ttl := time.Now().Add(time.Duration(3) * time.Minute) + if s.established || r.established { + return fmt.Errorf("sender or receiver has already been assoicated with another peer") + } + expiredAt := time.Now().Add(sessionTTL) s.established = true s.counterpart = receiver - s.ttl = ttl - t.table[sender] = s + s.expiredAt = expiredAt + t.peers[sender] = s r.established = true r.counterpart = sender - r.ttl = ttl - t.table[receiver] = r + r.expiredAt = expiredAt + t.peers[receiver] = r return nil } // GetPeerCounterpart retrieves the counterpart of a given peer from the same session. func (t *ExchangeTable) GetPeerCounterpart(index uint32) (uint32, error) { - t.lock.RLock() - defer t.lock.RUnlock() + t.mu.RLock() + defer t.mu.RUnlock() - peer, ok := t.table[index] - if !ok || !peer.established { + peer, ok := t.peers[index] + if !ok || !peer.established || peer.isExpired() { return 0, fmt.Errorf("peer %d doesn't exist or has no counterpart", index) } @@ -143,20 +207,19 @@ func (t *ExchangeTable) GetPeerCounterpart(index uint32) (uint32, error) { // Contains checks if an address exists in the exchange table. func (t *ExchangeTable) Contains(addr net.UDPAddr) bool { - t.lock.RLock() - defer t.lock.RUnlock() + t.mu.RLock() + defer t.mu.RUnlock() - for _, peer := range t.table { - if peer.addr.String() == addr.String() { - return true - } + endpoint, ok := t.endpoints[addr.String()] + if !ok || endpoint.isExpired() { + return false } - - return false + return true } func MakeExchangeTable() ExchangeTable { return ExchangeTable{ - table: make(map[uint32]peerInfo), + endpoints: make(map[string]*endpointInfo), + peers: make(map[uint32]peerInfo), } } diff --git a/internal/exchange/exchange_test.go b/internal/exchange/exchange_test.go index ec373ce..0095b45 100644 --- a/internal/exchange/exchange_test.go +++ b/internal/exchange/exchange_test.go @@ -15,17 +15,17 @@ func TestExchangeTable_AddPeerAddr(t *testing.T) { if err := table.AddPeerAddr(index, *addr); err != nil { t.Errorf("AddPeerAddr failed: %v", err) } - if len(table.table) != 1 { - t.Errorf("incorrect number of peers in table, expected 1, got %d", len(table.table)) + if len(table.peers) != 1 { + t.Errorf("incorrect number of peers in table, expected 1, got %d", len(table.peers)) } - peer := table.table[index] - if peer.addr.String() != addr.String() { - t.Errorf("incorrect address of peer, expected %s, got %s", addr.String(), peer.addr.String()) + peer := table.peers[index] + if peer.addr.addr.String() != addr.String() { + t.Errorf("incorrect address of peer, expected %s, got %s", addr.String(), peer.addr.addr.String()) } if err := table.AddPeerAddr(1, *addr); err == nil { t.Errorf("AddPeerAddr didn't return error when adding an existing peer") } - if !peer.ttl.After(now) { + if !peer.expiredAt.After(now) { t.Errorf("incorrect ttl of peer") } } @@ -67,8 +67,8 @@ func TestExchangeTable_GetPeerCounterpart(t *testing.T) { if err := table.AddPeerAddr(index3, *addr3); err != nil { t.Errorf("AddPeerAddr failed: %v", err) } - if err := table.LinkPeers(index1, index2); err != nil { - t.Errorf("LinkPeers failed: %v", err) + if err := table.AssociatePeers(index1, index2); err != nil { + t.Errorf("AssociatePeers failed: %v", err) } c1, err := table.GetPeerCounterpart(index1) if err != nil { @@ -98,15 +98,15 @@ func TestExchangeTable_LinkPeers(t *testing.T) { if err := table.AddPeerAddr(index2, *addr2); err != nil { t.Errorf("AddPeerAddr failed: %v", err) } - if err := table.LinkPeers(index1, index2); err != nil { - t.Errorf("LinkPeers failed: %v", err) + if err := table.AssociatePeers(index1, index2); err != nil { + t.Errorf("AssociatePeers failed: %v", err) } - peer1, ok1 := table.table[index1] - peer2, ok2 := table.table[index2] + peer1, ok1 := table.peers[index1] + peer2, ok2 := table.peers[index2] if !ok1 || !ok2 { t.Errorf("AddPeerAddr failed, peer doesn't exist in the table") } - if peer1.ttl != peer2.ttl { + if peer1.expiredAt != peer2.expiredAt { t.Errorf("linked peers don't have same ttl") } if !peer1.established || !peer2.established {