diff --git a/shadowaead/packet.go b/shadowaead/packet.go index 2ba403fb..5887953c 100644 --- a/shadowaead/packet.go +++ b/shadowaead/packet.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net" + "net/netip" "sync" "github.com/shadowsocks/go-shadowsocks2/internal" @@ -73,6 +74,9 @@ type packetConn struct { // NewPacketConn wraps a net.PacketConn with cipher func NewPacketConn(c net.PacketConn, ciph Cipher) net.PacketConn { const maxPacketSize = 64 * 1024 + if cc, ok := c.(*net.UDPConn); ok { + return &udpConn{UDPConn: cc, Cipher: ciph, buf: make([]byte, maxPacketSize)} + } return &packetConn{PacketConn: c, Cipher: ciph, buf: make([]byte, maxPacketSize)} } @@ -101,3 +105,62 @@ func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { copy(b, bb) return len(bb), addr, err } + +type udpConn struct { + *net.UDPConn + Cipher + sync.Mutex + buf []byte // write lock +} + +// WriteTo encrypts b and write to addr using the embedded UDPConn. +func (c *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { + c.Lock() + defer c.Unlock() + buf, err := Pack(c.buf, b, c) + if err != nil { + return 0, err + } + _, err = c.UDPConn.WriteTo(buf, addr) + return len(b), err +} + +// ReadFrom reads from the embedded UDPConn and decrypts into b. +func (c *udpConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, addr, err := c.UDPConn.ReadFrom(b) + if err != nil { + return n, addr, err + } + bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c) + if err != nil { + return n, addr, err + } + copy(b, bb) + return len(bb), addr, err +} + +// WriteToUDPAddrPort encrypts b and write to addr using the embedded PacketConn. +func (c *udpConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + c.Lock() + defer c.Unlock() + buf, err := Pack(c.buf, b, c) + if err != nil { + return 0, err + } + _, err = c.UDPConn.WriteToUDPAddrPort(buf, addr) + return len(b), err +} + +// ReadFromUDPAddrPort reads from the embedded UDPConn and decrypts into b. +func (c *udpConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, error) { + n, addr, err := c.UDPConn.ReadFromUDPAddrPort(b) + if err != nil { + return n, addr, err + } + bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c) + if err != nil { + return n, addr, err + } + copy(b, bb) + return len(bb), addr, err +} diff --git a/udp.go b/udp.go index 06d85789..8d8c52e5 100644 --- a/udp.go +++ b/udp.go @@ -3,6 +3,7 @@ package main import ( "fmt" "net" + "net/netip" "sync" "time" @@ -34,7 +35,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack return } - c, err := net.ListenPacket("udp", laddr) + lnAddr, err := net.ResolveUDPAddr("udp", laddr) + if err != nil { + logf("UDP listen address error: %v", err) + return + } + + c, err := net.ListenUDP("udp", lnAddr) if err != nil { logf("UDP local listen error: %v", err) return @@ -47,13 +54,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack logf("UDP tunnel %s <-> %s <-> %s", laddr, server, target) for { - n, raddr, err := c.ReadFrom(buf[len(tgt):]) + n, raddr, err := c.ReadFromUDPAddrPort(buf[len(tgt):]) if err != nil { logf("UDP local read error: %v", err) continue } - pc := nm.Get(raddr.String()) + pc := nm.Get(raddr) if pc == nil { pc, err = net.ListenPacket("udp", "") if err != nil { @@ -81,7 +88,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC return } - c, err := net.ListenPacket("udp", laddr) + lnAddr, err := net.ResolveUDPAddr("udp", laddr) + if err != nil { + logf("UDP listen address error: %v", err) + return + } + + c, err := net.ListenUDP("udp", lnAddr) if err != nil { logf("UDP local listen error: %v", err) return @@ -92,13 +105,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC buf := make([]byte, udpBufSize) for { - n, raddr, err := c.ReadFrom(buf) + n, raddr, err := c.ReadFromUDPAddrPort(buf) if err != nil { logf("UDP local read error: %v", err) continue } - pc := nm.Get(raddr.String()) + pc := nm.Get(raddr) if pc == nil { pc, err = net.ListenPacket("udp", "") if err != nil { @@ -118,22 +131,33 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC } } +type UDPConn interface { + net.PacketConn + ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) + WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) +} + // Listen on addr for encrypted packets and basically do UDP NAT. func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) { - c, err := net.ListenPacket("udp", addr) + nAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + logf("UDP server address error: %v", err) + return + } + cc, err := net.ListenUDP("udp", nAddr) if err != nil { logf("UDP remote listen error: %v", err) return } - defer c.Close() - c = shadow(c) + defer cc.Close() + c := shadow(cc).(UDPConn) nm := newNATmap(config.UDPTimeout) buf := make([]byte, udpBufSize) logf("listening UDP on %s", addr) for { - n, raddr, err := c.ReadFrom(buf) + n, raddr, err := c.ReadFromUDPAddrPort(buf) if err != nil { logf("UDP remote read error: %v", err) continue @@ -153,7 +177,7 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) { payload := buf[len(tgtAddr):n] - pc := nm.Get(raddr.String()) + pc := nm.Get(raddr) if pc == nil { pc, err = net.ListenPacket("udp", "") if err != nil { @@ -175,31 +199,31 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) { // Packet NAT table type natmap struct { sync.RWMutex - m map[string]net.PacketConn + m map[netip.AddrPort]net.PacketConn timeout time.Duration } func newNATmap(timeout time.Duration) *natmap { m := &natmap{} - m.m = make(map[string]net.PacketConn) + m.m = make(map[netip.AddrPort]net.PacketConn) m.timeout = timeout return m } -func (m *natmap) Get(key string) net.PacketConn { +func (m *natmap) Get(key netip.AddrPort) net.PacketConn { m.RLock() defer m.RUnlock() return m.m[key] } -func (m *natmap) Set(key string, pc net.PacketConn) { +func (m *natmap) Set(key netip.AddrPort, pc net.PacketConn) { m.Lock() defer m.Unlock() m.m[key] = pc } -func (m *natmap) Del(key string) net.PacketConn { +func (m *natmap) Del(key netip.AddrPort) net.PacketConn { m.Lock() defer m.Unlock() @@ -211,19 +235,19 @@ func (m *natmap) Del(key string) net.PacketConn { return nil } -func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn, role mode) { - m.Set(peer.String(), src) +func (m *natmap) Add(peer netip.AddrPort, dst UDPConn, src net.PacketConn, role mode) { + m.Set(peer, src) go func() { timedCopy(dst, peer, src, m.timeout, role) - if pc := m.Del(peer.String()); pc != nil { + if pc := m.Del(peer); pc != nil { pc.Close() } }() } // copy from src to dst at target with read timeout -func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration, role mode) error { +func timedCopy(dst UDPConn, target netip.AddrPort, src net.PacketConn, timeout time.Duration, role mode) error { buf := make([]byte, udpBufSize) for { @@ -238,12 +262,12 @@ func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout srcAddr := socks.ParseAddr(raddr.String()) copy(buf[len(srcAddr):], buf[:n]) copy(buf, srcAddr) - _, err = dst.WriteTo(buf[:len(srcAddr)+n], target) + _, err = dst.WriteToUDPAddrPort(buf[:len(srcAddr)+n], target) case relayClient: // client -> user: strip original packet source srcAddr := socks.SplitAddr(buf[:n]) - _, err = dst.WriteTo(buf[len(srcAddr):n], target) + _, err = dst.WriteToUDPAddrPort(buf[len(srcAddr):n], target) case socksClient: // client -> socks5 program: just set RSV and FRAG = 0 - _, err = dst.WriteTo(append([]byte{0, 0, 0}, buf[:n]...), target) + _, err = dst.WriteToUDPAddrPort(append([]byte{0, 0, 0}, buf[:n]...), target) } if err != nil {