From 3e6e2effd7a309234b90b79b55fcbbf078361fee Mon Sep 17 00:00:00 2001 From: bill fort Date: Fri, 18 Aug 2023 10:43:34 -0600 Subject: [PATCH] Add GetFreePort in both TCP and UDP to avoid UDP listen fail Signed-off-by: bill fort --- client.go | 6 +++++- udp.go | 4 ++++ util.go | 41 +++++++++++++++++++++++++++++++++++++++++ util_test.go | 12 ++++++++++++ 4 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 util_test.go diff --git a/client.go b/client.go index 248bdb0..c98148c 100644 --- a/client.go +++ b/client.go @@ -898,7 +898,11 @@ func (c *TunaSessionClient) startExits() error { listeners := make([]net.Listener, c.config.NumTunaListeners) var err error for i := 0; i < len(listeners); i++ { - listeners[i], err = net.Listen("tcp", "127.0.0.1:") + port, err := GetFreePort(0) + if err != nil { + return err + } + listeners[i], err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", port)) if err != nil { return err } diff --git a/udp.go b/udp.go index b10be80..1787c02 100644 --- a/udp.go +++ b/udp.go @@ -109,6 +109,10 @@ func (c *TunaSessionClient) DialUDPWithConfig(remoteAddr string, config *nkn.Dia } func (c *TunaSessionClient) handleUdpListenerTcp(tcpConn *Conn, remoteAddr string, sessionID []byte) { + if c.listenerUdpSess == nil { // Udp listener not started + tcpConn.Close() + return + } sessKey := sessionKey(remoteAddr, sessionID) c.listenerUdpSess.Lock() diff --git a/util.go b/util.go index 6e9d026..256a915 100644 --- a/util.go +++ b/util.go @@ -2,7 +2,11 @@ package session import ( "encoding/hex" + "errors" + "fmt" + "net" "strconv" + "sync" "time" ) @@ -17,3 +21,40 @@ func sessionKey(remoteAddr string, sessionID []byte) string { func connID(i int) string { return strconv.Itoa(i) } + +// Get free port start from parameter `port` +// If paramenter `port` is 0, return system available port +// The returned port is free in both TCP and UDP +var lock sync.Mutex + +func GetFreePort(port int) (int, error) { + // to avoid race condition + lock.Lock() + defer lock.Unlock() + + for i := 0; i < 100; i++ { + addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return 0, err + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return 0, err + } + defer l.Close() + + port = l.Addr().(*net.TCPAddr).Port + u, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: port}) + if err != nil { + l.Close() + port++ + continue + } + u.Close() + + return port, nil + } + + return 0, errors.New("failed to find free port after 100 tries") +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..8005fcd --- /dev/null +++ b/util_test.go @@ -0,0 +1,12 @@ +package session + +import "testing" + +// go test -v -run=TestGetFreePort +func TestGetFreePort(t *testing.T) { + port, err := GetFreePort(0) + if err != nil { + t.Error(err) + } + t.Log(port) +}