diff --git a/bin/main.go b/bin/main.go index f67ee77..ab1b9a7 100644 --- a/bin/main.go +++ b/bin/main.go @@ -25,6 +25,7 @@ func main() { identifier := flag.String("i", "", "NKN address identifier") from := flag.String("from", "", `listening at address (omitted or "nkn" for listening on nkn address, ip:port for tcp address)`) to := flag.String("to", "", "dialing to address (nkn address or ip:port)") + dialTimeout := flag.Int("t", 0, "dial timeout in milliseconds") acceptAddr := flag.String("accept", "", "accept incoming nkn address regex, separated by comma") useTuna := flag.Bool("tuna", false, "use tuna instead of nkn client for nkn session") tunaCountry := flag.String("country", "", `tuna service node allowed country code, separated by comma, e.g. "US" or "US,CN"`) @@ -80,6 +81,9 @@ func main() { walletConfig := &nkn.WalletConfig{ SeedRPCServerAddr: seedRPCServerAddr, } + dialConfig := &nkn.DialConfig{ + DialTimeout: int32(*dialTimeout), + } var tsConfig *ts.Config if *useTuna { @@ -107,6 +111,7 @@ func main() { AcceptAddrs: acceptAddrs, ClientConfig: clientConfig, WalletConfig: walletConfig, + DialConfig: dialConfig, TunaSessionConfig: tsConfig, Verbose: *verbose, } diff --git a/config.go b/config.go index c23ebff..cd55d28 100644 --- a/config.go +++ b/config.go @@ -13,6 +13,7 @@ type Config struct { AcceptAddrs *nkngomobile.StringArray ClientConfig *nkn.ClientConfig WalletConfig *nkn.WalletConfig + DialConfig *nkn.DialConfig TunaSessionConfig *ts.Config Verbose bool } @@ -23,6 +24,7 @@ var defaultConfig = Config{ AcceptAddrs: nil, ClientConfig: nil, WalletConfig: nil, + DialConfig: nil, TunaSessionConfig: nil, Verbose: false, } diff --git a/go.mod b/go.mod index b1b6a31..d2bda64 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/imdario/mergo v0.3.9 github.com/nknorg/ncp-go v1.0.4-0.20220224111535-206abfb10fe8 github.com/nknorg/nkn-sdk-go v1.3.8 - github.com/nknorg/nkn-tuna-session v0.2.1 + github.com/nknorg/nkn-tuna-session v0.2.2 github.com/nknorg/nkngomobile v0.0.0-20220125080321-848ddd2e5157 github.com/nknorg/tuna v0.0.0-20220224114148-597496bdcb11 ) diff --git a/go.sum b/go.sum index 4d4fc18..437d4f5 100644 --- a/go.sum +++ b/go.sum @@ -253,8 +253,8 @@ github.com/nknorg/ncp-go v1.0.4-0.20220224111535-206abfb10fe8/go.mod h1:ALtnk9lK github.com/nknorg/nkn-sdk-go v1.3.7/go.mod h1:JSksFP+VQ0S54Ztiht6WHC3tNZklcGg+JaxENuFnqRc= github.com/nknorg/nkn-sdk-go v1.3.8 h1:t4lcHYcEC3ylWDgbObbw2zXBuEVgcMIC3jAOKGgRzmg= github.com/nknorg/nkn-sdk-go v1.3.8/go.mod h1:/2FtpRM4mWpze03V8FIoESQCa6wCdPRPQO9HqpIMNYw= -github.com/nknorg/nkn-tuna-session v0.2.1 h1:+xETWix5dmplLQqFgBwGVNJs5nAhjCZU5sPP4ll/GxI= -github.com/nknorg/nkn-tuna-session v0.2.1/go.mod h1:k+i+HePAOswU74klYpUawqOBKnTqvB9BfUzFTFaggMU= +github.com/nknorg/nkn-tuna-session v0.2.2 h1:3RF0Me8orGt+ZKr00iPPZb/iagKBB0mlbjCuu9xFJBs= +github.com/nknorg/nkn-tuna-session v0.2.2/go.mod h1:k+i+HePAOswU74klYpUawqOBKnTqvB9BfUzFTFaggMU= github.com/nknorg/nkn/v2 v2.0.6/go.mod h1:cXl2WTv72trEXKJiNH0dCMygMtL8nJne07dWajDlRIo= github.com/nknorg/nkn/v2 v2.1.7/go.mod h1:4xzrHJCI/FDFZmlt606Mn9ScKY4UUCFoaWydL1TzQRs= github.com/nknorg/nkn/v2 v2.1.8 h1:h25rqQ0E8CvlN8Jm4zF6CBBLgwdoSS7HHdrU4ZYcmjA= diff --git a/tunnel.go b/tunnel.go index 573cfb7..1db63f5 100644 --- a/tunnel.go +++ b/tunnel.go @@ -6,8 +6,10 @@ import ( "net" "strings" "sync" + "time" "github.com/hashicorp/go-multierror" + "github.com/nknorg/ncp-go" nkn "github.com/nknorg/nkn-sdk-go" ts "github.com/nknorg/nkn-tuna-session" "github.com/nknorg/nkngomobile" @@ -16,6 +18,7 @@ import ( type nknDialer interface { Addr() net.Addr Dial(addr string) (net.Conn, error) + DialWithConfig(addr string, config *nkn.DialConfig) (*ncp.Session, error) Close() error } @@ -173,9 +176,13 @@ func (t *Tunnel) SetAcceptAddrs(addrsRe *nkngomobile.StringArray) error { func (t *Tunnel) dial(addr string) (net.Conn, error) { if t.toNKN { - return t.dialer.Dial(addr) + return t.dialer.DialWithConfig(addr, t.config.DialConfig) } - return net.Dial("tcp", addr) + var dialTimeout time.Duration + if t.config.DialConfig != nil { + dialTimeout = time.Duration(t.config.DialConfig.DialTimeout) * time.Millisecond + } + return net.DialTimeout("tcp", addr, dialTimeout) } // Start starts the tunnel and will return on error. @@ -198,6 +205,7 @@ func (t *Tunnel) Start() error { toConn, err := t.dial(t.to) if err != nil { log.Println(err) + fromConn.Close() return }