diff --git a/connect-udp_test.go b/connect-udp_test.go index 25c01f3..1fc58d7 100644 --- a/connect-udp_test.go +++ b/connect-udp_test.go @@ -23,7 +23,8 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } -func runEchoServer(t *testing.T) *net.UDPConn { +// runEchoServer runs an echo server that echos back the data it receives n times. +func runEchoServer(t *testing.T, amplification int) *net.UDPConn { t.Helper() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) @@ -34,8 +35,10 @@ func runEchoServer(t *testing.T) *net.UDPConn { if err != nil { return } - if _, err := conn.WriteTo(b[:n], addr); err != nil { - return + for i := 0; i < amplification; i++ { + if _, err := conn.WriteTo(b[:n], addr); err != nil { + return + } } } }() @@ -43,7 +46,8 @@ func runEchoServer(t *testing.T) *net.UDPConn { } func TestProxyToIP(t *testing.T) { - remoteServerConn := runEchoServer(t) + const amplification = 3 + remoteServerConn := runEchoServer(t, 3) defer remoteServerConn.Close() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) @@ -62,6 +66,7 @@ func TestProxyToIP(t *testing.T) { defer server.Close() proxy := masque.Proxy{} defer proxy.Close() + statsChan := make(chan masque.Stats, 1) mux.HandleFunc("/masque", func(w http.ResponseWriter, r *http.Request) { req, err := masque.ParseRequest(r, template) if err != nil { @@ -69,7 +74,8 @@ func TestProxyToIP(t *testing.T) { w.WriteHeader(http.StatusBadRequest) return } - proxy.Proxy(w, req) + stats, _ := proxy.Proxy(w, req) + statsChan <- stats }) go func() { if err := server.Serve(conn); err != nil { @@ -81,20 +87,33 @@ func TestProxyToIP(t *testing.T) { Template: template, TLSClientConfig: &tls.Config{ClientCAs: certPool, NextProtos: []string{http3.NextProtoH3}, InsecureSkipVerify: true}, } - defer cl.Close() proxiedConn, _, err := cl.Dial(context.Background(), remoteServerConn.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) _, err = proxiedConn.WriteTo([]byte("foobar"), remoteServerConn.LocalAddr()) require.NoError(t, err) - b := make([]byte, 1500) - n, _, err := proxiedConn.ReadFrom(b) - require.NoError(t, err) - require.Equal(t, []byte("foobar"), b[:n]) + for i := 0; i < amplification; i++ { + b := make([]byte, 1500) + n, _, err := proxiedConn.ReadFrom(b) + require.NoError(t, err) + require.Equal(t, []byte("foobar"), b[:n]) + } + cl.Close() + select { + case stats := <-statsChan: + require.Equal(t, masque.Stats{ + PacketsSent: 1, + DataSent: 6, + PacketsReceived: amplification, + DataReceived: 6 * amplification, + }, stats) + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for stats") + } } func TestProxyToHostname(t *testing.T) { - remoteServerConn := runEchoServer(t) + remoteServerConn := runEchoServer(t, 1) defer remoteServerConn.Close() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) @@ -197,7 +216,7 @@ func TestProxyToHostnameMissingPort(t *testing.T) { } func TestProxyShutdown(t *testing.T) { - remoteServerConn := runEchoServer(t) + remoteServerConn := runEchoServer(t, 1) defer remoteServerConn.Close() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) diff --git a/proxy.go b/proxy.go index 04da117..494652a 100644 --- a/proxy.go +++ b/proxy.go @@ -26,6 +26,11 @@ type proxyEntry struct { conn *net.UDPConn } +type Stats struct { + PacketsSent, PacketsReceived uint64 + DataSent, DataReceived uint64 +} + type Proxy struct { closed atomic.Bool @@ -38,23 +43,23 @@ type Proxy struct { // For more control over the UDP socket, use ProxyConnectedSocket. // Applications may add custom header fields to the response header, // but MUST NOT call WriteHeader on the http.ResponseWriter. -func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error { +func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) (Stats, error) { if s.closed.Load() { w.WriteHeader(http.StatusServiceUnavailable) - return net.ErrClosed + return Stats{}, net.ErrClosed } addr, err := net.ResolveUDPAddr("udp", r.Target) if err != nil { // TODO: set proxy-status header (might want to use structured headers) w.WriteHeader(http.StatusGatewayTimeout) - return err + return Stats{}, err } conn, err := net.DialUDP("udp", nil, addr) if err != nil { // TODO: set proxy-status header (might want to use structured headers) w.WriteHeader(http.StatusGatewayTimeout) - return err + return Stats{}, err } defer conn.Close() @@ -64,11 +69,11 @@ func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error { // ProxyConnectedSocket proxies a request on a connected UDP socket. // Applications may add custom header fields to the response header, // but MUST NOT call WriteHeader on the http.ResponseWriter. -func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) error { +func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) (Stats, error) { if s.closed.Load() { conn.Close() w.WriteHeader(http.StatusServiceUnavailable) - return net.ErrClosed + return Stats{}, net.ErrClosed } s.refCount.Add(1) @@ -87,16 +92,21 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne var wg sync.WaitGroup wg.Add(3) + var packetsSent, packetsReceived, dataSent, dataReceived uint64 go func() { defer wg.Done() - if err := s.proxyConnSend(conn, str); err != nil { + var err error + packetsSent, dataSent, err = s.proxyConnSend(conn, str) + if err != nil && !s.closed.Load() { log.Printf("proxying send side to %s failed: %v", conn.RemoteAddr(), err) } str.Close() }() go func() { defer wg.Done() - if err := s.proxyConnReceive(conn, str); err != nil && !s.closed.Load() { + var err error + packetsReceived, dataReceived, err = s.proxyConnReceive(conn, str) + if err != nil && !s.closed.Load() { log.Printf("proxying receive side to %s failed: %v", conn.RemoteAddr(), err) } str.Close() @@ -111,41 +121,50 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne conn.Close() }() wg.Wait() - return nil + return Stats{ + PacketsSent: packetsSent, + PacketsReceived: packetsReceived, + DataSent: dataSent, + DataReceived: dataReceived, + }, nil } -func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) error { +func (s *Proxy) proxyConnSend(conn *net.UDPConn, str http3.Stream) (packetsSent, dataSent uint64, _ error) { for { data, err := str.ReceiveDatagram(context.Background()) if err != nil { - return err + return packetsSent, dataSent, err } contextID, n, err := quicvarint.Parse(data) if err != nil { - return err + return packetsSent, dataSent, err } if contextID != 0 { // Drop this datagram. We currently only support proxying of UDP payloads. continue } + packetsSent++ + dataSent += uint64(len(data) - n) if _, err := conn.Write(data[n:]); err != nil { - return err + return packetsSent, dataSent, err } } } -func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) error { +func (s *Proxy) proxyConnReceive(conn *net.UDPConn, str http3.Stream) (packetsReceived, dataReceived uint64, _ error) { b := make([]byte, 1500) for { n, err := conn.Read(b) if err != nil { - return err + return packetsReceived, dataReceived, err } + packetsReceived++ + dataReceived += uint64(n) data := make([]byte, 0, len(contextIDZero)+n) data = append(data, contextIDZero...) data = append(data, b[:n]...) if err := str.SendDatagram(data); err != nil { - return err + return packetsReceived, dataReceived, err } } } diff --git a/proxy_test.go b/proxy_test.go index 5438c3f..b58d564 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -103,7 +103,8 @@ func TestProxyDialFailure(t *testing.T) { require.NoError(t, err) rec := httptest.NewRecorder() - require.ErrorContains(t, p.Proxy(rec, req), "invalid port") + _, err = p.Proxy(rec, req) + require.ErrorContains(t, err, "invalid port") require.Equal(t, http.StatusGatewayTimeout, rec.Code) } @@ -117,7 +118,8 @@ func TestProxyingAfterClose(t *testing.T) { t.Run("proxying", func(t *testing.T) { rec := httptest.NewRecorder() - require.ErrorIs(t, p.Proxy(rec, req), net.ErrClosed) + _, err := p.Proxy(rec, req) + require.ErrorIs(t, err, net.ErrClosed) require.Equal(t, http.StatusServiceUnavailable, rec.Code) }) @@ -125,7 +127,8 @@ func TestProxyingAfterClose(t *testing.T) { rec := httptest.NewRecorder() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) - require.ErrorIs(t, p.ProxyConnectedSocket(rec, req, conn), net.ErrClosed) + _, err = p.ProxyConnectedSocket(rec, req, conn) + require.ErrorIs(t, err, net.ErrClosed) require.Equal(t, http.StatusServiceUnavailable, rec.Code) }) }