Skip to content

Commit

Permalink
fix(tor): improve tor identity switching mechanism
Browse files Browse the repository at this point in the history
Refactored the method to request a new Tor identity using a separate control port dial function. This update improves error handling by confirming responses from the Tor control port. Also added tests to ensure robust connection behavior for valid and invalid addresses.
  • Loading branch information
ryanbekhen committed Dec 8, 2024
1 parent 3f0bac5 commit d76b124
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 18 deletions.
21 changes: 9 additions & 12 deletions pkg/tor/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"fmt"
"github.com/rs/zerolog"
"strings"
)

type Controller struct {
Expand All @@ -15,24 +16,20 @@ func NewTorController(dialer Dialer) *Controller {
}

func (t *Controller) RequestNewTorIdentity(logger *zerolog.Logger) error {
conn, err := t.dialer.Dial("tcp", "127.0.0.1:9051")
conn, err := t.dialer.DialControlPort("tcp", "127.0.0.1:9051")
if err != nil {
return fmt.Errorf("failed to connect to tor control port: %w", err)
}
defer conn.Close()

_, _ = fmt.Fprintf(conn, "AUTHENTICATE\r\n")
_, _ = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n")

authStatus, err := bufio.NewReader(conn).ReadString('\n')
if err != nil || authStatus != "250 OK\r\n" {
return fmt.Errorf("failed to authenticate with tor control port: %w", err)
_, _ = fmt.Fprintf(conn, "AUTHENTICATE \"\"\r\n")
_, err = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n")
if err != nil {
return fmt.Errorf("failed to request new identity: %w", err)
}

_, _ = fmt.Fprintf(conn, "SIGNAL NEWNYM\r\n")
status, err := bufio.NewReader(conn).ReadString('\n')
if err != nil || status != "250 OK\r\n" {
return fmt.Errorf("failed to switch tor identity: %w", err)
signalResponse, _ := bufio.NewReader(conn).ReadString('\n')
if !strings.HasPrefix(signalResponse, "250") {
return fmt.Errorf("failed to switch tor identity: %v", signalResponse)
}

if logger != nil {
Expand Down
7 changes: 7 additions & 0 deletions pkg/tor/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ func (md *MockDialer) Dial(network, address string) (net.Conn, error) {
return &MockConn{responses: []string{"250 OK\r\n", "250 OK\r\n"}}, nil
}

func (md *MockDialer) DialControlPort(network, address string) (net.Conn, error) {
if md.shouldFail {
return nil, fmt.Errorf("failed to connect to tor control port")
}
return &MockConn{responses: []string{"250 OK\r\n", "250 OK\r\n"}}, nil
}

func TestRequestNewTorIdentity_Success(t *testing.T) {
logger := zerolog.New(zerolog.ConsoleWriter{Out: &bytes.Buffer{}}).With().Logger()
dialer := &MockDialer{shouldFail: false}
Expand Down
12 changes: 6 additions & 6 deletions pkg/tor/dial.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tor

import (
"fmt"
"golang.org/x/net/proxy"
"net"
)
Expand All @@ -10,15 +9,16 @@ var customSOCKS5 = proxy.SOCKS5

type Dialer interface {
Dial(network, address string) (net.Conn, error)
DialControlPort(network, address string) (net.Conn, error)
}

type DefaultDialer struct{}

func (d DefaultDialer) Dial(network, address string) (net.Conn, error) {
dialer, err := customSOCKS5("tcp", "localhost:9050", nil, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create tor dialer: %w", err)
}

dialer, _ := customSOCKS5("tcp", "localhost:9050", nil, proxy.Direct)
return dialer.Dial(network, address)
}

func (d DefaultDialer) DialControlPort(network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
26 changes: 26 additions & 0 deletions pkg/tor/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,29 @@ func TestDefaultDialer_Dial_Failure(t *testing.T) {
assert.NotNil(t, err, "expected an error during dial failure simulation")
assert.Nil(t, conn, "expected no connection on dial failure")
}

func TestDefaultDialer_DialControlPort(t *testing.T) {
tests := []struct {
name string
address string
wantErr bool
}{
{"successful connection", "localhost:9051", false},
{"failed connection", "invalid:address", true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dialer := DefaultDialer{}
conn, err := dialer.DialControlPort("tcp", tt.address)

if tt.wantErr {
assert.NotNil(t, err, "expected an error for: %s", tt.name)
assert.Nil(t, conn, "expected no connection for: %s", tt.name)
} else {
assert.Nil(t, err, "expected no error for: %s", tt.name)
assert.NotNil(t, conn, "expected a valid connection for: %s", tt.name)
}
})
}
}

0 comments on commit d76b124

Please sign in to comment.