-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: change engine to Fiber, and fix authentication can work on netw…
…ork proxy
- Loading branch information
1 parent
f341795
commit 4a684d6
Showing
13 changed files
with
662 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,114 @@ | ||
package config | ||
|
||
import ( | ||
"errors" | ||
"flag" | ||
"os" | ||
"strings" | ||
"time" | ||
) | ||
|
||
type Config struct { | ||
PemPath string | ||
KeyPath string | ||
Proto string | ||
Addr string | ||
TunnelTimeout time.Duration | ||
BasicAuth string | ||
Debug bool | ||
pemPath string | ||
keyPath string | ||
proto string | ||
addr string | ||
tunnelTimeout time.Duration | ||
basicAuth string | ||
debug bool | ||
} | ||
|
||
var ( | ||
ErrInvalidProto = errors.New("invalid protocol") | ||
ErrInvalidBasicAuth = errors.New("invalid basic auth") | ||
) | ||
|
||
func New() *Config { | ||
c := &Config{} | ||
flag.StringVar(&c.PemPath, "pem", "server.pem", "path to pem file") | ||
flag.StringVar(&c.KeyPath, "key", "server.key", "path to key file") | ||
flag.StringVar(&c.Proto, "proto", "http", "proxy protocol (http or https)") | ||
flag.StringVar(&c.Addr, "addr", ":8080", "proxy listen address (default :8080)") | ||
flag.DurationVar(&c.TunnelTimeout, "timeout", time.Second*15, "tunnel timeout (default 15s)") | ||
flag.StringVar(&c.BasicAuth, "auth", "", "basic auth (username:password)") | ||
flag.BoolVar(&c.Debug, "debug", false, "debug mode") | ||
flag.StringVar(&c.pemPath, "pem", "server.pem", "path to pem file") | ||
flag.StringVar(&c.keyPath, "key", "server.key", "path to key file") | ||
flag.StringVar(&c.proto, "proto", "http", "proxy protocol (http or https)") | ||
flag.StringVar(&c.addr, "addr", ":8080", "proxy listen address (default :8080)") | ||
flag.DurationVar(&c.tunnelTimeout, "timeout", time.Second*15, "tunnel timeout (default 15s)") | ||
flag.StringVar(&c.basicAuth, "auth", "", "basic auth (username:password)") | ||
flag.BoolVar(&c.debug, "debug", false, "debug mode") | ||
flag.Parse() | ||
|
||
if os.Getenv("PEM") != "" { | ||
c.PemPath = os.Getenv("PEM") | ||
c.pemPath = os.Getenv("PEM") | ||
} | ||
|
||
if os.Getenv("KEY") != "" { | ||
c.KeyPath = os.Getenv("KEY") | ||
c.keyPath = os.Getenv("KEY") | ||
} | ||
|
||
if os.Getenv("PROTO") != "" { | ||
c.Proto = os.Getenv("PROTO") | ||
c.proto = os.Getenv("PROTO") | ||
} | ||
|
||
if os.Getenv("ADDR") != "" { | ||
c.Addr = os.Getenv("ADDR") | ||
c.addr = os.Getenv("ADDR") | ||
} | ||
|
||
if os.Getenv("TIMEOUT") != "" { | ||
d, err := time.ParseDuration(os.Getenv("TIMEOUT")) | ||
if err == nil { | ||
c.TunnelTimeout = d | ||
c.tunnelTimeout = d | ||
} | ||
} | ||
|
||
if os.Getenv("AUTH") != "" { | ||
c.BasicAuth = os.Getenv("AUTH") | ||
c.basicAuth = os.Getenv("AUTH") | ||
} | ||
return c | ||
} | ||
|
||
func (c *Config) PemPath() string { | ||
return c.pemPath | ||
} | ||
|
||
func (c *Config) KeyPath() string { | ||
return c.keyPath | ||
} | ||
|
||
func (c *Config) Addr() string { | ||
return c.addr | ||
} | ||
|
||
func (c *Config) TunnelTimeout() time.Duration { | ||
return c.tunnelTimeout | ||
} | ||
|
||
func (c *Config) BasicAuth() map[string]string { | ||
cred := strings.Split(c.basicAuth, ",") | ||
users := map[string]string{} | ||
for _, cred := range cred { | ||
userPass := strings.Split(cred, ":") | ||
if len(userPass) == 2 { | ||
users[userPass[0]] = userPass[1] | ||
} | ||
} | ||
return users | ||
} | ||
|
||
func (c *Config) Debug() bool { | ||
return c.debug | ||
} | ||
|
||
func (c *Config) IsHTTPS() bool { | ||
return c.proto == "https" | ||
} | ||
|
||
func (c *Config) Validate() error { | ||
if c.proto != "http" && c.proto != "https" { | ||
return ErrInvalidProto | ||
} | ||
|
||
if c.basicAuth != "" { | ||
if len(c.BasicAuth()) == 0 { | ||
return ErrInvalidBasicAuth | ||
} | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package basicauth | ||
|
||
import ( | ||
"encoding/base64" | ||
"github.com/gofiber/fiber/v2" | ||
"github.com/gofiber/fiber/v2/utils" | ||
"strings" | ||
) | ||
|
||
func New(config Config) fiber.Handler { | ||
cfg := configDefault(config) | ||
return func(c *fiber.Ctx) error { | ||
// Skip if basic auth is empty | ||
if len(cfg.Users) == 0 { | ||
return c.Next() | ||
} | ||
|
||
// Get authorization header | ||
auth := c.Get(fiber.HeaderProxyAuthorization) | ||
|
||
// Check if header is valid | ||
if len(auth) < 6 || !utils.EqualFold(auth[:6], "basic ") { | ||
return cfg.Unauthorized(c) | ||
} | ||
|
||
// Decode header | ||
raw, err := base64.StdEncoding.DecodeString(auth[6:]) | ||
if err != nil { | ||
return cfg.Unauthorized(c) | ||
} | ||
|
||
// Get credentials | ||
creds := utils.UnsafeString(raw) | ||
|
||
// Split username and password | ||
index := strings.Index(creds, ":") | ||
if index == -1 { | ||
return cfg.Unauthorized(c) | ||
} | ||
|
||
// Get username and password | ||
user := creds[:index] | ||
pass := creds[index+1:] | ||
|
||
// Check credentials | ||
if cfg.Authorized(user, pass) { | ||
c.Locals("username", user) | ||
c.Locals("password", pass) | ||
return c.Next() | ||
} | ||
|
||
// Credentials doesn't match | ||
return cfg.Unauthorized(c) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package basicauth | ||
|
||
import ( | ||
"encoding/base64" | ||
"github.com/gofiber/fiber/v2" | ||
"github.com/gofiber/fiber/v2/utils" | ||
"io" | ||
"net/http/httptest" | ||
"testing" | ||
) | ||
|
||
func Test_Middleware_BasicAuth(t *testing.T) { | ||
t.Parallel() | ||
|
||
app := fiber.New() | ||
|
||
app.Use(New(Config{ | ||
Users: map[string]string{ | ||
"john": "doe", | ||
"jane": "doe", | ||
}, | ||
})) | ||
|
||
app.Get("/testauth", func(c *fiber.Ctx) error { | ||
username := c.Locals("username").(string) | ||
password := c.Locals("password").(string) | ||
|
||
return c.SendString(username + password) | ||
}) | ||
|
||
tests := []struct { | ||
url string | ||
statusCode int | ||
username string | ||
password string | ||
}{ | ||
{ | ||
url: "/testauth", | ||
statusCode: fiber.StatusOK, | ||
username: "john", | ||
password: "doe", | ||
}, | ||
{ | ||
url: "/testauth", | ||
statusCode: fiber.StatusOK, | ||
username: "jane", | ||
password: "doe", | ||
}, | ||
{ | ||
url: "/testauth", | ||
statusCode: fiber.StatusUnauthorized, | ||
username: "john", | ||
password: "wrong", | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
// Encode credentials to base64 | ||
cred := base64.StdEncoding.EncodeToString([]byte(tt.username + ":" + tt.password)) | ||
|
||
req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil) | ||
req.Header.Set(fiber.HeaderProxyAuthorization, "Basic "+cred) | ||
resp, err := app.Test(req) | ||
utils.AssertEqual(t, nil, err) | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
utils.AssertEqual(t, nil, err) | ||
|
||
utils.AssertEqual(t, tt.statusCode, resp.StatusCode) | ||
if tt.statusCode == fiber.StatusOK { | ||
utils.AssertEqual(t, tt.username+tt.password, string(body)) | ||
} | ||
} | ||
} | ||
|
||
func Test_Middleware_BasicAuth_No_Users(t *testing.T) { | ||
t.Parallel() | ||
|
||
app := fiber.New() | ||
|
||
app.Use(New(Config{ | ||
Users: map[string]string{}, | ||
})) | ||
|
||
app.Get("/testauth", func(c *fiber.Ctx) error { | ||
return c.SendString("testauth") | ||
}) | ||
|
||
req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil) | ||
resp, err := app.Test(req) | ||
utils.AssertEqual(t, nil, err) | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
utils.AssertEqual(t, nil, err) | ||
|
||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) | ||
utils.AssertEqual(t, "testauth", string(body)) | ||
} |
Oops, something went wrong.