Skip to content

Commit

Permalink
Refactor basic authentication in config.go and connectHandlerFunc.go
Browse files Browse the repository at this point in the history
  • Loading branch information
arloor committed Mar 11, 2024
1 parent 4a619a1 commit 47e1056
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
12 changes: 7 additions & 5 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"encoding/base64"
"flag"
"fmt"
"gopkg.in/natefinch/lumberjack.v2"
"gopkg.in/yaml.v2"
"io"
"log"
"os"
"strings"

"gopkg.in/natefinch/lumberjack.v2"
"gopkg.in/yaml.v2"
)

type stringArray []string
Expand All @@ -32,7 +34,7 @@ type Config struct {
LogPath string `yaml:"log"`
WebPath string `yaml:"content"`
Users stringArray `yaml:"users"`
BasicAuth map[string]bool `yaml:"auth"`
BasicAuth map[string]string `yaml:"auth"`
Refer string `yaml:"refer"`
}

Expand All @@ -48,10 +50,10 @@ func init() {
flag.Var(&Instance.Users, "user", "")
flag.StringVar(&Instance.Refer, "refer", "", "本站的referer特征")
flag.Parse()
Instance.BasicAuth = make(map[string]bool)
Instance.BasicAuth = make(map[string]string)
for _, user := range Instance.Users {
base64Encode := "Basic " + base64.StdEncoding.EncodeToString([]byte(user))
Instance.BasicAuth[base64Encode] = true
Instance.BasicAuth[base64Encode] = strings.Split(user, ":")[0]
}
initLog()
out, err := yaml.Marshal(Instance)
Expand Down
17 changes: 10 additions & 7 deletions internal/server/connectHandlerFunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ func dialContextCheckACL(network, hostPort string) (net.Conn, error) {

func connect(w http.ResponseWriter, r *http.Request) {
clientAddr := strings.Split(r.RemoteAddr, ":")[0]
var username string;
if len(config.Instance.BasicAuth) != 0 {
if _, ok := config.Instance.BasicAuth[r.Header.Get("proxy-authorization")]; !ok {
var ok bool;
username, ok = config.Instance.BasicAuth[r.Header.Get("proxy-authorization")]
if !ok {
log.Println("wrong basic auth from", clientAddr)
http.Error(w, "InternalServerError", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -53,7 +56,7 @@ func connect(w http.ResponseWriter, r *http.Request) {

switch r.ProtoMajor {
case 1: // http1: hijack the whole flow
_, err := serveHijack(w, targetConn, clientAddr, hostPort)
_, err := serveHijack(w, targetConn, clientAddr, hostPort,username)
if err != nil {
log.Println(err, r.RemoteAddr)
}
Expand All @@ -71,7 +74,7 @@ func connect(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Server", "go_web_server")
}
wFlusher.Flush()
err := dualStream(targetConn, r.Body, w, clientAddr, hostPort)
err := dualStream(targetConn, r.Body, w, clientAddr, hostPort,username)
if err != nil {
log.Println(err, r.RemoteAddr)
}
Expand All @@ -83,7 +86,7 @@ func connect(w http.ResponseWriter, r *http.Request) {

// Hijacks the connection from ResponseWriter, writes the response and proxies data between targetConn
// and hijacked connection.
func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string, hostPort string) (int, error) {
func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string, hostPort string,username string) (int, error) {
hijacker, ok := w.(http.Hijacker)
if !ok {
return http.StatusInternalServerError, errors.New("ResponseWriter does not implement Hijacker")
Expand Down Expand Up @@ -121,21 +124,21 @@ func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string,
return http.StatusInternalServerError, errors.New("failed to send response to client: " + err.Error())
}

return 0, dualStream(targetConn, clientConn, clientConn, clientAddr, hostPort)
return 0, dualStream(targetConn, clientConn, clientConn, clientAddr, hostPort,username)
}

var bufferPool = &sync.Pool{New: func() interface{} {
return make([]byte, 32*1024)
}}

func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Writer, clientAddr string, hostPort string) error {
func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Writer, clientAddr string, hostPort string,username string) error {
stream := func(w io.Writer, r io.Reader) error {
// copy bytes from r to w
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
buf = buf[0:cap(buf)]
nw, _err := flushingIoCopy(w, r, buf)
ProxyTraffic.WithLabelValues(clientAddr, hostPort).Add(float64(nw))
ProxyTraffic.WithLabelValues(clientAddr, hostPort,username).Add(float64(nw))
if closeWriter, ok := w.(interface {
CloseWrite() error
}); ok {
Expand Down
2 changes: 1 addition & 1 deletion internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var (
ProxyTraffic = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "proxy_traffic_total",
Help: "num proxy_traffic",
}, []string{"client", "target"})
}, []string{"client", "target","username"})
)

func Serve() error {
Expand Down

0 comments on commit 47e1056

Please sign in to comment.