From 2e437a0d56fdc130ada4936766c1dd509756a14e Mon Sep 17 00:00:00 2001 From: Lonny Wong Date: Sun, 30 Jun 2024 09:26:50 +0800 Subject: [PATCH] support QUIC protocol --- .goreleaser.yaml | 1 + README.md | 2 +- go.mod | 13 +++- go.sum | 30 +++++++- tsshd/bus.go | 64 +++++++++------- tsshd/forward.go | 44 ++++++----- tsshd/main.go | 27 ++++++- tsshd/proto.go | 192 +++++++++++++++++++++++++++++++++++------------ tsshd/server.go | 163 +++++++++++++++++++++++++++++++++------- tsshd/service.go | 120 +++++++++++++++++++++-------- tsshd/session.go | 83 ++++++++++---------- 11 files changed, 528 insertions(+), 211 deletions(-) diff --git a/.goreleaser.yaml b/.goreleaser.yaml index af4154b..d3eaaaa 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -17,6 +17,7 @@ builds: - amd64 - arm - arm64 + - loong64 goarm: - "6" - "7" diff --git a/README.md b/README.md index 0958422..da533cb 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ The [`tssh --udp`](https://github.com/trzsz/trzsz-ssh) works like [`mosh`](https ## Advanced Features -- Low latency ( based on kcp ) +- Low latency ( based on QUIC / KCP ) - Port forwarding ( same as ssh ) diff --git a/go.mod b/go.mod index 17d9811..cd57ff6 100644 --- a/go.mod +++ b/go.mod @@ -5,20 +5,29 @@ go 1.20 require ( github.com/UserExistsError/conpty v0.1.3 github.com/creack/pty v1.1.21 + github.com/quic-go/quic-go v0.40.1 github.com/trzsz/go-arg v1.5.3 github.com/xtaci/kcp-go/v5 v5.6.1 + github.com/xtaci/smux v1.5.24 golang.org/x/crypto v0.24.0 golang.org/x/sys v0.21.0 ) require ( github.com/alexflint/go-scalar v1.2.0 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 // indirect + github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/klauspost/reedsolomon v1.12.1 // indirect + github.com/onsi/ginkgo/v2 v2.19.0 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/stretchr/testify v1.8.4 // indirect - github.com/templexxx/cpu v0.1.0 // indirect + github.com/quic-go/qtls-go1-20 v0.4.1 // indirect + github.com/templexxx/cpu v0.1.1-0.20240303154708-598a14b050c5 // indirect github.com/templexxx/xorsimd v0.4.2 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect + go.uber.org/mock v0.4.0 // indirect + golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect + golang.org/x/mod v0.18.0 // indirect golang.org/x/net v0.26.0 // indirect + golang.org/x/tools v0.22.0 // indirect ) diff --git a/go.sum b/go.sum index 84dabbc..8ce9c4b 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,9 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -30,6 +33,9 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 h1:e+8XbKB6IMn8A4OAyZccO4pYfB3s7bt6azNIPE7AnPg= +github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/klauspost/cpuid v1.2.4/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid v1.3.1/go.mod h1:bYW4mA6ZgKPob1/Dlai2LviZJO7KGI3uoWLd42rAQw4= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= @@ -38,20 +44,27 @@ github.com/klauspost/reedsolomon v1.9.9/go.mod h1:O7yFFHiQwDR6b2t63KPUpccPtNdp5A github.com/klauspost/reedsolomon v1.12.1 h1:NhWgum1efX1x58daOBGCFWcxtEhOhXKKl1HAPQUp03Q= github.com/klauspost/reedsolomon v1.12.1/go.mod h1:nEi5Kjb6QqtbofI6s+cbG/j1da11c96IBYBSnVGtuBs= github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104/go.mod h1:wqKykBG2QzQDJEzvRkcS8x6MiSJkF52hXZsXcjaB3ls= +github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA= +github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= +github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs= +github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/quic-go v0.40.1 h1:X3AGzUNFs0jVuO3esAGnTfvdgvL4fq655WaOi1snv1Q= +github.com/quic-go/quic-go v0.40.1/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/templexxx/cpu v0.0.1/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= github.com/templexxx/cpu v0.0.7/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= -github.com/templexxx/cpu v0.1.0 h1:wVM+WIJP2nYaxVxqgHPD4wGA2aJ9rvrQRV8CvFzNb40= github.com/templexxx/cpu v0.1.0/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= +github.com/templexxx/cpu v0.1.1-0.20240303154708-598a14b050c5 h1:Ke6p9WHBy8Ooz8Vg/+o9SHp5yE2VlzzyHVEfHTFmJoM= +github.com/templexxx/cpu v0.1.1-0.20240303154708-598a14b050c5/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= github.com/templexxx/xorsimd v0.4.1/go.mod h1:W+ffZz8jJMH2SXwuKu9WhygqBMbFnp14G2fqEr8qaNo= github.com/templexxx/xorsimd v0.4.2 h1:ocZZ+Nvu65LGHmCLZ7OoCtg8Fx8jnHKK37SjvngUoVI= github.com/templexxx/xorsimd v0.4.2/go.mod h1:HgwaPoDREdi6OnULpSfxhzaiiSUY4Fi3JPn1wpt28NI= @@ -64,8 +77,12 @@ github.com/xtaci/kcp-go/v5 v5.6.1 h1:Pwn0aoeNSPF9dTS7IgiPXn0HEtaIlVb6y5UKWPsx8bI github.com/xtaci/kcp-go/v5 v5.6.1/go.mod h1:W3kVPyNYwZ06p79dNwFWQOVFrdcBpDBsdyvK8moQrYo= github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae h1:J0GxkO96kL4WF+AIT3M4mfUVinOCPgf2uUWYFUzN0sM= github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae/go.mod h1:gXtu8J62kEgmN++bm9BVICuT/e8yiLI2KFobd/TRFsE= +github.com/xtaci/smux v1.5.24 h1:77emW9dtnOxxOQ5ltR+8BbsX1kzcOxQ5gB+aaV9hXOY= +github.com/xtaci/smux v1.5.24/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/arch v0.0.0-20190909030613-46d78d1859ac/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -76,11 +93,15 @@ golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -99,6 +120,7 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -111,6 +133,7 @@ golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -119,6 +142,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -137,6 +162,7 @@ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/tsshd/bus.go b/tsshd/bus.go index dc3e2af..f731efb 100644 --- a/tsshd/bus.go +++ b/tsshd/bus.go @@ -26,63 +26,64 @@ package tsshd import ( "fmt" + "net" "sync" "sync/atomic" "time" - - "github.com/xtaci/kcp-go/v5" ) +var serving atomic.Bool + var busMutex sync.Mutex -var busSession atomic.Pointer[kcp.UDPSession] +var busStream atomic.Pointer[net.Conn] var lastAliveTime atomic.Pointer[time.Time] func sendBusCommand(command string) error { busMutex.Lock() defer busMutex.Unlock() - session := busSession.Load() - if session == nil { - return fmt.Errorf("bus session is nil") + stream := busStream.Load() + if stream == nil { + return fmt.Errorf("bus stream is nil") } - return SendCommand(session, command) + return SendCommand(*stream, command) } func sendBusMessage(command string, msg any) error { busMutex.Lock() defer busMutex.Unlock() - session := busSession.Load() - if session == nil { - return fmt.Errorf("bus session is nil") + stream := busStream.Load() + if stream == nil { + return fmt.Errorf("bus stream is nil") } - if err := SendCommand(session, command); err != nil { + if err := SendCommand(*stream, command); err != nil { return err } - return SendMessage(session, msg) + return SendMessage(*stream, msg) } func trySendErrorMessage(format string, a ...any) { _ = sendBusMessage("error", ErrorMessage{fmt.Sprintf(format, a...)}) } -func handleBusEvent(session *kcp.UDPSession) { +func handleBusEvent(stream net.Conn) { var msg BusMessage - if err := RecvMessage(session, &msg); err != nil { - SendError(session, fmt.Errorf("recv bus message failed: %v", err)) + if err := RecvMessage(stream, &msg); err != nil { + SendError(stream, fmt.Errorf("recv bus message failed: %v", err)) return } busMutex.Lock() // only one bus - if !busSession.CompareAndSwap(nil, session) { + if !busStream.CompareAndSwap(nil, &stream) { busMutex.Unlock() - SendError(session, fmt.Errorf("bus has been initialized")) + SendError(stream, fmt.Errorf("bus has been initialized")) return } - if err := SendSuccess(session); err != nil { // ack ok + if err := SendSuccess(stream); err != nil { // ack ok busMutex.Unlock() trySendErrorMessage("bus ack ok failed: %v", err) return @@ -99,7 +100,7 @@ func handleBusEvent(session *kcp.UDPSession) { } for { - command, err := RecvCommand(session) + command, err := RecvCommand(stream) if err != nil { trySendErrorMessage("recv bus command failed: %v", err) return @@ -107,15 +108,15 @@ func handleBusEvent(session *kcp.UDPSession) { switch command { case "resize": - err = handleResizeEvent(session) + err = handleResizeEvent(stream) case "close": - exitChan <- true + exitChan <- 0 return case "alive": now := time.Now() lastAliveTime.Store(&now) default: - err = handleUnknownEvent(session) + err = handleUnknownEvent(stream) } if err != nil { trySendErrorMessage("handle bus command [%s] failed: %v", command, err) @@ -123,22 +124,31 @@ func handleBusEvent(session *kcp.UDPSession) { } } -func handleUnknownEvent(session *kcp.UDPSession) error { +func handleUnknownEvent(stream net.Conn) error { var msg struct{} - if err := RecvMessage(session, &msg); err != nil { + if err := RecvMessage(stream, &msg); err != nil { return fmt.Errorf("recv unknown message failed: %v", err) } return fmt.Errorf("unknown command") } func keepAlive(timeout time.Duration) { + sleepTime := timeout / 10 + if sleepTime > 10*time.Second { + sleepTime = 10 * time.Second + } + go func() { + for { + _ = sendBusCommand("alive") + time.Sleep(sleepTime) + } + }() for { - _ = sendBusCommand("alive") if t := lastAliveTime.Load(); t != nil && time.Since(*t) > timeout { trySendErrorMessage("tsshd keep alive timeout") - exitChan <- true + exitChan <- 2 return } - time.Sleep(timeout / 10) + time.Sleep(sleepTime) } } diff --git a/tsshd/forward.go b/tsshd/forward.go index 055c80c..acc86ea 100644 --- a/tsshd/forward.go +++ b/tsshd/forward.go @@ -30,18 +30,16 @@ import ( "net" "sync" "sync/atomic" - - "github.com/xtaci/kcp-go/v5" ) var acceptMutex sync.Mutex var acceptID atomic.Uint64 var acceptMap = make(map[uint64]net.Conn) -func handleDialEvent(session *kcp.UDPSession) { +func handleDialEvent(stream net.Conn) { var msg DialMessage - if err := RecvMessage(session, &msg); err != nil { - SendError(session, fmt.Errorf("recv dial message failed: %v", err)) + if err := RecvMessage(stream, &msg); err != nil { + SendError(stream, fmt.Errorf("recv dial message failed: %v", err)) return } @@ -53,36 +51,36 @@ func handleDialEvent(session *kcp.UDPSession) { conn, err = net.Dial(msg.Network, msg.Addr) } if err != nil { - SendError(session, fmt.Errorf("dial %s [%s] failed: %v", msg.Network, msg.Addr, err)) + SendError(stream, fmt.Errorf("dial %s [%s] failed: %v", msg.Network, msg.Addr, err)) return } defer conn.Close() - if err := SendSuccess(session); err != nil { // ack ok + if err := SendSuccess(stream); err != nil { // ack ok trySendErrorMessage("dial ack ok failed: %v", err) return } - forwardConnection(session, conn) + forwardConnection(stream, conn) } -func handleListenEvent(session *kcp.UDPSession) { +func handleListenEvent(stream net.Conn) { var msg ListenMessage - if err := RecvMessage(session, &msg); err != nil { - SendError(session, fmt.Errorf("recv listen message failed: %v", err)) + if err := RecvMessage(stream, &msg); err != nil { + SendError(stream, fmt.Errorf("recv listen message failed: %v", err)) return } listener, err := net.Listen(msg.Network, msg.Addr) if err != nil { - SendError(session, fmt.Errorf("listen on %s [%s] failed: %v", msg.Network, msg.Addr, err)) + SendError(stream, fmt.Errorf("listen on %s [%s] failed: %v", msg.Network, msg.Addr, err)) return } defer listener.Close() - if err := SendSuccess(session); err != nil { // ack ok + if err := SendSuccess(stream); err != nil { // ack ok trySendErrorMessage("listen ack ok failed: %v", err) return } @@ -99,7 +97,7 @@ func handleListenEvent(session *kcp.UDPSession) { acceptMutex.Lock() id := acceptID.Add(1) - 1 acceptMap[id] = conn - if err := SendMessage(session, AcceptMessage{id}); err != nil { + if err := SendMessage(stream, AcceptMessage{id}); err != nil { acceptMutex.Unlock() trySendErrorMessage("send accept message failed: %v", err) return @@ -108,10 +106,10 @@ func handleListenEvent(session *kcp.UDPSession) { } } -func handleAcceptEvent(session *kcp.UDPSession) { +func handleAcceptEvent(stream net.Conn) { var msg AcceptMessage - if err := RecvMessage(session, &msg); err != nil { - SendError(session, fmt.Errorf("recv accept message failed: %v", err)) + if err := RecvMessage(stream, &msg); err != nil { + SendError(stream, fmt.Errorf("recv accept message failed: %v", err)) return } @@ -120,30 +118,30 @@ func handleAcceptEvent(session *kcp.UDPSession) { conn, ok := acceptMap[msg.ID] if !ok { - SendError(session, fmt.Errorf("invalid accept id: %d", msg.ID)) + SendError(stream, fmt.Errorf("invalid accept id: %d", msg.ID)) return } delete(acceptMap, msg.ID) defer conn.Close() - if err := SendSuccess(session); err != nil { // ack ok + if err := SendSuccess(stream); err != nil { // ack ok trySendErrorMessage("accept ack ok failed: %v", err) return } - forwardConnection(session, conn) + forwardConnection(stream, conn) } -func forwardConnection(session *kcp.UDPSession, conn net.Conn) { +func forwardConnection(stream net.Conn, conn net.Conn) { var wg sync.WaitGroup wg.Add(2) go func() { - _, _ = io.Copy(conn, session) + _, _ = io.Copy(conn, stream) wg.Done() }() go func() { - _, _ = io.Copy(session, conn) + _, _ = io.Copy(stream, conn) wg.Done() }() wg.Wait() diff --git a/tsshd/main.go b/tsshd/main.go index f56555d..2f0cb9d 100644 --- a/tsshd/main.go +++ b/tsshd/main.go @@ -29,13 +29,17 @@ import ( "io" "os" "os/exec" + "time" "github.com/trzsz/go-arg" ) -const kTsshdVersion = "0.1.0" +const kTsshdVersion = "0.1.1" + +var exitChan = make(chan int, 1) type tsshdArgs struct { + KCP bool `arg:"--kcp" help:"KCP protocol (default is QUIC protocol)"` } func (tsshdArgs) Description() string { @@ -84,7 +88,7 @@ func TsshdMain() int { return 0 } - listener, err := initServer(&args) + kcpListener, quicListener, err := initServer(&args) if err != nil { fmt.Println(err) os.Stdout.Close() @@ -93,7 +97,22 @@ func TsshdMain() int { os.Stdout.Close() - serve(listener) + if kcpListener != nil { + defer kcpListener.Close() + go serveKCP(kcpListener) + } + if quicListener != nil { + defer quicListener.Close() + go serveQUIC(quicListener) + } + + go func() { + // should be connected within 20 seconds + time.Sleep(20 * time.Second) + if !serving.Load() { + exitChan <- 1 + } + }() - return 0 + return <-exitChan } diff --git a/tsshd/proto.go b/tsshd/proto.go index dacc434..75bb93e 100644 --- a/tsshd/proto.go +++ b/tsshd/proto.go @@ -25,22 +25,36 @@ SOFTWARE. package tsshd import ( + "context" + "crypto/sha1" + "crypto/tls" + "crypto/x509" "encoding/binary" + "encoding/hex" "encoding/json" "fmt" "io" + "net" + "strings" "time" + "github.com/quic-go/quic-go" "github.com/xtaci/kcp-go/v5" + "github.com/xtaci/smux" + "golang.org/x/crypto/pbkdf2" ) const kNoErrorMsg = "_TSSHD_NO_ERROR_" type ServerInfo struct { - Ver string - Pass string - Salt string - Port int + Ver string + Port int + Mode string + Pass string + Salt string + ServerCert string + ClientCert string + ClientKey string } type ErrorMessage struct { @@ -105,7 +119,7 @@ func writeAll(dst io.Writer, data []byte) error { return nil } -func SendCommand(session *kcp.UDPSession, command string) error { +func SendCommand(stream net.Conn, command string) error { if len(command) == 0 { return fmt.Errorf("send command is empty") } @@ -115,25 +129,25 @@ func SendCommand(session *kcp.UDPSession, command string) error { buffer := make([]byte, len(command)+1) buffer[0] = uint8(len(command)) copy(buffer[1:], []byte(command)) - if err := writeAll(session, buffer); err != nil { + if err := writeAll(stream, buffer); err != nil { return fmt.Errorf("send command write buffer failed: %v", err) } return nil } -func RecvCommand(session *kcp.UDPSession) (string, error) { +func RecvCommand(stream net.Conn) (string, error) { length := make([]byte, 1) - if _, err := session.Read(length); err != nil { + if _, err := stream.Read(length); err != nil { return "", fmt.Errorf("recv command read length failed: %v", err) } command := make([]byte, length[0]) - if _, err := io.ReadFull(session, command); err != nil { + if _, err := io.ReadFull(stream, command); err != nil { return "", fmt.Errorf("recv command read buffer failed: %v", err) } return string(command), nil } -func SendMessage(session *kcp.UDPSession, msg any) error { +func SendMessage(stream net.Conn, msg any) error { msgBuf, err := json.Marshal(msg) if err != nil { return fmt.Errorf("send message marshal failed: %v", err) @@ -141,19 +155,19 @@ func SendMessage(session *kcp.UDPSession, msg any) error { buffer := make([]byte, len(msgBuf)+4) binary.BigEndian.PutUint32(buffer, uint32(len(msgBuf))) copy(buffer[4:], msgBuf) - if err := writeAll(session, buffer); err != nil { + if err := writeAll(stream, buffer); err != nil { return fmt.Errorf("send message write buffer failed: %v", err) } return nil } -func RecvMessage(session *kcp.UDPSession, msg any) error { +func RecvMessage(stream net.Conn, msg any) error { lenBuf := make([]byte, 4) - if _, err := io.ReadFull(session, lenBuf); err != nil { + if _, err := io.ReadFull(stream, lenBuf); err != nil { return fmt.Errorf("recv message read length failed: %v", err) } msgBuf := make([]byte, binary.BigEndian.Uint32(lenBuf)) - if _, err := io.ReadFull(session, msgBuf); err != nil { + if _, err := io.ReadFull(stream, msgBuf); err != nil { return fmt.Errorf("recv message read buffer failed: %v", err) } if err := json.Unmarshal(msgBuf, msg); err != nil { @@ -162,19 +176,19 @@ func RecvMessage(session *kcp.UDPSession, msg any) error { return nil } -func SendError(session *kcp.UDPSession, err error) { - if e := SendMessage(session, ErrorMessage{err.Error()}); e != nil { +func SendError(stream net.Conn, err error) { + if e := SendMessage(stream, ErrorMessage{err.Error()}); e != nil { trySendErrorMessage("send error [%v] failed: %v", err, e) } } -func SendSuccess(session *kcp.UDPSession) error { - return SendMessage(session, ErrorMessage{kNoErrorMsg}) +func SendSuccess(stream net.Conn) error { + return SendMessage(stream, ErrorMessage{kNoErrorMsg}) } -func RecvError(session *kcp.UDPSession) error { +func RecvError(stream net.Conn) error { var errMsg ErrorMessage - if err := RecvMessage(session, &errMsg); err != nil { + if err := RecvMessage(stream, &errMsg); err != nil { return fmt.Errorf("recv error failed: %v", err) } if errMsg.Msg != kNoErrorMsg { @@ -183,41 +197,119 @@ func RecvError(session *kcp.UDPSession) error { return nil } -func NewKcpSession(addr string, key []byte, cmd string) (session *kcp.UDPSession, err error) { +type Client interface { + Close() error + NewStream() (net.Conn, error) +} + +type kcpClient struct { + session *smux.Session +} + +func (c *kcpClient) Close() error { + return c.session.Close() +} + +func (c *kcpClient) NewStream() (net.Conn, error) { + stream, err := c.session.OpenStream() + if err != nil { + return nil, fmt.Errorf("kcp smux open stream failed: %v", err) + } + return stream, nil +} + +type quicClient struct { + conn quic.Connection +} + +func (c *quicClient) Close() error { + return c.conn.CloseWithError(0, "") +} + +func (c *quicClient) NewStream() (net.Conn, error) { + stream, err := c.conn.OpenStreamSync(context.Background()) + if err != nil { + return nil, fmt.Errorf("quic open stream sync failed: %v", err) + } + return &quicStream{stream, c.conn}, err +} + +func NewClient(host string, info *ServerInfo) (Client, error) { + switch info.Mode { + case "": + return nil, fmt.Errorf("Please upgrade tsshd.") + case kModeKCP: + return newKcpClient(host, info) + case kModeQUIC: + return newQuicClient(host, info) + default: + return nil, fmt.Errorf("unknown tsshd mode: %s", info.Mode) + } +} + +func newKcpClient(host string, info *ServerInfo) (Client, error) { + pass, err := hex.DecodeString(info.Pass) + if err != nil { + return nil, fmt.Errorf("decode pass [%s] failed: %v", info.Pass, err) + } + salt, err := hex.DecodeString(info.Salt) + if err != nil { + return nil, fmt.Errorf("decode salt [%s] failed: %v", info.Pass, err) + } + addr := joinHostPort(host, info.Port) + key := pbkdf2.Key(pass, salt, 4096, 32, sha1.New) block, err := kcp.NewAESBlockCrypt(key) if err != nil { return nil, fmt.Errorf("new aes block crypt failed: %v", err) } + conn, err := kcp.DialWithOptions(addr, block, 10, 3) + if err != nil { + return nil, fmt.Errorf("kcp dial [%s] failed: %v", addr, err) + } + conn.SetNoDelay(1, 10, 2, 1) + session, err := smux.Client(conn, &smuxConfig) + if err != nil { + return nil, fmt.Errorf("kcp smux client failed: %v", err) + } + return &kcpClient{session}, nil +} - done := make(chan struct{}, 1) - go func() { - defer func() { - if err != nil && session != nil { - session.Close() - } - done <- struct{}{} - close(done) - }() - session, err = kcp.DialWithOptions(addr, block, 10, 3) - if err != nil { - err = fmt.Errorf("kcp dial [%s] [%s] failed: %v", addr, cmd, err) - return - } - session.SetNoDelay(1, 10, 2, 1) - if err = SendCommand(session, cmd); err != nil { - err = fmt.Errorf("kcp send command [%s] [%s] failed: %v", addr, cmd, err) - return - } - if err = RecvError(session); err != nil { - err = fmt.Errorf("kcp new session [%s] [%s] failed: %v", addr, cmd, err) - return - } - }() +func newQuicClient(host string, info *ServerInfo) (Client, error) { + serverCert, err := hex.DecodeString(info.ServerCert) + if err != nil { + return nil, fmt.Errorf("decode server cert [%s] failed: %v", info.ServerCert, err) + } + clientCert, err := hex.DecodeString(info.ClientCert) + if err != nil { + return nil, fmt.Errorf("decode client cert [%s] failed: %v", info.ClientCert, err) + } + clientKey, err := hex.DecodeString(info.ClientKey) + if err != nil { + return nil, fmt.Errorf("decode client key [%s] failed: %v", info.ClientKey, err) + } + + clientTlsCert, err := tls.X509KeyPair(clientCert, clientKey) + if err != nil { + return nil, fmt.Errorf("x509 key pair failed: %v", err) + } + serverCertPool := x509.NewCertPool() + serverCertPool.AppendCertsFromPEM(serverCert) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{clientTlsCert}, + RootCAs: serverCertPool, + ServerName: "tsshd", + } + addr := joinHostPort(host, info.Port) + conn, err := quic.DialAddr(context.Background(), addr, tlsConfig, &quicConfig) + if err != nil { + return nil, fmt.Errorf("quic dail [%s] failed: %v", addr, err) + } + return &quicClient{conn}, nil +} - select { - case <-time.After(10 * time.Second): - err = fmt.Errorf("kcp new session [%s] [%s] timeout", addr, cmd) - case <-done: +func joinHostPort(host string, port int) string { + if !strings.HasPrefix(host, "[") && strings.ContainsRune(host, ':') { + return fmt.Sprintf("[%s]:%d", host, port) } - return + return fmt.Sprintf("%s:%d", host, port) } diff --git a/tsshd/server.go b/tsshd/server.go index 1c6a7f8..e20545f 100644 --- a/tsshd/server.go +++ b/tsshd/server.go @@ -25,63 +25,78 @@ SOFTWARE. package tsshd import ( + "crypto/ecdsa" + "crypto/elliptic" crypto_rand "crypto/rand" "crypto/sha1" + "crypto/tls" + "crypto/x509" "encoding/json" + "encoding/pem" "fmt" + "math/big" math_rand "math/rand" "net" + "time" + "github.com/quic-go/quic-go" "github.com/xtaci/kcp-go/v5" "golang.org/x/crypto/pbkdf2" ) -const kDefaultPortRangeLow = 61001 +const ( + kModeKCP = "KCP" + kModeQUIC = "QUIC" +) -const kDefaultPortRangeHigh = 61999 +const ( + kDefaultPortRangeLow = 61001 + kDefaultPortRangeHigh = 61999 +) -func initServer(args *tsshdArgs) (*kcp.Listener, error) { +var quicConfig = quic.Config{ + HandshakeIdleTimeout: 30 * time.Second, + MaxIdleTimeout: 365 * 24 * time.Hour, +} + +func initServer(args *tsshdArgs) (*kcp.Listener, *quic.Listener, error) { portRangeLow := kDefaultPortRangeLow portRangeHigh := kDefaultPortRangeHigh conn, port := listenOnFreePort(portRangeLow, portRangeHigh) if conn == nil { - return nil, fmt.Errorf("no free udp port in [%d, %d]", portRangeLow, portRangeHigh) + return nil, nil, fmt.Errorf("no free udp port in [%d, %d]", portRangeLow, portRangeHigh) } - pass := make([]byte, 32) - if _, err := crypto_rand.Read(pass); err != nil { - return nil, fmt.Errorf("rand pass failed: %v", err) - } - salt := make([]byte, 32) - if _, err := crypto_rand.Read(salt); err != nil { - return nil, fmt.Errorf("rand salt failed: %v", err) + info := &ServerInfo{ + Ver: kTsshdVersion, + Port: port, } - key := pbkdf2.Key(pass, salt, 4096, 32, sha1.New) - block, err := kcp.NewAESBlockCrypt(key) - if err != nil { - return nil, fmt.Errorf("new aes block crypt failed: %v", err) + var err error + var kcpListener *kcp.Listener + var quicListener *quic.Listener + if args.KCP { + kcpListener, err = listenKCP(conn, info) + } else { + quicListener, err = listenQUIC(conn, info) } - - listener, err := kcp.ServeConn(block, 10, 3, conn) if err != nil { - return nil, fmt.Errorf("kcp serve conn failed: %v", err) + return nil, nil, err } - svrInfo := ServerInfo{ - Ver: kTsshdVersion, - Pass: fmt.Sprintf("%x", pass), - Salt: fmt.Sprintf("%x", salt), - Port: port, - } - info, err := json.Marshal(svrInfo) + infoStr, err := json.Marshal(info) if err != nil { - listener.Close() - return nil, fmt.Errorf("json marshal failed: %v\n", err) + if kcpListener != nil { + kcpListener.Close() + } + if quicListener != nil { + quicListener.Close() + } + return nil, nil, fmt.Errorf("json marshal failed: %v\n", err) } - fmt.Printf("\a%s\r\n", string(info)) + fmt.Printf("\a%s\r\n", string(infoStr)) - return listener, nil + return kcpListener, quicListener, nil } func listenOnFreePort(low, high int) (*net.UDPConn, int) { @@ -113,3 +128,93 @@ func listenOnPort(port int) *net.UDPConn { } return conn } + +func listenKCP(conn *net.UDPConn, info *ServerInfo) (*kcp.Listener, error) { + pass := make([]byte, 32) + if _, err := crypto_rand.Read(pass); err != nil { + return nil, fmt.Errorf("rand pass failed: %v", err) + } + salt := make([]byte, 32) + if _, err := crypto_rand.Read(salt); err != nil { + return nil, fmt.Errorf("rand salt failed: %v", err) + } + key := pbkdf2.Key(pass, salt, 4096, 32, sha1.New) + + block, err := kcp.NewAESBlockCrypt(key) + if err != nil { + return nil, fmt.Errorf("new aes block crypt failed: %v", err) + } + + listener, err := kcp.ServeConn(block, 10, 3, conn) + if err != nil { + return nil, fmt.Errorf("kcp serve conn failed: %v", err) + } + + info.Mode = kModeKCP + info.Pass = fmt.Sprintf("%x", pass) + info.Salt = fmt.Sprintf("%x", salt) + return listener, nil +} + +func listenQUIC(conn *net.UDPConn, info *ServerInfo) (*quic.Listener, error) { + serverCertPEM, serverKeyPEM, err := generateCertKeyPair() + if err != nil { + return nil, err + } + clientCertPEM, clientKeyPEM, err := generateCertKeyPair() + if err != nil { + return nil, err + } + + serverTlsCert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM) + if err != nil { + return nil, fmt.Errorf("x509 key pair failed: %v", err) + } + + clientCertPool := x509.NewCertPool() + clientCertPool.AppendCertsFromPEM(clientCertPEM) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{serverTlsCert}, + ClientCAs: clientCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + listener, err := (&quic.Transport{Conn: conn}).Listen(tlsConfig, &quicConfig) + if err != nil { + return nil, fmt.Errorf("quic listen failed: %v", err) + } + + info.Mode = kModeQUIC + info.ServerCert = fmt.Sprintf("%x", serverCertPEM) + info.ClientCert = fmt.Sprintf("%x", clientCertPEM) + info.ClientKey = fmt.Sprintf("%x", clientKeyPEM) + + return listener, nil +} + +func generateCertKeyPair() ([]byte, []byte, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), crypto_rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("ecdsa generate key failed: %v", err) + } + now := time.Now() + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{"tsshd"}, + NotBefore: now.AddDate(0, 0, -1), + NotAfter: now.AddDate(1, 0, 0), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + } + certDER, err := x509.CreateCertificate(crypto_rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + return nil, nil, fmt.Errorf("x509 create certificate failed: %v", err) + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyBytes, err := x509.MarshalECPrivateKey(key) + if err != nil { + return nil, nil, fmt.Errorf("x509 marshal ec private key failed: %v", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}) + return certPEM, keyPEM, nil +} diff --git a/tsshd/service.go b/tsshd/service.go index 555a7e7..0fa852d 100644 --- a/tsshd/service.go +++ b/tsshd/service.go @@ -25,54 +25,112 @@ SOFTWARE. package tsshd import ( + "context" "fmt" - "sync/atomic" - "time" + "net" + "github.com/quic-go/quic-go" "github.com/xtaci/kcp-go/v5" + "github.com/xtaci/smux" ) -var serving atomic.Bool +var smuxConfig = smux.Config{ + Version: 2, + KeepAliveDisabled: true, + MaxFrameSize: 32 * 1024, + MaxStreamBuffer: 64 * 1024, + MaxReceiveBuffer: 4 * 1024 * 1024, +} + +type quicStream struct { + quic.Stream + conn quic.Connection +} -var exitChan = make(chan bool, 1) +func (s *quicStream) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} -func serve(listener *kcp.Listener) { - defer listener.Close() +func (s *quicStream) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} - go func() { - // should be connected within 10 seconds - time.Sleep(10 * time.Second) - if !serving.Load() { - exitChan <- true +func serveKCP(listener *kcp.Listener) { + for { + conn, err := listener.AcceptKCP() + if err != nil { + trySendErrorMessage("kcp accept failed: %v", err) + return } - }() + go handleKcpConn(conn) + } +} + +func handleKcpConn(conn *kcp.UDPSession) { + defer conn.Close() + + if serving.Load() { + return + } + + conn.SetNoDelay(1, 10, 2, 1) + + session, err := smux.Server(conn, &smuxConfig) + if err != nil { + trySendErrorMessage("kcp smux server failed: %v", err) + return + } - go func() { - for { - session, err := listener.AcceptKCP() - if err != nil { - trySendErrorMessage("kcp accept failed: %v", err) - return - } - go handleSession(session) + for { + stream, err := session.AcceptStream() + if err != nil { + trySendErrorMessage("kcp smux accept stream failed: %v", err) + return } - }() + go handleStream(stream) + } +} - <-exitChan +func serveQUIC(listener *quic.Listener) { + for { + conn, err := listener.Accept(context.Background()) + if err != nil { + trySendErrorMessage("quic accept conn failed: %v", err) + return + } + go handleQuicConn(conn) + } } -func handleSession(session *kcp.UDPSession) { - defer session.Close() +func handleQuicConn(conn quic.Connection) { + defer func() { + _ = conn.CloseWithError(0, "") + }() + + if serving.Load() { + return + } + + for { + stream, err := conn.AcceptStream(context.Background()) + if err != nil { + trySendErrorMessage("quic accept stream failed: %v", err) + return + } + go handleStream(&quicStream{stream, conn}) + } +} - session.SetNoDelay(1, 10, 2, 1) +func handleStream(stream net.Conn) { + defer stream.Close() - command, err := RecvCommand(session) + command, err := RecvCommand(stream) if err != nil { - SendError(session, fmt.Errorf("recv session command failed: %v", err)) + SendError(stream, fmt.Errorf("recv stream command failed: %v", err)) return } - var handler func(*kcp.UDPSession) + var handler func(net.Conn) switch command { case "bus": @@ -88,14 +146,14 @@ func handleSession(session *kcp.UDPSession) { case "accept": handler = handleAcceptEvent default: - SendError(session, fmt.Errorf("unknown session command: %s", command)) + SendError(stream, fmt.Errorf("unknown stream command: %s", command)) return } - if err := SendSuccess(session); err != nil { // say hello + if err := SendSuccess(stream); err != nil { // say hello trySendErrorMessage("tsshd say hello failed: %v", err) return } - handler(session) + handler(stream) } diff --git a/tsshd/session.go b/tsshd/session.go index ed3bf05..9475958 100644 --- a/tsshd/session.go +++ b/tsshd/session.go @@ -27,14 +27,13 @@ package tsshd import ( "fmt" "io" + "net" "os" "os/exec" "path/filepath" "runtime" "strings" "sync" - - "github.com/xtaci/kcp-go/v5" ) type sessionContext struct { @@ -50,17 +49,17 @@ type sessionContext struct { started bool } -type stderrContext struct { - id uint64 - wg sync.WaitGroup - session *kcp.UDPSession +type stderrStream struct { + id uint64 + wg sync.WaitGroup + stream net.Conn } var sessionMutex sync.Mutex var sessionMap = make(map[uint64]*sessionContext) var stderrMutex sync.Mutex -var stderrMap = make(map[uint64]*stderrContext) +var stderrMap = make(map[uint64]*stderrStream) func (c *sessionContext) StartPty() error { var err error @@ -92,17 +91,17 @@ func (c *sessionContext) StartCmd() error { return nil } -func (c *sessionContext) forwardIO(session *kcp.UDPSession) { +func (c *sessionContext) forwardIO(stream net.Conn) { if c.stdin != nil { go func() { - _, _ = io.Copy(c.stdin, session) + _, _ = io.Copy(c.stdin, stream) }() } if c.stdout != nil { c.wg.Add(1) go func() { - _, _ = io.Copy(session, c.stdout) + _, _ = io.Copy(stream, c.stdout) c.wg.Done() }() } @@ -111,9 +110,9 @@ func (c *sessionContext) forwardIO(session *kcp.UDPSession) { c.wg.Add(1) go func() { if stderr, ok := stderrMap[c.id]; ok { - _, _ = io.Copy(stderr.session, c.stderr) + _, _ = io.Copy(stderr.stream, c.stderr) } else { - _, _ = io.Copy(session, c.stderr) + _, _ = io.Copy(stream, c.stderr) } c.wg.Done() }() @@ -167,20 +166,20 @@ func (c *sessionContext) SetSize(cols, rows int) error { return nil } -func handleSessionEvent(session *kcp.UDPSession) { +func handleSessionEvent(stream net.Conn) { var msg StartMessage - if err := RecvMessage(session, &msg); err != nil { - SendError(session, fmt.Errorf("recv start message failed: %v", err)) + if err := RecvMessage(stream, &msg); err != nil { + SendError(stream, fmt.Errorf("recv start message failed: %v", err)) return } - if errCtx := getStderrSession(msg.ID); errCtx != nil { - defer errCtx.Close() + if errStream := getStderrStream(msg.ID); errStream != nil { + defer errStream.Close() } - ctx, err := newSession(&msg) + ctx, err := newSessionContext(&msg) if err != nil { - SendError(session, err) + SendError(stream, err) return } defer ctx.Close() @@ -191,21 +190,21 @@ func handleSessionEvent(session *kcp.UDPSession) { err = ctx.StartCmd() } if err != nil { - SendError(session, err) + SendError(stream, err) return } - if err := SendSuccess(session); err != nil { // ack ok + if err := SendSuccess(stream); err != nil { // ack ok trySendErrorMessage("session ack ok failed: %v", err) return } - ctx.forwardIO(session) + ctx.forwardIO(stream) ctx.Wait() } -func newSession(msg *StartMessage) (*sessionContext, error) { +func newSessionContext(msg *StartMessage) (*sessionContext, error) { cmd, err := getSessionStartCmd(msg) if err != nil { return nil, fmt.Errorf("build start command failed: %v", err) @@ -228,34 +227,34 @@ func newSession(msg *StartMessage) (*sessionContext, error) { return ctx, nil } -func (c *stderrContext) Wait() { +func (c *stderrStream) Wait() { c.wg.Wait() } -func (c *stderrContext) Close() { +func (c *stderrStream) Close() { c.wg.Done() stderrMutex.Lock() defer stderrMutex.Unlock() delete(stderrMap, c.id) } -func newStderrSession(id uint64, session *kcp.UDPSession) (*stderrContext, error) { +func newStderrStream(id uint64, stream net.Conn) (*stderrStream, error) { stderrMutex.Lock() defer stderrMutex.Unlock() if _, ok := stderrMap[id]; ok { return nil, fmt.Errorf("session %d stderr already set", id) } - ctx := &stderrContext{id: id, session: session} - ctx.wg.Add(1) - stderrMap[id] = ctx - return ctx, nil + errStream := &stderrStream{id: id, stream: stream} + errStream.wg.Add(1) + stderrMap[id] = errStream + return errStream, nil } -func getStderrSession(id uint64) *stderrContext { +func getStderrStream(id uint64) *stderrStream { stderrMutex.Lock() defer stderrMutex.Unlock() - if ctx, ok := stderrMap[id]; ok { - return ctx + if errStream, ok := stderrMap[id]; ok { + return errStream } return nil } @@ -294,30 +293,30 @@ func getSessionStartCmd(msg *StartMessage) (*exec.Cmd, error) { return cmd, nil } -func handleStderrEvent(session *kcp.UDPSession) { +func handleStderrEvent(stream net.Conn) { var msg StderrMessage - if err := RecvMessage(session, &msg); err != nil { - SendError(session, fmt.Errorf("recv stderr message failed: %v", err)) + if err := RecvMessage(stream, &msg); err != nil { + SendError(stream, fmt.Errorf("recv stderr message failed: %v", err)) return } - ctx, err := newStderrSession(msg.ID, session) + errStream, err := newStderrStream(msg.ID, stream) if err != nil { - SendError(session, err) + SendError(stream, err) return } - if err := SendSuccess(session); err != nil { // ack ok + if err := SendSuccess(stream); err != nil { // ack ok trySendErrorMessage("stderr ack ok failed: %v", err) return } - ctx.Wait() + errStream.Wait() } -func handleResizeEvent(session *kcp.UDPSession) error { +func handleResizeEvent(stream net.Conn) error { var msg ResizeMessage - if err := RecvMessage(session, &msg); err != nil { + if err := RecvMessage(stream, &msg); err != nil { return fmt.Errorf("recv resize message failed: %v", err) } if msg.Cols <= 0 || msg.Rows <= 0 {