diff --git a/go.mod b/go.mod index 36506f6c..06e7747a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/saucelabs/forwarder -go 1.23 +go 1.23.1 require ( github.com/dop251/goja v0.0.0-20231027120936-b396bb4c349d @@ -11,6 +11,7 @@ require ( github.com/kevinburke/hostsfile v0.0.0-20220522040509-e5e984885321 github.com/mitchellh/go-wordwrap v1.0.1 github.com/mmatczuk/anyflag v0.0.0-20240709090339-eb9e24cd1b44 + github.com/mmatczuk/connfu v0.0.0-20241015064402-db8989f89d8c github.com/prometheus/client_golang v1.20.5 github.com/prometheus/client_model v0.6.1 github.com/prometheus/common v0.60.0 diff --git a/go.sum b/go.sum index e94c7214..78f78955 100644 --- a/go.sum +++ b/go.sum @@ -67,6 +67,8 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mmatczuk/anyflag v0.0.0-20240709090339-eb9e24cd1b44 h1:Ds9W8Yj5ti4kQXITpCozfNNibS1fUA8+aK2T5th0vXE= github.com/mmatczuk/anyflag v0.0.0-20240709090339-eb9e24cd1b44/go.mod h1:PT22bA6vWBzPL8tAeK2XCMvWOQ4e19yY3MJIgnTZRaE= +github.com/mmatczuk/connfu v0.0.0-20241015064402-db8989f89d8c h1:1CC7JKZjrhe2AQh2T0Tay4j9Pp7HQl3WYpQvZr/ceA0= +github.com/mmatczuk/connfu v0.0.0-20241015064402-db8989f89d8c/go.mod h1:atoMPmvjynZBBUEoYWCM/ZnXAzZ9RoAnihm7YKXK/nY= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= diff --git a/tracked/tracked.go b/tracked/tracked.go new file mode 100644 index 00000000..2cac9416 --- /dev/null +++ b/tracked/tracked.go @@ -0,0 +1,142 @@ +// Copyright 2022-2024 Sauce Labs Inc., all rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package tracked + +import ( + "io" + "net" + "sync" + "sync/atomic" + + "github.com/mmatczuk/connfu" +) + +// ConnObserver allows to observe the number of bytes read and written from a connection. +type ConnObserver struct { + rx atomic.Uint64 + tx atomic.Uint64 +} + +// Rx returns the number of bytes read from the connection. +// It requires TrackTraffic to be set to true, otherwise it returns 0. +func (o *ConnObserver) Rx() uint64 { + return o.rx.Load() +} + +// Tx returns the number of bytes written to the connection. +// It requires TrackTraffic to be set to true, otherwise it returns 0. +func (o *ConnObserver) Tx() uint64 { + return o.tx.Load() +} + +func (o *ConnObserver) addRx(n uint64) { + o.rx.Add(n) +} + +func (o *ConnObserver) addTx(n uint64) { + o.tx.Add(n) +} + +type closeConn struct { + net.Conn + l closeListener // this is a field to avoid ambiguous selector error on Close method +} + +func (c *closeConn) Close() error { + return c.l.Close() +} + +type closeListener struct { + close func() error + once sync.Once + onClose func() +} + +func (c *closeListener) Close() error { + err := c.close() + c.once.Do(c.onClose) + return err +} + +// conn is a net.Conn that tracks the number of bytes read and written. +// It needs to be configured before first use by setting TrackTraffic and onClose if needed. +type conn struct { + net.Conn + o ConnObserver +} + +func (c *conn) Read(p []byte) (n int, err error) { + n, err = c.Conn.Read(p) + c.o.addRx(uint64(n)) + return +} + +func (c *conn) Write(p []byte) (n int, err error) { + n, err = c.Conn.Write(p) + c.o.addTx(uint64(n)) + return +} + +func (c *conn) ReadFrom(r io.Reader) (n int64, err error) { + n, err = c.Conn.(io.ReaderFrom).ReadFrom(r) + c.o.addTx(uint64(n)) + return +} + +type Builder struct { + // TrackTraffic enables counting of bytes read and written by the connection. + // Use Rx and Tx to get the number of bytes read and written. + TrackTraffic bool + + // OnClose is called after the underlying connection is closed and before the Close method returns. + // OnClose is called at most once. + OnClose func() +} + +func (b Builder) Build(c net.Conn) (net.Conn, *ConnObserver) { + var ( + wc net.Conn + co *ConnObserver + ) + + if b.TrackTraffic { + if b.OnClose != nil { + cc := &struct { + conn + closeListener + }{ + conn: conn{Conn: c}, + closeListener: closeListener{ + close: c.Close, + onClose: b.OnClose, + }, + } + wc = cc + co = &cc.conn.o + } else { + cc := &conn{ + Conn: c, + } + wc = cc + co = &cc.o + } + } else { + if b.OnClose == nil { + wc = c + } else { + wc = &closeConn{ + Conn: c, + l: closeListener{ + close: c.Close, + onClose: b.OnClose, + }, + } + } + } + + return connfu.Combine(wc, c), co +} diff --git a/tracked/tracked_test.go b/tracked/tracked_test.go new file mode 100644 index 00000000..b87a36d1 --- /dev/null +++ b/tracked/tracked_test.go @@ -0,0 +1,73 @@ +// Copyright 2022-2024 Sauce Labs Inc., all rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package tracked + +import ( + "crypto/tls" + "io" + "net" + "runtime" + "strings" + "testing" +) + +type closeWriter interface { + CloseWrite() error +} + +func TestBuildTCP(t *testing.T) { + wc, co := Builder{TrackTraffic: true}.Build(new(net.TCPConn)) + if co == nil { + t.Error("Expected a connection observer") + } + if _, ok := wc.(io.ReaderFrom); ok != (strings.HasPrefix(runtime.GOOS, "linux")) { + t.Error("ReaderFrom missmatch") + } + if _, ok := wc.(io.WriterTo); ok { + t.Error("Unexpected WriterTo") + } + if _, ok := wc.(closeWriter); !ok { + t.Error("Missing CloseWrite") + } +} + +func TestBuildTLS(t *testing.T) { + wc, co := Builder{TrackTraffic: true}.Build(new(tls.Conn)) + if co == nil { + t.Error("Expected a connection observer") + } + if _, ok := wc.(io.ReaderFrom); ok { + t.Error("Unexpected ReaderFrom") + } + if _, ok := wc.(io.WriterTo); ok { + t.Error("Unexpected WriterTo") + } + if _, ok := wc.(closeWriter); !ok { + t.Error("Missing CloseWrite") + } +} + +func TestBuildOnClose(t *testing.T) { + var closed bool + wc, co := Builder{OnClose: func() { closed = true }}.Build(new(net.TCPConn)) + if co != nil { + t.Error("Unexpected connection observer") + } + if _, ok := wc.(io.ReaderFrom); ok { + t.Error("Unexpected ReaderFrom") + } + if _, ok := wc.(io.WriterTo); ok { + t.Error("Unexpected WriterTo") + } + if _, ok := wc.(closeWriter); !ok { + t.Error("Missing CloseWrite") + } + wc.Close() + if !closed { + t.Error("OnClose not called") + } +}