From d76b124bce8909c2c4ce74643425210de9c64f13 Mon Sep 17 00:00:00 2001 From: Achmad Irianto Eka Putra Date: Sun, 8 Dec 2024 08:44:28 +0700 Subject: [PATCH] fix(tor): improve tor identity switching mechanism Refactored the method to request a new Tor identity using a separate control port dial function. This update improves error handling by confirming responses from the Tor control port. Also added tests to ensure robust connection behavior for valid and invalid addresses. --- pkg/tor/controller.go | 21 +++++++++------------ pkg/tor/controller_test.go | 7 +++++++ pkg/tor/dial.go | 12 ++++++------ pkg/tor/dial_test.go | 26 ++++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 18 deletions(-) diff --git a/pkg/tor/controller.go b/pkg/tor/controller.go index 516238f..9322f9d 100644 --- a/pkg/tor/controller.go +++ b/pkg/tor/controller.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "github.com/rs/zerolog" + "strings" ) type Controller struct { @@ -15,24 +16,20 @@ func NewTorController(dialer Dialer) *Controller { } func (t *Controller) RequestNewTorIdentity(logger *zerolog.Logger) error { - conn, err := t.dialer.Dial("tcp", "127.0.0.1:9051") + conn, err := t.dialer.DialControlPort("tcp", "127.0.0.1:9051") if err != nil { return fmt.Errorf("failed to connect to tor control port: %w", err) } defer conn.Close() - _, _ = fmt.Fprintf(conn, "AUTHENTICATE\r\n") - _, _ = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n") - - authStatus, err := bufio.NewReader(conn).ReadString('\n') - if err != nil || authStatus != "250 OK\r\n" { - return fmt.Errorf("failed to authenticate with tor control port: %w", err) + _, _ = fmt.Fprintf(conn, "AUTHENTICATE \"\"\r\n") + _, err = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n") + if err != nil { + return fmt.Errorf("failed to request new identity: %w", err) } - - _, _ = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n") - status, err := bufio.NewReader(conn).ReadString('\n') - if err != nil || status != "250 OK\r\n" { - return fmt.Errorf("failed to switch tor identity: %w", err) + signalResponse, _ := bufio.NewReader(conn).ReadString('\n') + if !strings.HasPrefix(signalResponse, "250") { + return fmt.Errorf("failed to switch tor identity: %v", signalResponse) } if logger != nil { diff --git a/pkg/tor/controller_test.go b/pkg/tor/controller_test.go index 0f3f547..5d4a5fc 100644 --- a/pkg/tor/controller_test.go +++ b/pkg/tor/controller_test.go @@ -48,6 +48,13 @@ func (md *MockDialer) Dial(network, address string) (net.Conn, error) { return &MockConn{responses: []string{"250 OK\r\n", "250 OK\r\n"}}, nil } +func (md *MockDialer) DialControlPort(network, address string) (net.Conn, error) { + if md.shouldFail { + return nil, fmt.Errorf("failed to connect to tor control port") + } + return &MockConn{responses: []string{"250 OK\r\n", "250 OK\r\n"}}, nil +} + func TestRequestNewTorIdentity_Success(t *testing.T) { logger := zerolog.New(zerolog.ConsoleWriter{Out: &bytes.Buffer{}}).With().Logger() dialer := &MockDialer{shouldFail: false} diff --git a/pkg/tor/dial.go b/pkg/tor/dial.go index 1fc0131..637a7f7 100644 --- a/pkg/tor/dial.go +++ b/pkg/tor/dial.go @@ -1,7 +1,6 @@ package tor import ( - "fmt" "golang.org/x/net/proxy" "net" ) @@ -10,15 +9,16 @@ var customSOCKS5 = proxy.SOCKS5 type Dialer interface { Dial(network, address string) (net.Conn, error) + DialControlPort(network, address string) (net.Conn, error) } type DefaultDialer struct{} func (d DefaultDialer) Dial(network, address string) (net.Conn, error) { - dialer, err := customSOCKS5("tcp", "localhost:9050", nil, proxy.Direct) - if err != nil { - return nil, fmt.Errorf("failed to create tor dialer: %w", err) - } - + dialer, _ := customSOCKS5("tcp", "localhost:9050", nil, proxy.Direct) return dialer.Dial(network, address) } + +func (d DefaultDialer) DialControlPort(network, address string) (net.Conn, error) { + return net.Dial(network, address) +} diff --git a/pkg/tor/dial_test.go b/pkg/tor/dial_test.go index 12f3191..862b6cf 100644 --- a/pkg/tor/dial_test.go +++ b/pkg/tor/dial_test.go @@ -44,3 +44,29 @@ func TestDefaultDialer_Dial_Failure(t *testing.T) { assert.NotNil(t, err, "expected an error during dial failure simulation") assert.Nil(t, conn, "expected no connection on dial failure") } + +func TestDefaultDialer_DialControlPort(t *testing.T) { + tests := []struct { + name string + address string + wantErr bool + }{ + {"successful connection", "localhost:9051", false}, + {"failed connection", "invalid:address", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dialer := DefaultDialer{} + conn, err := dialer.DialControlPort("tcp", tt.address) + + if tt.wantErr { + assert.NotNil(t, err, "expected an error for: %s", tt.name) + assert.Nil(t, conn, "expected no connection for: %s", tt.name) + } else { + assert.Nil(t, err, "expected no error for: %s", tt.name) + assert.NotNil(t, conn, "expected a valid connection for: %s", tt.name) + } + }) + } +}