diff --git a/nanoproxy.go b/nanoproxy.go index e4a97ec..8182962 100644 --- a/nanoproxy.go +++ b/nanoproxy.go @@ -46,10 +46,18 @@ func main() { } if cfg.TorEnabled { - socks5Config.Dial = tor.Dial + torDialer := &tor.DefaultDialer{} + socks5Config.Dial = torDialer.Dial logger.Info().Msg("Tor mode enabled") - go tor.SwitcherIdentity(&logger, cfg.TorIdentityInterval) + torController := tor.NewTorController(torDialer) + ch := make(chan bool) + go tor.SwitcherIdentity(&logger, torController, cfg.TorIdentityInterval, ch) + + go func() { + <-ch + logger.Fatal().Msg("Tor identity switcher stopped") + }() } sock5Server := socks5.New(&socks5Config) diff --git a/pkg/tor/controller.go b/pkg/tor/controller.go new file mode 100644 index 0000000..516238f --- /dev/null +++ b/pkg/tor/controller.go @@ -0,0 +1,43 @@ +package tor + +import ( + "bufio" + "fmt" + "github.com/rs/zerolog" +) + +type Controller struct { + dialer Dialer +} + +func NewTorController(dialer Dialer) *Controller { + return &Controller{dialer: dialer} +} + +func (t *Controller) RequestNewTorIdentity(logger *zerolog.Logger) error { + conn, err := t.dialer.Dial("tcp", "127.0.0.1:9051") + if err != nil { + return fmt.Errorf("failed to connect to tor control port: %w", err) + } + defer conn.Close() + + _, _ = fmt.Fprintf(conn, "AUTHENTICATE\r\n") + _, _ = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n") + + authStatus, err := bufio.NewReader(conn).ReadString('\n') + if err != nil || authStatus != "250 OK\r\n" { + return fmt.Errorf("failed to authenticate with tor control port: %w", err) + } + + _, _ = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n") + status, err := bufio.NewReader(conn).ReadString('\n') + if err != nil || status != "250 OK\r\n" { + return fmt.Errorf("failed to switch tor identity: %w", err) + } + + if logger != nil { + logger.Info().Msg("Tor identity changed") + } + + return nil +} diff --git a/pkg/tor/controller_test.go b/pkg/tor/controller_test.go new file mode 100644 index 0000000..0f3f547 --- /dev/null +++ b/pkg/tor/controller_test.go @@ -0,0 +1,68 @@ +package tor_test + +import ( + "bytes" + "fmt" + "github.com/ryanbekhen/nanoproxy/pkg/tor" + "net" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" +) + +type MockConn struct { + net.Conn + responses []string + writeBuf []string + index int +} + +func (mc *MockConn) Read(b []byte) (n int, err error) { + if mc.index >= len(mc.responses) { + return 0, fmt.Errorf("EOF") + } + copy(b, mc.responses[mc.index]) + mc.index++ + return len(mc.responses[mc.index-1]), nil +} + +func (mc *MockConn) Write(b []byte) (n int, err error) { + mc.writeBuf = append(mc.writeBuf, string(b)) + return len(b), nil +} + +func (mc *MockConn) Close() error { + return nil +} + +type MockDialer struct { + net.Conn + shouldFail bool +} + +func (md *MockDialer) Dial(network, address string) (net.Conn, error) { + if md.shouldFail { + return nil, fmt.Errorf("failed to connect to tor control port") + } + return &MockConn{responses: []string{"250 OK\r\n", "250 OK\r\n"}}, nil +} + +func TestRequestNewTorIdentity_Success(t *testing.T) { + logger := zerolog.New(zerolog.ConsoleWriter{Out: &bytes.Buffer{}}).With().Logger() + dialer := &MockDialer{shouldFail: false} + torController := tor.NewTorController(dialer) + + err := torController.RequestNewTorIdentity(&logger) + assert.Nil(t, err, "expected no error during successful RequestNewTorIdentity call") +} + +func TestRequestNewTorIdentity_FailConnect(t *testing.T) { + logger := zerolog.New(zerolog.ConsoleWriter{Out: &bytes.Buffer{}}).With().Logger() + dialer := &MockDialer{shouldFail: true} + torController := tor.NewTorController(dialer) + + err := torController.RequestNewTorIdentity(&logger) + assert.NotNil(t, err, "expected error when connection fails") + assert.Contains(t, err.Error(), "failed to connect to tor control port") +} diff --git a/pkg/tor/dial.go b/pkg/tor/dial.go index 0309ca7..1fc0131 100644 --- a/pkg/tor/dial.go +++ b/pkg/tor/dial.go @@ -6,11 +6,19 @@ import ( "net" ) -func Dial(network, addr string) (net.Conn, error) { - dialer, err := proxy.SOCKS5("tcp", "localhost:9050", nil, proxy.Direct) +var customSOCKS5 = proxy.SOCKS5 + +type Dialer interface { + Dial(network, address string) (net.Conn, error) +} + +type DefaultDialer struct{} + +func (d DefaultDialer) Dial(network, address string) (net.Conn, error) { + dialer, err := customSOCKS5("tcp", "localhost:9050", nil, proxy.Direct) if err != nil { return nil, fmt.Errorf("failed to create tor dialer: %w", err) } - return dialer.Dial(network, addr) + return dialer.Dial(network, address) } diff --git a/pkg/tor/dial_test.go b/pkg/tor/dial_test.go new file mode 100644 index 0000000..c74b799 --- /dev/null +++ b/pkg/tor/dial_test.go @@ -0,0 +1,33 @@ +package tor + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "golang.org/x/net/proxy" + "testing" +) + +func TestDial(t *testing.T) { + network := "tcp" + addr := "example.com:80" + + conn, err := DefaultDialer{}.Dial(network, addr) + assert.Nil(t, err) + defer conn.Close() + assert.NotNil(t, conn) +} + +func TestDial_Error(t *testing.T) { + originalSOCKS5 := customSOCKS5 + customSOCKS5 = func(network, address string, auth *proxy.Auth, forward proxy.Dialer) (proxy.Dialer, error) { + return nil, fmt.Errorf("simulated SOCKS5 error") + } + defer func() { customSOCKS5 = originalSOCKS5 }() + + network := "tcp" + addr := "example.com:80" + conn, err := DefaultDialer{}.Dial(network, addr) + + assert.NotNil(t, err, "expected an error when dialing with simulated SOCKS5 error") + assert.Nil(t, conn, "expected no connection to be returned on error") +} diff --git a/pkg/tor/identity.go b/pkg/tor/identity.go index 2981a02..2bf4ac7 100644 --- a/pkg/tor/identity.go +++ b/pkg/tor/identity.go @@ -1,19 +1,17 @@ package tor import ( - "bufio" "fmt" "github.com/rs/zerolog" - "net" "time" ) -func waitForTorBootstrap(logger *zerolog.Logger, timeout time.Duration) error { +func WaitForTorBootstrap(logger *zerolog.Logger, requester Requester, timeout time.Duration) error { complete := make(chan bool) go func() { for { - if requestNewTorIdentity(nil) == nil { + if requester.RequestNewTorIdentity(nil) == nil { complete <- true break } @@ -30,44 +28,21 @@ func waitForTorBootstrap(logger *zerolog.Logger, timeout time.Duration) error { } } -func SwitcherIdentity(logger *zerolog.Logger, switchInterval time.Duration) { - if err := waitForTorBootstrap(logger, 5*time.Minute); err != nil { +func SwitcherIdentity(logger *zerolog.Logger, requester Requester, switchInterval time.Duration, done <-chan bool) { + if err := WaitForTorBootstrap(logger, requester, 5*time.Minute); err != nil { logger.Error().Msg(err.Error()) return } for { - if err := requestNewTorIdentity(logger); err != nil { - logger.Error().Msg(err.Error()) + select { + case <-done: + return + default: + if err := requester.RequestNewTorIdentity(logger); err != nil { + logger.Error().Msg(err.Error()) + } + time.Sleep(switchInterval) } - time.Sleep(switchInterval) - } -} - -func requestNewTorIdentity(logger *zerolog.Logger) error { - conn, err := net.Dial("tcp", "127.0.0.1:9051") - if err != nil { - return fmt.Errorf("failed to connect to tor control port: %w", err) - } - defer conn.Close() - - _, _ = fmt.Fprintf(conn, "AUTHENTICATE\r\n") - _, _ = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n") - - authStatus, err := bufio.NewReader(conn).ReadString('\n') - if err != nil || authStatus != "250 OK\r\n" { - return fmt.Errorf("failed to authenticate with tor control port: %w", err) - } - - _, _ = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n") - status, err := bufio.NewReader(conn).ReadString('\n') - if err != nil || status != "250 OK\r\n" { - return fmt.Errorf("failed to switch tor identity: %w", err) } - - if logger != nil { - logger.Info().Msg("Tor identity changed") - } - - return nil } diff --git a/pkg/tor/identity_test.go b/pkg/tor/identity_test.go new file mode 100644 index 0000000..278b3f4 --- /dev/null +++ b/pkg/tor/identity_test.go @@ -0,0 +1,45 @@ +package tor_test + +import ( + "bytes" + "fmt" + "github.com/ryanbekhen/nanoproxy/pkg/tor" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" +) + +type MockRequester struct { + shouldFail bool + callCount int +} + +func (m *MockRequester) RequestNewTorIdentity(logger *zerolog.Logger) error { + m.callCount++ + if m.shouldFail { + return fmt.Errorf("simulated failure") + } + return nil +} + +func TestSwitcherIdentity(t *testing.T) { + logger := zerolog.New(zerolog.ConsoleWriter{Out: &bytes.Buffer{}}).With().Logger() + requester := &MockRequester{shouldFail: false} + done := make(chan bool) + + // Set up a Goroutine to stop the SwitcherIdentity after a short delay + go func() { + time.Sleep(10 * time.Millisecond) + done <- true + }() + + // Call the SwitcherIdentity function with a very short interval + go tor.SwitcherIdentity(&logger, requester, 1*time.Millisecond, done) + + // Wait for a moment to ensure goroutine have run + time.Sleep(15 * time.Millisecond) + + assert.True(t, requester.callCount > 0, "expected SwitcherIdentity to call RequestNewTorIdentity multiple times") +} diff --git a/pkg/tor/requester.go b/pkg/tor/requester.go new file mode 100644 index 0000000..d610038 --- /dev/null +++ b/pkg/tor/requester.go @@ -0,0 +1,7 @@ +package tor + +import "github.com/rs/zerolog" + +type Requester interface { + RequestNewTorIdentity(logger *zerolog.Logger) error +}