diff --git a/.env.sample b/.env.sample index 6416b54..e69de29 100644 --- a/.env.sample +++ b/.env.sample @@ -1,7 +0,0 @@ -ADDR= -PROTO= -PEM= -KEY= -TIMEOUT= -AUTH= -TZ= \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index fc1a88f..f68266a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ FROM busybox:1.36.1-glibc -WORKDIR /root -COPY nanoproxy /usr/local/bin/nanoproxy +COPY nanoproxy /usr/bin/nanoproxy +EXPOSE 1080 ENTRYPOINT ["nanoproxy"] \ No newline at end of file diff --git a/LICENSE b/LICENSE index d03a842..f12f9c2 100644 --- a/LICENSE +++ b/LICENSE @@ -19,3 +19,26 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- + +The MIT License (MIT) + +Copyright (c) 2014 Armon Dadgar + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index be07a7f..bc114c4 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,12 @@ ![coverage](https://raw.githubusercontent.com/ryanbekhen/nanoproxy/badges/.badges/master/coverage.svg) [![Go Report Card](https://goreportcard.com/badge/github.com/ryanbekhen/nanoproxy?cache=v1)](https://goreportcard.com/report/github.com/vladopajic/go-test-coverage) -NanoProxy is a lightweight HTTP proxy server designed to provide basic proxying functionality. -It supports handling HTTP requests and tunneling for HTTPS. NanoProxy is written in Go and built on top of Fiber. +Note: This code includes modifications from the original go-socks5 project (https://github.com/armon/go-socks5) +Modifications have been made as part of maintenance for NanoProxy. +This version is licensed under the MIT license. + +NanoProxy is a lightweight SOCKS5 proxy server written in Go. It is designed to be simple, minimalistic, and easy to +use. > ⚠️ **Notice:** NanoProxy is currently in pre-production stage. While it provides essential proxying capabilities, > please be aware that it is still under active development. Full backward compatibility is not guaranteed until @@ -13,10 +17,10 @@ It supports handling HTTP requests and tunneling for HTTPS. NanoProxy is written ## Data Flow Through Proxy -NanoProxy acts as an intermediary between user requests and the destination server. When a user makes a request, -NanoProxy forwards the request to the destination server. The destination server processes the request and responds -back to NanoProxy, which then sends the response back to the user. This allows NanoProxy to intercept and manage -network traffic effectively. +NanoProxy acts as a proxy server that forwards network traffic between the user and the destination server. +When a user makes a request, the request is sent to the proxy server. The proxy server then forwards the request to +the destination server. The destination server processes the request and responds back to the proxy server, which then +sends the response back to the user. This allows the proxy server to intercept and manage network traffic effectively. Here's how the data flows through the proxy: @@ -30,17 +34,15 @@ Here's how the data flows through the proxy: `-----------------' ``` -This clear separation of responsibilities helps optimize network communication and enables various -proxy-related functionalities. +This clear separation of responsibilities helps optimize network communication and enables various proxy-related +functionalities. ## Features -- **Simple and minimalistic HTTP proxy server.** NanoProxy is designed with simplicity in mind, making it easy to set -up and use for various purposes. -- **Handles both HTTP requests and tunneling (CONNECT) for HTTPS.** NanoProxy supports both HTTP requests and tunneling, -allowing you to proxy regular HTTP requests as well as secure HTTPS connections. -- **Lightweight and easy to configure.** With a small footprint and straightforward configuration options, NanoProxy is -a lightweight solution that can be quickly configured to suit your needs. +NanoProxy provides the following features: + +- **SOCKS5 proxy server.** NanoProxy is a SOCKS5 proxy server that can be used to proxy network traffic for various +applications. ## Installation @@ -48,7 +50,7 @@ You can easily install NanoProxy using your package manager by adding the offici ### Debian and Ubuntu -Add the NanoProxy repository to your sources list: +Add the NanoProxy repository to your source list: ```shell echo "deb [trusted=yes] https://repo.ryanbekhen.dev/apt/ /" | sudo tee /etc/apt/sources.list.d/ryanbekhen.list @@ -83,7 +85,7 @@ sudo yum install nanoproxy ## Usage -After installing NanoProxy using the provided packages (.deb or .rpm) or accessed it through the repository, +After installing NanoProxy using the provided packages (.deb or .rpm) or accessing it through the repository, you can manage NanoProxy as a service using the system's service management tool (systemd). To enable NanoProxy to start automatically on system boot, run the following command: @@ -104,35 +106,35 @@ sudo systemctl start nanoproxy You can also run NanoProxy using Docker. To do so, you can use the following command: ```shell -docker run -p 8080:8080 ghcr.io/ryanbekhen/nanoproxy:latest +docker run -p 1080:1080 ghcr.io/ryanbekhen/nanoproxy:latest ``` ## Configuration -You can modify the behavior of NanoProxy by adjusting the command line flags when running the proxy. The available flags are: - -- `-addr`: Proxy listen address (default: :8080). -- `-pem`: Path to the PEM file for TLS (HTTPS) support. -- `-key`: Path to the private key file for TLS. -- `-proto`: Proxy protocol `http` or `https`. If set to `https`, the `-pem` and `-key` flags must be set. -- `-timeout`: Timeout duration for tunneling connections (default: 15 seconds). -- `-auth`: Basic authentication credentials in the form of `username:password`. -- `-debug`: Enable debug mode. - -You can set the configuration using environment variables. Create a file -at `/etc/nanoproxy/nanoproxy.env` and add the desired values: +You can also set the configuration using environment variables. Create a file at `/etc/nanoproxy/nanoproxy` and add the +desired values: ```text ADDR=:8080 -PROTO=http -PEM=server.pem -KEY=server.key -TIMEOUT=15s -AUTH=user:pass +NETWORK=tcp TZ=Asia/Jakarta ``` -Modify these flags or environment variables according to your requirements. +The following table lists the available configuration options: + +| Name | Description | Default Value | +|------|-------------|---------------| +| ADDR | The address to listen on. | `:1080` | +| NETWORK | The network to listen on. | `tcp` | +| TZ | The timezone to use. | `Local` | + +## Logging + +NanoProxy logs all requests and responses to the standard output. You can use the `journalctl` command to view the logs: + +```shell +journalctl -u nanoproxy +``` ## Testing @@ -155,8 +157,8 @@ Contributions are welcome! Feel free to open issues and submit pull requests. ## Security -If you discover any security related issues, please email i@ryanbekhen.dev instead of using the issue tracker. +If you discover any security-related issues, please email i@ryanbekhen.dev instead of using the issue tracker. ## License -This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. +This project is licensed under the MIT License—see the [LICENSE](LICENSE) file for details. diff --git a/config/config.go b/config/config.go deleted file mode 100644 index 30d2ffd..0000000 --- a/config/config.go +++ /dev/null @@ -1,114 +0,0 @@ -package config - -import ( - "errors" - "flag" - "os" - "strings" - "time" -) - -type Config struct { - pemPath string - keyPath string - proto string - addr string - tunnelTimeout time.Duration - basicAuth string - debug bool -} - -var ( - ErrInvalidProto = errors.New("invalid protocol") - ErrInvalidBasicAuth = errors.New("invalid basic auth") -) - -func New() *Config { - c := &Config{} - flag.StringVar(&c.pemPath, "pem", "server.pem", "path to pem file") - flag.StringVar(&c.keyPath, "key", "server.key", "path to key file") - flag.StringVar(&c.proto, "proto", "http", "proxy protocol (http or https)") - flag.StringVar(&c.addr, "addr", ":8080", "proxy listen address (default :8080)") - flag.DurationVar(&c.tunnelTimeout, "timeout", time.Second*15, "tunnel timeout (default 15s)") - flag.StringVar(&c.basicAuth, "auth", "", "basic auth (username:password)") - flag.BoolVar(&c.debug, "debug", false, "debug mode") - flag.Parse() - - if os.Getenv("PEM") != "" { - c.pemPath = os.Getenv("PEM") - } - - if os.Getenv("KEY") != "" { - c.keyPath = os.Getenv("KEY") - } - - if os.Getenv("PROTO") != "" { - c.proto = os.Getenv("PROTO") - } - - if os.Getenv("ADDR") != "" { - c.addr = os.Getenv("ADDR") - } - - if os.Getenv("TIMEOUT") != "" { - d, err := time.ParseDuration(os.Getenv("TIMEOUT")) - if err == nil { - c.tunnelTimeout = d - } - } - - if os.Getenv("AUTH") != "" { - c.basicAuth = os.Getenv("AUTH") - } - return c -} - -func (c *Config) PemPath() string { - return c.pemPath -} - -func (c *Config) KeyPath() string { - return c.keyPath -} - -func (c *Config) Addr() string { - return c.addr -} - -func (c *Config) TunnelTimeout() time.Duration { - return c.tunnelTimeout -} - -func (c *Config) BasicAuth() map[string]string { - cred := strings.Split(c.basicAuth, ",") - users := map[string]string{} - for _, cred := range cred { - userPass := strings.Split(cred, ":") - if len(userPass) == 2 { - users[userPass[0]] = userPass[1] - } - } - return users -} - -func (c *Config) Debug() bool { - return c.debug -} - -func (c *Config) IsHTTPS() bool { - return c.proto == "https" -} - -func (c *Config) Validate() error { - if c.proto != "http" && c.proto != "https" { - return ErrInvalidProto - } - - if c.basicAuth != "" { - if len(c.BasicAuth()) == 0 { - return ErrInvalidBasicAuth - } - } - - return nil -} diff --git a/go.mod b/go.mod index 449005a..e1bfabc 100644 --- a/go.mod +++ b/go.mod @@ -3,21 +3,12 @@ module github.com/ryanbekhen/nanoproxy go 1.21 require ( - github.com/gofiber/contrib/fiberzerolog v0.2.2 - github.com/gofiber/fiber/v2 v2.49.2 - github.com/rs/zerolog v1.30.0 - github.com/valyala/fasthttp v1.50.0 + github.com/caarlos0/env/v10 v10.0.0 + github.com/rs/zerolog v1.31.0 ) require ( - github.com/andybalholm/brotli v1.0.5 // indirect - github.com/google/uuid v1.3.1 // indirect - github.com/klauspost/compress v1.17.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/rivo/uniseg v0.4.4 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/tcplisten v1.0.0 // indirect - golang.org/x/sys v0.12.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + golang.org/x/sys v0.14.0 // indirect ) diff --git a/go.sum b/go.sum index 782ceff..f63ab3c 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,7 @@ -github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= -github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/caarlos0/env/v10 v10.0.0 h1:yIHUBZGsyqCnpTkbjk8asUlx6RFhhEs+h7TOBdgdzXA= +github.com/caarlos0/env/v10 v10.0.0/go.mod h1:ZfulV76NvVPw3tm591U4SwL3Xx9ldzBP9aGxzeN7G18= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gofiber/contrib/fiberzerolog v0.2.2 h1:tvHBW5k+udW02LU1eNneh65znGwhsKcv8XWf22U7dlc= -github.com/gofiber/contrib/fiberzerolog v0.2.2/go.mod h1:CSpu4UUPGWAA/jIIuHXIhJt3W1cRxprxupXndAYuvpU= -github.com/gofiber/fiber/v2 v2.49.2 h1:ONEN3/Vc+dUCxxDgZZwpqvhISgHqb+bu+isBiEyKEQs= -github.com/gofiber/fiber/v2 v2.49.2/go.mod h1:gNsKnyrmfEWFpJxQAV0qvW6l70K1dZGno12oLtukcts= -github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= -github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= -github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= @@ -17,24 +9,19 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= -github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.50.0 h1:H7fweIlBm0rXLs2q0XbalvJ6r0CUPFWK3/bB4N13e9M= -github.com/valyala/fasthttp v1.50.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= -github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= -github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= +github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go deleted file mode 100644 index cbafba2..0000000 --- a/middleware/basicauth/basicauth.go +++ /dev/null @@ -1,55 +0,0 @@ -package basicauth - -import ( - "encoding/base64" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/utils" - "strings" -) - -func New(config Config) fiber.Handler { - cfg := configDefault(config) - return func(c *fiber.Ctx) error { - // Skip if basic auth is empty - if len(cfg.Users) == 0 { - return c.Next() - } - - // Get authorization header - auth := c.Get(fiber.HeaderProxyAuthorization) - - // Check if header is valid - if len(auth) < 6 || !utils.EqualFold(auth[:6], "basic ") { - return cfg.Unauthorized(c) - } - - // Decode header - raw, err := base64.StdEncoding.DecodeString(auth[6:]) - if err != nil { - return cfg.Unauthorized(c) - } - - // Get credentials - creds := utils.UnsafeString(raw) - - // Split username and password - index := strings.Index(creds, ":") - if index == -1 { - return cfg.Unauthorized(c) - } - - // Get username and password - user := creds[:index] - pass := creds[index+1:] - - // Check credentials - if cfg.Authorized(user, pass) { - c.Locals("username", user) - c.Locals("password", pass) - return c.Next() - } - - // Credentials doesn't match - return cfg.Unauthorized(c) - } -} diff --git a/middleware/basicauth/basicauth_test.go b/middleware/basicauth/basicauth_test.go deleted file mode 100644 index e6360a1..0000000 --- a/middleware/basicauth/basicauth_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package basicauth - -import ( - "encoding/base64" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/utils" - "io" - "net/http/httptest" - "testing" -) - -func Test_Middleware_BasicAuth(t *testing.T) { - t.Parallel() - - app := fiber.New() - - app.Use(New(Config{ - Users: map[string]string{ - "john": "doe", - "jane": "doe", - }, - })) - - app.Get("/testauth", func(c *fiber.Ctx) error { - username := c.Locals("username").(string) - password := c.Locals("password").(string) - - return c.SendString(username + password) - }) - - tests := []struct { - url string - statusCode int - username string - password string - }{ - { - url: "/testauth", - statusCode: fiber.StatusOK, - username: "john", - password: "doe", - }, - { - url: "/testauth", - statusCode: fiber.StatusOK, - username: "jane", - password: "doe", - }, - { - url: "/testauth", - statusCode: fiber.StatusProxyAuthRequired, - username: "john", - password: "wrong", - }, - } - - for _, tt := range tests { - // Encode credentials to base64 - cred := base64.StdEncoding.EncodeToString([]byte(tt.username + ":" + tt.password)) - - req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil) - req.Header.Set(fiber.HeaderProxyAuthorization, "Basic "+cred) - resp, err := app.Test(req) - utils.AssertEqual(t, nil, err) - - body, err := io.ReadAll(resp.Body) - utils.AssertEqual(t, nil, err) - - utils.AssertEqual(t, tt.statusCode, resp.StatusCode) - if tt.statusCode == fiber.StatusOK { - utils.AssertEqual(t, tt.username+tt.password, string(body)) - } - } -} - -func Test_Middleware_BasicAuth_No_Users(t *testing.T) { - t.Parallel() - - app := fiber.New() - - app.Use(New(Config{ - Users: map[string]string{}, - })) - - app.Get("/testauth", func(c *fiber.Ctx) error { - return c.SendString("testauth") - }) - - req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil) - resp, err := app.Test(req) - utils.AssertEqual(t, nil, err) - - body, err := io.ReadAll(resp.Body) - utils.AssertEqual(t, nil, err) - - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) - utils.AssertEqual(t, "testauth", string(body)) -} diff --git a/middleware/basicauth/config.go b/middleware/basicauth/config.go deleted file mode 100644 index 216e5ad..0000000 --- a/middleware/basicauth/config.go +++ /dev/null @@ -1,47 +0,0 @@ -package basicauth - -import ( - "crypto/subtle" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/utils" -) - -type Config struct { - Users map[string]string - Authorized func(user string, pass string) bool - Unauthorized func(*fiber.Ctx) error -} - -var ConfigDefault = Config{ - Users: map[string]string{}, - Authorized: nil, - Unauthorized: nil, -} - -func configDefault(config ...Config) Config { - if len(config) < 1 { - return ConfigDefault - } - - cfg := config[0] - - if cfg.Users == nil { - cfg.Users = ConfigDefault.Users - } - - if cfg.Authorized == nil { - cfg.Authorized = func(user string, pass string) bool { - userPass, exist := cfg.Users[user] - return exist && subtle.ConstantTimeCompare(utils.UnsafeBytes(userPass), utils.UnsafeBytes(pass)) == 1 - } - } - - if cfg.Unauthorized == nil { - cfg.Unauthorized = func(c *fiber.Ctx) error { - c.Set(fiber.HeaderProxyAuthenticate, "Basic realm=Restricted") - return c.SendStatus(fiber.StatusProxyAuthRequired) - } - } - - return cfg -} diff --git a/middleware/hopbyhop/hopbyhop.go b/middleware/hopbyhop/hopbyhop.go deleted file mode 100644 index 0a8e696..0000000 --- a/middleware/hopbyhop/hopbyhop.go +++ /dev/null @@ -1,27 +0,0 @@ -package hopbyhop - -import "github.com/gofiber/fiber/v2" - -// Hop-by-hop headers. These are removed when sent to the backend. -// (https://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html) -var hopHeaders = []string{ - "Connection", - "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailers", - "Transfer-Encoding", - "Upgrade", -} - -func New() fiber.Handler { - return func(c *fiber.Ctx) error { - // remove hop-by-hop headers - for _, h := range hopHeaders { - c.Request().Header.Del(h) - } - return c.Next() - } -} diff --git a/middleware/hopbyhop/hopbyhop_test.go b/middleware/hopbyhop/hopbyhop_test.go deleted file mode 100644 index 87742b7..0000000 --- a/middleware/hopbyhop/hopbyhop_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package hopbyhop - -import ( - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/utils" - "io" - "net/http/httptest" - "testing" -) - -// go test -run Test_Middleware_HopByHop -func Test_Middleware_HopByHop(t *testing.T) { - t.Parallel() - - app := fiber.New() - app.Use(New()).Get("/test", func(c *fiber.Ctx) error { - return c.Send(c.Request().Header.Header()) - }) - - // request with hop-by-hop headers - req := httptest.NewRequest(fiber.MethodGet, "/test", nil) - req.Header.Set("Proxy-Connection", "close") - req.Header.Set("Test", "test") - - resp, err := app.Test(req) - utils.AssertEqual(t, nil, err) - - body, err := io.ReadAll(resp.Body) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "GET /test HTTP/1.1\r\nHost: example.com\r\nTest: test\r\n\r\n", string(body)) -} diff --git a/nanoproxy.go b/nanoproxy.go index f5f2f27..95bfca1 100644 --- a/nanoproxy.go +++ b/nanoproxy.go @@ -1,62 +1,40 @@ package main import ( - "github.com/gofiber/contrib/fiberzerolog" - "github.com/gofiber/fiber/v2" - recoverFiber "github.com/gofiber/fiber/v2/middleware/recover" + "github.com/caarlos0/env/v10" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - "github.com/ryanbekhen/nanoproxy/config" - "github.com/ryanbekhen/nanoproxy/middleware/basicauth" - "github.com/ryanbekhen/nanoproxy/middleware/hopbyhop" - "github.com/ryanbekhen/nanoproxy/webproxy" + "github.com/ryanbekhen/nanoproxy/pkg/config" + "github.com/ryanbekhen/nanoproxy/pkg/socks5" "os" "time" + _ "time/tzdata" ) func main() { - cfg := config.New() - loc, _ := time.LoadLocation(os.Getenv("TZ")) - time.Local = loc + cfg := &config.Config{} + logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger() - logLevel := zerolog.InfoLevel - if cfg.Debug() { - logLevel = zerolog.DebugLevel + if err := env.Parse(cfg); err != nil { + logger.Fatal().Msg(err.Error()) } - logger := log.Level(logLevel).Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}). - With().Timestamp().Logger() + loc, _ := time.LoadLocation(cfg.Timezone) + if loc != nil { + time.Local = loc + } - // validate config - if err := cfg.Validate(); err != nil { - logger.Fatal().Msg(err.Error()) + socks5Config := &socks5.Config{ + Logger: &logger, } - server := fiber.New(fiber.Config{DisableStartupMessage: true}) - srv := webproxy.New(cfg.TunnelTimeout()) + sock5Server, err := socks5.New(socks5Config) + if err != nil { + logger.Fatal().Msg(err.Error()) + } - // middleware - server.Use(recoverFiber.New()) - server.Use(basicauth.New(basicauth.Config{Users: cfg.BasicAuth()})) - server.Use(hopbyhop.New()) - server.Use(fiberzerolog.New(fiberzerolog.Config{ - Logger: &logger, - Fields: []string{"ip", "latency", "status", "url", "error"}, - })) - - // routes - server.All("*", srv.Handler) - - // start server - logger.Info().Msgf("Starting server on %s", cfg.Addr()) - if cfg.IsHTTPS() { - logger.Fatal(). - Err(server.ListenTLS(cfg.Addr(), cfg.PemPath(), cfg.KeyPath())). - Msg("Server closed") - } else { - logger.Fatal(). - Err(server.Listen(cfg.Addr())). - Msg("Server closed") + logger.Info().Msgf("Starting socks5 server on %s://%s", cfg.Network, cfg.ADDR) + if err := sock5Server.ListenAndServe(cfg.Network, cfg.ADDR); err != nil { + logger.Fatal().Msg(err.Error()) } } diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..9c92a8e --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,7 @@ +package config + +type Config struct { + Timezone string `env:"TZ" envDefault:"Local"` + Network string `env:"NETWORK" envDefault:"tcp"` + ADDR string `env:"ADDR" envDefault:":1080"` +} diff --git a/pkg/socks5/auth.go b/pkg/socks5/auth.go new file mode 100644 index 0000000..577d27e --- /dev/null +++ b/pkg/socks5/auth.go @@ -0,0 +1,150 @@ +package socks5 + +import ( + "fmt" + "io" +) + +const ( + NoAuth = uint8(0) + noAcceptable = uint8(255) + UserPassAuth = uint8(2) + userAuthVersion = uint8(1) + authSuccess = uint8(0) + authFailure = uint8(1) +) + +var ( + UserAuthFailed = fmt.Errorf("user authentication failed") + NoSupportedAuth = fmt.Errorf("no supported authentication mechanism") +) + +// AuthContext A Request encapsulates authentication state provided +// during negotiation +type AuthContext struct { + // Provided auth method + Method uint8 + // Payload provided during negotiation. + // Keys depend on the used auth method. + // For UserPass-auth contains Username + Payload map[string]string +} + +type Authenticator interface { + Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) + GetCode() uint8 +} + +// NoAuthAuthenticator is used to handle the "No Authentication" mode +type NoAuthAuthenticator struct{} + +func (a NoAuthAuthenticator) GetCode() uint8 { + return NoAuth +} + +func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { + _, err := writer.Write([]byte{socks5Version, NoAuth}) + return &AuthContext{NoAuth, nil}, err +} + +// UserPassAuthenticator is used to handle username/password-based +// authentication +type UserPassAuthenticator struct { + Credentials CredentialStore +} + +func (a UserPassAuthenticator) GetCode() uint8 { + return UserPassAuth +} + +func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { + // Tell the client to use user/pass auth + if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil { + return nil, err + } + + // Get the version and username length + header := []byte{0, 0} + if _, err := io.ReadAtLeast(reader, header, 2); err != nil { + return nil, err + } + + // Ensure we are compatible + if header[0] != userAuthVersion { + return nil, fmt.Errorf("unsupported auth version: %v", header[0]) + } + + // Get the username + userLen := int(header[1]) + user := make([]byte, userLen) + if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { + return nil, err + } + + // Get the password length + if _, err := reader.Read(header[:1]); err != nil { + return nil, err + } + + // Get the password + passLen := int(header[0]) + pass := make([]byte, passLen) + if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { + return nil, err + } + + // Verify the password + if a.Credentials.Valid(string(user), string(pass)) { + if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { + return nil, err + } + } else { + if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { + return nil, err + } + return nil, UserAuthFailed + } + + // Done + return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil +} + +// authenticate is used to handle connection authentication +func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) { + // Get the methods + methods, err := readMethods(bufConn) + if err != nil { + return nil, fmt.Errorf("failed to get auth methods: %v", err) + } + + // Select a usable method + for _, method := range methods { + if auth, ok := s.authMethods[method]; ok { + return auth.Authenticate(bufConn, conn) + } + } + + // No usable method found + return nil, noAcceptableAuth(conn) +} + +// noAcceptableAuth is used to handle when we have no eligible +// authentication mechanism +func noAcceptableAuth(conn io.Writer) error { + conn.Write([]byte{socks5Version, noAcceptable}) + return NoSupportedAuth +} + +// readMethods is used to read the number of methods +// and proceeding auth methods +func readMethods(r io.Reader) ([]byte, error) { + header := []byte{0} + if _, err := r.Read(header); err != nil { + return nil, err + } + + numMethods := int(header[0]) + methods := make([]byte, numMethods) + _, err := io.ReadAtLeast(r, methods, numMethods) + return methods, err +} diff --git a/pkg/socks5/auth_test.go b/pkg/socks5/auth_test.go new file mode 100644 index 0000000..a4fa1b3 --- /dev/null +++ b/pkg/socks5/auth_test.go @@ -0,0 +1,120 @@ +package socks5 + +import ( + "bytes" + "errors" + "testing" +) + +func TestNoAuth(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{1, NoAuth}) + var resp bytes.Buffer + + s, _ := New(&Config{}) + ctx, err := s.authenticate(&resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + if ctx.Method != NoAuth { + t.Fatal("Invalid Context Method") + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, NoAuth}) { + t.Fatalf("bad: %v", out) + } +} + +func TestPasswordAuth_Valid(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{2, NoAuth, UserPassAuth}) + req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) + var resp bytes.Buffer + + cred := StaticCredentials{ + "foo": "bar", + } + + auth := UserPassAuthenticator{Credentials: cred} + + s, _ := New(&Config{AuthMethods: []Authenticator{auth}}) + + ctx, err := s.authenticate(&resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + if ctx.Method != UserPassAuth { + t.Fatal("Invalid Context Method") + } + + val, ok := ctx.Payload["Username"] + if !ok { + t.Fatal("Missing key Username in auth context's payload") + } + + if val != "foo" { + t.Fatal("Invalid Username in auth context's payload") + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authSuccess}) { + t.Fatalf("bad: %v", out) + } +} + +func TestPasswordAuth_Invalid(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{2, NoAuth, UserPassAuth}) + req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'}) + var resp bytes.Buffer + + cred := StaticCredentials{ + "foo": "bar", + } + auth := UserPassAuthenticator{Credentials: cred} + s, _ := New(&Config{AuthMethods: []Authenticator{auth}}) + + ctx, err := s.authenticate(&resp, req) + if !errors.Is(err, UserAuthFailed) { + t.Fatalf("err: %v", err) + } + + if ctx != nil { + t.Fatal("Invalid Context Method") + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authFailure}) { + t.Fatalf("bad: %v", out) + } +} + +func TestNoSupportedAuth(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{1, NoAuth}) + var resp bytes.Buffer + + cred := StaticCredentials{ + "foo": "bar", + } + auth := UserPassAuthenticator{Credentials: cred} + + s, _ := New(&Config{AuthMethods: []Authenticator{auth}}) + + ctx, err := s.authenticate(&resp, req) + if !errors.Is(err, NoSupportedAuth) { + t.Fatalf("err: %v", err) + } + + if ctx != nil { + t.Fatal("Invalid Context Method") + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) { + t.Fatalf("bad: %v", out) + } +} diff --git a/pkg/socks5/credentials.go b/pkg/socks5/credentials.go new file mode 100644 index 0000000..9666427 --- /dev/null +++ b/pkg/socks5/credentials.go @@ -0,0 +1,17 @@ +package socks5 + +// CredentialStore is used to support user/pass authentication +type CredentialStore interface { + Valid(user, password string) bool +} + +// StaticCredentials enables using a map directly as a credential store +type StaticCredentials map[string]string + +func (s StaticCredentials) Valid(user, password string) bool { + pass, ok := s[user] + if !ok { + return false + } + return password == pass +} diff --git a/pkg/socks5/credentials_test.go b/pkg/socks5/credentials_test.go new file mode 100644 index 0000000..4591c09 --- /dev/null +++ b/pkg/socks5/credentials_test.go @@ -0,0 +1,24 @@ +package socks5 + +import ( + "testing" +) + +func TestStaticCredentials(t *testing.T) { + credentials := StaticCredentials{ + "foo": "bar", + "baz": "", + } + + if !credentials.Valid("foo", "bar") { + t.Fatalf("expect valid") + } + + if !credentials.Valid("baz", "") { + t.Fatalf("expect valid") + } + + if credentials.Valid("foo", "") { + t.Fatalf("expect invalid") + } +} diff --git a/pkg/socks5/request.go b/pkg/socks5/request.go new file mode 100644 index 0000000..41d61e9 --- /dev/null +++ b/pkg/socks5/request.go @@ -0,0 +1,386 @@ +package socks5 + +import ( + "context" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" +) + +const ( + ConnectCommand = uint8(1) + BindCommand = uint8(2) + AssociateCommand = uint8(3) + ipv4Address = uint8(1) + fqdnAddress = uint8(3) + ipv6Address = uint8(4) +) + +const ( + successReply uint8 = iota + serverFailure + ruleFailure + networkUnreachable + hostUnreachable + connectionRefused + ttlExpired + commandNotSupported + addrTypeNotSupported +) + +var ( + unrecognizedAddrType = fmt.Errorf("unrecognized address type") +) + +// AddressRewriter is used to rewrite a destination transparently +type AddressRewriter interface { + Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) +} + +// AddrSpec is used to return the target address of a request +// which may be specified as IPv4, IPv6, or a FQDN +type AddrSpec struct { + FQDN string + IP net.IP + Port int +} + +func (a *AddrSpec) String() string { + if a.FQDN != "" { + return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) + } + return fmt.Sprintf("%s:%d", a.IP, a.Port) +} + +// Address returns a string suitable to dial; prefer returning IP-based +// address, fallback to FQDN +func (a *AddrSpec) Address() string { + if 0 != len(a.IP) { + return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) + } + return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) +} + +// A Request represents request received by a server +type Request struct { + // Protocol version + Version uint8 + // Requested command + Command uint8 + // AuthContext provided during negotiation + AuthContext *AuthContext + // AddrSpec of the network that sent the request + RemoteAddr *AddrSpec + // AddrSpec of the desired destination + DestAddr *AddrSpec + // AddrSpec of the actual destination (might be affected by rewrite) + realDestAddr *AddrSpec + bufConn io.Reader + + // Latency is the time it took to establish the connection + Latency time.Duration +} + +type conn interface { + Write([]byte) (int, error) + RemoteAddr() net.Addr +} + +// NewRequest creates a new Request from the tcp connection +func NewRequest(bufConn io.Reader) (*Request, error) { + // Read the version byte + header := []byte{0, 0, 0} + if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { + return nil, fmt.Errorf("failed to get command version: %v", err) + } + + // Ensure we are compatible + if header[0] != socks5Version { + return nil, fmt.Errorf("unsupported command version: %v", header[0]) + } + + // Read in the destination address + dest, err := readAddrSpec(bufConn) + if err != nil { + return nil, err + } + + request := &Request{ + Version: socks5Version, + Command: header[1], + DestAddr: dest, + bufConn: bufConn, + } + + return request, nil +} + +// handleRequest is used for request processing after authentication +func (s *Server) handleRequest(req *Request, conn conn) error { + ctx := context.Background() + + // Resolve the address if we have a FQDN + dest := req.DestAddr + if dest.FQDN != "" { + ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) + if err != nil { + if err := sendReply(conn, hostUnreachable, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return fmt.Errorf("failed to resolve destination '%v': %v", dest.FQDN, err) + } + ctx = ctx_ + dest.IP = addr + } + + // Apply any address rewrites + req.realDestAddr = req.DestAddr + if s.config.Rewriter != nil { + ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) + } + + // Switch on the command + switch req.Command { + case ConnectCommand: + return s.handleConnect(ctx, conn, req) + case BindCommand: + return s.handleBind(ctx, conn, req) + case AssociateCommand: + return s.handleAssociate(ctx, conn, req) + default: + if err := sendReply(conn, commandNotSupported, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return fmt.Errorf("unsupported command: %v", req.Command) + } +} + +// handleConnect is used to handle a connect command +func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error { + // Check if this is allowed + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if err := sendReply(conn, ruleFailure, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return fmt.Errorf("connect to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ + } + + // Attempt to connect + dial := s.config.Dial + if dial == nil { + dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { + return net.Dial(net_, addr) + } + } + + start := time.Now() + + target, err := dial(ctx, "tcp", req.realDestAddr.Address()) + + req.Latency = time.Since(start) + + if err != nil { + msg := err.Error() + resp := hostUnreachable + if strings.Contains(msg, "refused") { + resp = connectionRefused + } else if strings.Contains(msg, "network is unreachable") { + resp = networkUnreachable + } + if err := sendReply(conn, resp, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return fmt.Errorf("connect to %v failed: %v", req.DestAddr, err) + } + defer target.Close() + + // Send success + local := target.LocalAddr().(*net.TCPAddr) + bind := AddrSpec{IP: local.IP, Port: local.Port} + if err := sendReply(conn, successReply, &bind); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + + // Start proxying + errCh := make(chan error, 2) + go proxy(target, req.bufConn, errCh) + go proxy(conn, target, errCh) + + // Wait + for i := 0; i < 2; i++ { + e := <-errCh + if e != nil { + // return from this function closes target (and conn). + return e + } + } + return nil +} + +// handleBind is used to handle a connect command +func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error { + // Check if this is allowed + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if err := sendReply(conn, ruleFailure, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return fmt.Errorf("bind to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ + } + + // TODO: Support bind + if err := sendReply(conn, commandNotSupported, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return nil +} + +// handleAssociate is used to handle a connect command +func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error { + // Check if this is allowed + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if err := sendReply(conn, ruleFailure, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return fmt.Errorf("associate to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ + } + + // TODO: Support associate + if err := sendReply(conn, commandNotSupported, nil); err != nil { + return fmt.Errorf("failed to send reply: %v", err) + } + return nil +} + +// readAddrSpec is used to read AddrSpec. +// Expects an address type byte, followed by the address and port +func readAddrSpec(r io.Reader) (*AddrSpec, error) { + d := &AddrSpec{} + + // Get the address type + addrType := []byte{0} + if _, err := r.Read(addrType); err != nil { + return nil, err + } + + // Handle on a per-type basis + switch addrType[0] { + case ipv4Address: + addr := make([]byte, 4) + if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { + return nil, err + } + d.IP = addr + + case ipv6Address: + addr := make([]byte, 16) + if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { + return nil, err + } + d.IP = addr + + case fqdnAddress: + if _, err := r.Read(addrType); err != nil { + return nil, err + } + addrLen := int(addrType[0]) + fqdn := make([]byte, addrLen) + if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil { + return nil, err + } + d.FQDN = string(fqdn) + + default: + return nil, unrecognizedAddrType + } + + // Read the port + port := []byte{0, 0} + if _, err := io.ReadAtLeast(r, port, 2); err != nil { + return nil, err + } + d.Port = (int(port[0]) << 8) | int(port[1]) + + return d, nil +} + +// sendReply is used to send a reply message +func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { + // Format the address + var addrType uint8 + var addrBody []byte + var addrPort uint16 + switch { + case addr == nil: + addrType = ipv4Address + addrBody = []byte{0, 0, 0, 0} + addrPort = 0 + + case addr.FQDN != "": + addrType = fqdnAddress + addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) + addrPort = uint16(addr.Port) + + case addr.IP.To4() != nil: + addrType = ipv4Address + addrBody = addr.IP.To4() + addrPort = uint16(addr.Port) + + case addr.IP.To16() != nil: + addrType = ipv6Address + addrBody = addr.IP.To16() + addrPort = uint16(addr.Port) + + default: + return fmt.Errorf("failed to format address: %v", addr) + } + + // Format the message + msg := make([]byte, 6+len(addrBody)) + msg[0] = socks5Version + msg[1] = resp + msg[2] = 0 // Reserved + msg[3] = addrType + copy(msg[4:], addrBody) + msg[4+len(addrBody)] = byte(addrPort >> 8) + msg[4+len(addrBody)+1] = byte(addrPort & 0xff) + + // Send the message + _, err := w.Write(msg) + return err +} + +type closeWriter interface { + CloseWrite() error +} + +// proxy is used to shuffle data from src to destination, and sends errors +// down a dedicated channel +func proxy(dst io.Writer, src io.Reader, errCh chan error) { + _, err := io.Copy(dst, src) + if tcpConn, ok := dst.(closeWriter); ok { + tcpConn.CloseWrite() + } + errCh <- err +} + +func parseCommand(cmd uint8) string { + switch cmd { + case ConnectCommand: + return "connect" + case BindCommand: + return "bind" + case AssociateCommand: + return "associate" + default: + return "unknown" + } +} diff --git a/pkg/socks5/request_test.go b/pkg/socks5/request_test.go new file mode 100644 index 0000000..1f1044d --- /dev/null +++ b/pkg/socks5/request_test.go @@ -0,0 +1,178 @@ +package socks5 + +import ( + "bytes" + "encoding/binary" + "github.com/rs/zerolog" + "io" + "net" + "os" + "strings" + "testing" + "time" +) + +type MockConn struct { + buf bytes.Buffer +} + +func (m *MockConn) Write(b []byte) (int, error) { + return m.buf.Write(b) +} + +func (m *MockConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432} +} + +func TestRequest_Connect(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + go func() { + conn, err := l.Accept() + if err != nil { + t.Errorf("err: %v", err) + return + } + defer conn.Close() + + buf := make([]byte, 4) + if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { + t.Errorf("err: %v", err) + return + } + + if !bytes.Equal(buf, []byte("ping")) { + t.Errorf("bad: %v", buf) + return + } + conn.Write([]byte("pong")) + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Make server + logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger() + s := &Server{config: &Config{ + Rules: PermitAll(), + Resolver: DNSResolver{}, + Logger: &logger, + }} + + // Create the connect request + buf := bytes.NewBuffer(nil) + buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) + buf.Write(port) + + // Send a ping + buf.Write([]byte("ping")) + + // Handle the request + resp := &MockConn{} + req, err := NewRequest(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.handleRequest(req, resp); err != nil { + t.Fatalf("err: %v", err) + } + + // Verify response + out := resp.buf.Bytes() + expected := []byte{ + 5, + 0, + 0, + 1, + 127, 0, 0, 1, + 0, 0, + 'p', 'o', 'n', 'g', + } + + // Ignore the port for both + out[8] = 0 + out[9] = 0 + + if !bytes.Equal(out, expected) { + t.Fatalf("bad: %v %v", out, expected) + } +} + +func TestRequest_Connect_RuleFail(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + go func() { + conn, err := l.Accept() + if err != nil { + t.Errorf("err: %v", err) + return + } + defer conn.Close() + + buf := make([]byte, 4) + if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { + t.Errorf("err: %v", err) + return + } + + if !bytes.Equal(buf, []byte("ping")) { + t.Errorf("bad: %v", buf) + return + } + conn.Write([]byte("pong")) + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Make server + logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger() + s := &Server{config: &Config{ + Rules: PermitNone(), + Resolver: DNSResolver{}, + Logger: &logger, + }} + + // Create the connect request + buf := bytes.NewBuffer(nil) + buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) + buf.Write(port) + + // Send a ping + buf.Write([]byte("ping")) + + // Handle the request + resp := &MockConn{} + req, err := NewRequest(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") { + t.Fatalf("err: %v", err) + } + + // Verify response + out := resp.buf.Bytes() + expected := []byte{ + 5, + 2, + 0, + 1, + 0, 0, 0, 0, + 0, 0, + } + + if !bytes.Equal(out, expected) { + t.Fatalf("bad: %v %v", out, expected) + } +} diff --git a/pkg/socks5/resolver.go b/pkg/socks5/resolver.go new file mode 100644 index 0000000..38c2903 --- /dev/null +++ b/pkg/socks5/resolver.go @@ -0,0 +1,22 @@ +package socks5 + +import ( + "context" + "net" +) + +// NameResolver is used to implement custom name resolution +type NameResolver interface { + Resolve(ctx context.Context, address string) (context.Context, net.IP, error) +} + +// DNSResolver uses the system DNS to resolve host names +type DNSResolver struct{} + +func (d DNSResolver) Resolve(ctx context.Context, address string) (context.Context, net.IP, error) { + addr, err := net.ResolveIPAddr("ip", address) + if err != nil { + return ctx, nil, err + } + return ctx, addr.IP, err +} diff --git a/pkg/socks5/resolver_test.go b/pkg/socks5/resolver_test.go new file mode 100644 index 0000000..46c6e5a --- /dev/null +++ b/pkg/socks5/resolver_test.go @@ -0,0 +1,21 @@ +package socks5 + +import ( + "testing" + + "context" +) + +func TestDNSResolver(t *testing.T) { + d := DNSResolver{} + ctx := context.Background() + + _, addr, err := d.Resolve(ctx, "localhost") + if err != nil { + t.Fatalf("err: %v", err) + } + + if !addr.IsLoopback() { + t.Fatalf("expected loopback") + } +} diff --git a/pkg/socks5/ruleset.go b/pkg/socks5/ruleset.go new file mode 100644 index 0000000..d65699d --- /dev/null +++ b/pkg/socks5/ruleset.go @@ -0,0 +1,41 @@ +package socks5 + +import ( + "context" +) + +// RuleSet is used to provide custom rules to allow or prohibit actions +type RuleSet interface { + Allow(ctx context.Context, req *Request) (context.Context, bool) +} + +// PermitAll returns a RuleSet which allows all types of connections +func PermitAll() RuleSet { + return &PermitCommand{true, true, true} +} + +// PermitNone returns a RuleSet which disallows all types of connections +func PermitNone() RuleSet { + return &PermitCommand{false, false, false} +} + +// PermitCommand is an implementation of the RuleSet which +// enables filtering supported commands +type PermitCommand struct { + EnableConnect bool + EnableBind bool + EnableAssociate bool +} + +func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) { + switch req.Command { + case ConnectCommand: + return ctx, p.EnableConnect + case BindCommand: + return ctx, p.EnableBind + case AssociateCommand: + return ctx, p.EnableAssociate + } + + return ctx, false +} diff --git a/pkg/socks5/ruleset_test.go b/pkg/socks5/ruleset_test.go new file mode 100644 index 0000000..1fefa50 --- /dev/null +++ b/pkg/socks5/ruleset_test.go @@ -0,0 +1,24 @@ +package socks5 + +import ( + "testing" + + "context" +) + +func TestPermitCommand(t *testing.T) { + ctx := context.Background() + r := &PermitCommand{true, false, false} + + if _, ok := r.Allow(ctx, &Request{Command: ConnectCommand}); !ok { + t.Fatalf("expect connect") + } + + if _, ok := r.Allow(ctx, &Request{Command: BindCommand}); ok { + t.Fatalf("do not expect bind") + } + + if _, ok := r.Allow(ctx, &Request{Command: AssociateCommand}); ok { + t.Fatalf("do not expect associate") + } +} diff --git a/pkg/socks5/socks5.go b/pkg/socks5/socks5.go new file mode 100644 index 0000000..2cb3693 --- /dev/null +++ b/pkg/socks5/socks5.go @@ -0,0 +1,173 @@ +package socks5 + +import ( + "bufio" + "errors" + "github.com/rs/zerolog" + "net" + "os" + "time" + + "context" +) + +const ( + socks5Version = uint8(5) +) + +// Config is used to set up and configure a Server +type Config struct { + // AuthMethods can be provided to implement custom authentication + // By default, "auth-less" mode is enabled. + // For password-based auth use UserPassAuthenticator. + AuthMethods []Authenticator + + // If provided, username/password authentication is enabled, + // by appending a UserPassAuthenticator to AuthMethods. If not provided, + // and AUthMethods is a nil, then "auth-less" mode is enabled. + Credentials CredentialStore + + // Resolver can be provided to do custom name resolution. + // Defaults to DNSResolver if not provided. + Resolver NameResolver + + // Rules is provided to enable custom logic around permitting + // various commands. If not provided, PermitAll is used. + Rules RuleSet + + // Rewriter can be used to transparently rewrite addresses. + // This is invoked before the RuleSet is invoked. + // Defaults to NoRewrite. + Rewriter AddressRewriter + + // BindIP is used for bind or udp associate + BindIP net.IP + + // Logger can be used to provide a custom log target. + // Defaults to stdout. + Logger *zerolog.Logger + + // Optional function for dialing out + Dial func(ctx context.Context, network, addr string) (net.Conn, error) +} + +// Server is responsible for accepting connections and handling +// the details of the SOCKS5 protocol +type Server struct { + config *Config + authMethods map[uint8]Authenticator +} + +// New creates a new Server and potentially returns an error +func New(conf *Config) (*Server, error) { + // Ensure we have at least one authentication method enabled + if len(conf.AuthMethods) == 0 { + if conf.Credentials != nil { + conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}} + } else { + conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}} + } + } + + // Ensure we have a DNS resolver + if conf.Resolver == nil { + conf.Resolver = DNSResolver{} + } + + // Ensure we have a rule set + if conf.Rules == nil { + conf.Rules = PermitAll() + } + + // Ensure we have a log target + if conf.Logger == nil { + logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger() + conf.Logger = &logger + } + + server := &Server{ + config: conf, + } + + server.authMethods = make(map[uint8]Authenticator) + + for _, a := range conf.AuthMethods { + server.authMethods[a.GetCode()] = a + } + + return server, nil +} + +// ListenAndServe is used to create a listener and serve on it +func (s *Server) ListenAndServe(network, addr string) error { + l, err := net.Listen(network, addr) + if err != nil { + return err + } + return s.Serve(l) +} + +// Serve is used to serve connections from a listener +func (s *Server) Serve(l net.Listener) error { + for { + conn, err := l.Accept() + if err != nil { + return err + } + go s.ServeConn(conn) + } +} + +// ServeConn is used to serve a single connection. +func (s *Server) ServeConn(conn net.Conn) { + defer conn.Close() + bufConn := bufio.NewReader(conn) + + // Read the version byte + version := []byte{0} + if _, err := bufConn.Read(version); err != nil { + s.config.Logger.Err(err).Msg("failed to get version byte") + return + } + + // Ensure we are compatible + if version[0] != socks5Version { + s.config.Logger.Error().Msgf("unsupported SOCKS version: %v", version) + return + } + + // Authenticate the connection + authContext, err := s.authenticate(conn, bufConn) + if err != nil { + s.config.Logger.Err(err).Msg("failed to authenticate") + return + } + + request, err := NewRequest(bufConn) + if err != nil { + if errors.Is(err, unrecognizedAddrType) { + if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { + s.config.Logger.Err(err).Msg("failed to send reply") + return + } + } + s.config.Logger.Err(err).Msg("failed to read destination address") + return + } + request.AuthContext = authContext + if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok { + request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} + } + + // Process the client request + if err := s.handleRequest(request, conn); err != nil { + s.config.Logger.Err(err).Msg("failed to handle request") + } + + s.config.Logger.Info(). + Str("remote_addr", conn.RemoteAddr().String()). + Str("command", parseCommand(request.Command)). + Str("dest_addr", request.DestAddr.String()). + Str("latency", request.Latency.String()). + Msg("request processed") +} diff --git a/pkg/socks5/socks5_test.go b/pkg/socks5/socks5_test.go new file mode 100644 index 0000000..c833db3 --- /dev/null +++ b/pkg/socks5/socks5_test.go @@ -0,0 +1,115 @@ +package socks5 + +import ( + "bytes" + "encoding/binary" + "github.com/rs/zerolog" + "io" + "net" + "os" + "testing" + "time" +) + +func TestSOCKS5_Connect(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + go func() { + conn, err := l.Accept() + if err != nil { + t.Errorf("err: %v", err) + return + } + defer conn.Close() + + buf := make([]byte, 4) + if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { + t.Errorf("err: %v", err) + return + } + + if !bytes.Equal(buf, []byte("ping")) { + t.Errorf("bad: %v", buf) + return + } + conn.Write([]byte("pong")) + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Create a socks' server + logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger() + credentials := StaticCredentials{ + "foo": "bar", + } + auth := UserPassAuthenticator{Credentials: credentials} + conf := &Config{ + AuthMethods: []Authenticator{auth}, + Logger: &logger, + } + serv, err := New(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Start listening + go func() { + if err := serv.ListenAndServe("tcp", "127.0.0.1:12365"); err != nil { + t.Errorf("err: %v", err) + return + } + }() + time.Sleep(10 * time.Millisecond) + + // Get a local conn + conn, err := net.Dial("tcp", "127.0.0.1:12365") + if err != nil { + t.Fatalf("err: %v", err) + } + + // Connect, auth and connect to local + req := bytes.NewBuffer(nil) + req.Write([]byte{5}) + req.Write([]byte{2, NoAuth, UserPassAuth}) + req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) + req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) + req.Write(port) + + // Send a ping + req.Write([]byte("ping")) + + // Send all the bytes + conn.Write(req.Bytes()) + + // Verify response + expected := []byte{ + socks5Version, UserPassAuth, + 1, authSuccess, + 5, + 0, + 0, + 1, + 127, 0, 0, 1, + 0, 0, + 'p', 'o', 'n', 'g', + } + out := make([]byte, len(expected)) + + conn.SetDeadline(time.Now().Add(time.Second)) + if _, err := io.ReadAtLeast(conn, out, len(out)); err != nil { + t.Fatalf("err: %v", err) + } + + // Ignore the port + out[12] = 0 + out[13] = 0 + + if !bytes.Equal(out, expected) { + t.Fatalf("bad: %v", out) + } +} diff --git a/systemd/nanoproxy.service b/systemd/nanoproxy.service index c280ddb..8b88693 100644 --- a/systemd/nanoproxy.service +++ b/systemd/nanoproxy.service @@ -5,7 +5,7 @@ After=network.target [Service] EnvironmentFile=/etc/nanoproxy/nanoproxy ExecStart=/usr/bin/nanoproxy -WorkingDirectory=/etc/nanoproxy/ +WorkingDirectory=/usr/bin Restart=always User=root diff --git a/webproxy/webproxy.go b/webproxy/webproxy.go deleted file mode 100644 index d17c5fd..0000000 --- a/webproxy/webproxy.go +++ /dev/null @@ -1,84 +0,0 @@ -package webproxy - -import ( - "github.com/gofiber/fiber/v2" - "github.com/valyala/fasthttp" - "io" - "net" - "time" -) - -type WebProxy struct { - TunnelTimeout time.Duration // tunnel timeout in seconds (default 15s) -} - -func New(tunnelTimeout time.Duration) *WebProxy { - return &WebProxy{ - TunnelTimeout: tunnelTimeout, - } -} - -// Handler handles incoming HTTP requests and proxies them to the destination server -func (s *WebProxy) Handler(c *fiber.Ctx) error { - if c.Method() == fiber.MethodConnect { - return s.handleTunneling(c) - } else { - return s.handleHTTP(c) - } -} - -// handleHTTP handles normal HTTP proxy requests -func (s *WebProxy) handleHTTP(c *fiber.Ctx) error { - agent := fiber.AcquireAgent() - - // set request URI and Host header - req := agent.Request() - req.SetRequestURI(c.OriginalURL()) - req.Header.SetMethod(c.Method()) - req.Header.SetHost(c.Hostname()) - - // copy headers - c.Request().Header.CopyTo(&req.Header) - - // parse request - if err := agent.Parse(); err != nil { - return err - } - - // send request and receive response - var resp fiber.Response - if err := agent.DoTimeout(req, &resp, s.TunnelTimeout); err != nil { - return c.SendStatus(fiber.StatusBadGateway) - } - - // copy response headers - resp.Header.CopyTo(&c.Response().Header) - - // set status code - return c.Status(resp.StatusCode()).Send(resp.Body()) -} - -// handleTunneling handles CONNECT requests -func (s *WebProxy) handleTunneling(c *fiber.Ctx) error { - destConn, err := fasthttp.DialTimeout(c.OriginalURL(), s.TunnelTimeout) - if err != nil { - return c.SendStatus(fiber.StatusBadGateway) - } - - // hijack the client connection from the HTTP server - c.Context().Hijack(func(clientConn net.Conn) { - go transfer(destConn, clientConn) - transfer(clientConn, destConn) - }) - - return nil -} - -// transfer bytes from src to dst until either EOF is reached on src or an error occurs -func transfer(destination net.Conn, source net.Conn) { - defer func() { - _ = destination.Close() - _ = source.Close() - }() - _, _ = io.Copy(destination, source) -} diff --git a/webproxy/webproxy_test.go b/webproxy/webproxy_test.go deleted file mode 100644 index 61d41a9..0000000 --- a/webproxy/webproxy_test.go +++ /dev/null @@ -1,238 +0,0 @@ -package webproxy - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/utils" - "io" - "math/big" - "net" - "net/http" - "net/url" - "testing" - "time" -) - -func Test_Middleware_WebProxy(t *testing.T) { - t.Parallel() - - lnProxy, err := net.Listen("tcp", "127.0.0.1:0") - utils.AssertEqual(t, nil, err) - appProxy := fiber.New(fiber.Config{ - DisableStartupMessage: true, - }) - proxy := New(time.Second * 5) - appProxy.All("*", proxy.Handler) - - lnTarget, err := net.Listen("tcp", "127.0.0.1:0") - utils.AssertEqual(t, nil, err) - appTarget := fiber.New(fiber.Config{ - DisableStartupMessage: true, - }) - appTarget.Get("/", func(c *fiber.Ctx) error { - return c.SendString("Hello, World!") - }) - - proxyAddr := lnProxy.Addr().String() - targetAddr := lnTarget.Addr().String() - - go func() { - utils.AssertEqual(t, nil, appProxy.Listener(lnProxy)) - }() - go func() { - utils.AssertEqual(t, nil, appTarget.Listener(lnTarget)) - }() - - proxyURL, err := url.Parse("http://" + proxyAddr) - utils.AssertEqual(t, nil, err) - - transport := &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - - client := &http.Client{ - Transport: transport, - } - - resp, err := client.Get("http://" + targetAddr) - utils.AssertEqual(t, nil, err) - - body, err := io.ReadAll(resp.Body) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "Hello, World!", string(body)) -} - -func Test_Middleware_WebProxy_HTTPS(t *testing.T) { - t.Parallel() - - lnProxy, err := net.Listen("tcp", "127.0.0.1:0") - utils.AssertEqual(t, nil, err) - appProxy := fiber.New(fiber.Config{ - DisableStartupMessage: true, - }) - proxy := New(3 * time.Second) - appProxy.All("*", proxy.Handler) - - tlsconf, _, err := getTLSConfigs() - utils.AssertEqual(t, nil, err) - - lnTarget, err := tls.Listen("tcp", "127.0.0.1:0", tlsconf) - utils.AssertEqual(t, nil, err) - appTarget := fiber.New(fiber.Config{ - DisableStartupMessage: true, - }) - appTarget.Get("/", func(c *fiber.Ctx) error { - return c.SendString("Hello, World!") - }) - - proxyAddr := lnProxy.Addr().String() - targetAddr := lnTarget.Addr().String() - - go func() { - utils.AssertEqual(t, nil, appProxy.Listener(lnProxy)) - }() - - go func() { - utils.AssertEqual(t, nil, appTarget.Listener(lnTarget)) - }() - - proxyURL, err := url.Parse("http://" + proxyAddr) - utils.AssertEqual(t, nil, err) - - transport := &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - - client := &http.Client{ - Transport: transport, - } - - resp, err := client.Get("https://" + targetAddr) - utils.AssertEqual(t, nil, err) - - body, err := io.ReadAll(resp.Body) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "Hello, World!", string(body)) - - client2 := &http.Client{ - Transport: transport, - } - - _, err = client2.Get("https://wronghost") - utils.AssertEqual(t, "Get \"https://wronghost\": Bad Gateway", err.Error()) -} - -// getTLSConfigs returns a server and client TLS config -// this code is copied from https://github.com/gofiber/fiber/blob/master/internal/tlstest/tls.go -func getTLSConfigs() (serverTLSConf, clientTLSConf *tls.Config, err error) { - // set up our CA certificate - ca := &x509.Certificate{ - SerialNumber: big.NewInt(2021), - Subject: pkix.Name{ - Organization: []string{"Fiber"}, - Country: []string{"NL"}, - Province: []string{""}, - Locality: []string{"Amsterdam"}, - StreetAddress: []string{"Huidenstraat"}, - PostalCode: []string{"1011 AA"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - - // create our private and public key - caPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, err - } - - // create the CA - caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivateKey.PublicKey, caPrivateKey) - if err != nil { - return nil, nil, err - } - - // pem encode - var caPEM bytes.Buffer - _ = pem.Encode(&caPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: caBytes, - }) - - var caPrivKeyPEM bytes.Buffer - _ = pem.Encode(&caPrivKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(caPrivateKey), - }) - - // set up our server certificate - cert := &x509.Certificate{ - SerialNumber: big.NewInt(2021), - Subject: pkix.Name{ - Organization: []string{"Fiber"}, - Country: []string{"NL"}, - Province: []string{""}, - Locality: []string{"Amsterdam"}, - StreetAddress: []string{"Huidenstraat"}, - PostalCode: []string{"1011 AA"}, - }, - IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), - SubjectKeyId: []byte{1, 2, 3, 4, 6}, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - } - - certPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, err - } - - certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivateKey.PublicKey, caPrivateKey) - if err != nil { - return nil, nil, err - } - - var certPEM bytes.Buffer - _ = pem.Encode(&certPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }) - - var certPrivateKeyPEM bytes.Buffer - _ = pem.Encode(&certPrivateKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(certPrivateKey), - }) - - serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivateKeyPEM.Bytes()) - if err != nil { - return nil, nil, err - } - - serverTLSConf = &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - } - - certPool := x509.NewCertPool() - certPool.AppendCertsFromPEM(caPEM.Bytes()) - clientTLSConf = &tls.Config{ - RootCAs: certPool, - } - - return serverTLSConf, clientTLSConf, nil -}